diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c5c1339b8..775324b5d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,7 +60,7 @@ jobs: report_paths: "**/build/test-results/test/TEST-*.xml" unit_test_jdk8: - name: Unit test with docker service [JDK8] + name: Unit test with CLI runs-on: ubuntu-latest-16-cores timeout-minutes: 30 steps: @@ -82,9 +82,9 @@ jobs: - name: Set up Gradle uses: gradle/actions/setup-gradle@ac396bf1a80af16236baf54bd7330ae21dc6ece5 # v6 - - name: Start containerized server and dependencies + - name: Start CLI server env: - TEMPORAL_CLI_VERSION: 1.6.1-server-1.31.0-151.0 + TEMPORAL_CLI_VERSION: 1.7.0 run: | wget -O temporal_cli.tar.gz https://github.com/temporalio/cli/releases/download/v${TEMPORAL_CLI_VERSION}/temporal_cli_${TEMPORAL_CLI_VERSION}_linux_amd64.tar.gz tar -xzf temporal_cli.tar.gz diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java index b01503956..1dceb67fb 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java @@ -37,6 +37,7 @@ public ActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -53,6 +54,7 @@ public ActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java index 520ce7a37..d2fddde3f 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java @@ -105,6 +105,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -113,7 +114,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { @@ -125,6 +126,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), taskQueueActivitiesPerSecond, this.slotSupplier, @@ -133,7 +135,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java index 60ebcbf65..b23d16184 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncActivityPollTask.java @@ -43,6 +43,7 @@ public AsyncActivityPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, double activitiesPerSecond, @Nonnull TrackingSlotSupplier slotSupplier, @@ -59,6 +60,7 @@ public AsyncActivityPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (activitiesPerSecond > 0) { pollRequest.setTaskQueueMetadata( TaskQueueMetadata.newBuilder() diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java index efc4dc807..1ba3b84d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncNexusPollTask.java @@ -41,6 +41,7 @@ public AsyncNexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull Scope metricsScope, @Nonnull Supplier serverCapabilities, @@ -57,6 +58,8 @@ public AsyncNexusPollTask( .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java index 510634379..7859484bb 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncPoller.java @@ -29,7 +29,6 @@ final class AsyncPoller extends BasePoller { private final List> asyncTaskPollers; private final PollerOptions pollerOptions; private final PollerBehaviorAutoscaling pollerBehavior; - private final boolean serverSupportsAutoscaling; private final Scope workerMetricsScope; private Throttler pollRateThrottler; private final Thread.UncaughtExceptionHandler uncaughtExceptionHandler = @@ -43,7 +42,7 @@ final class AsyncPoller extends BasePoller { PollTaskAsync asyncTaskPoller, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { this( slotSupplier, @@ -51,7 +50,7 @@ final class AsyncPoller extends BasePoller { Collections.singletonList(asyncTaskPoller), taskExecutor, pollerOptions, - serverSupportsAutoscaling, + namespaceCapabilities, workerMetricsScope); } @@ -61,9 +60,9 @@ final class AsyncPoller extends BasePoller { List> asyncTaskPollers, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - boolean serverSupportsAutoscaling, + NamespaceCapabilities namespaceCapabilities, Scope workerMetricsScope) { - super(taskExecutor); + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(slotSupplier, "slotSupplier cannot be null"); Objects.requireNonNull(slotReservationData, "slotReservation data should not be null"); Objects.requireNonNull(asyncTaskPollers, "asyncTaskPollers should not be null"); @@ -82,7 +81,6 @@ final class AsyncPoller extends BasePoller { + " is not supported for AsyncPoller. Only PollerBehaviorAutoscaling is supported."); } this.pollerBehavior = (PollerBehaviorAutoscaling) pollerOptions.getPollerBehavior(); - this.serverSupportsAutoscaling = serverSupportsAutoscaling; this.pollerOptions = pollerOptions; this.workerMetricsScope = workerMetricsScope; } @@ -114,7 +112,7 @@ public boolean start() { pollerBehavior.getMinConcurrentTaskPollers(), pollerBehavior.getMaxConcurrentTaskPollers(), pollerBehavior.getInitialConcurrentTaskPollers(), - serverSupportsAutoscaling, + namespaceCapabilities.isPollerAutoscaling(), (newTarget) -> { log.debug( "Updating maximum number of pollers for {} to: {}", @@ -136,12 +134,14 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return super.shutdown(shutdownManager, interruptTasks) .thenApply( (f) -> { - for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { - try { - log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); - asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); - } catch (Throwable e) { - log.error("Error while cancelling poll task", e); + if (!namespaceCapabilities.isGracefulPollShutdown()) { + for (PollTaskAsync asyncTaskPoller : asyncTaskPollers) { + try { + log.debug("Shutting down async poller: {}", asyncTaskPoller.getLabel()); + asyncTaskPoller.cancel(new RuntimeException("Shutting down poller")); + } catch (Throwable e) { + log.error("Error while cancelling poll task", e); + } } } return null; diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java index c30dbc9e1..3bfa796a3 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/AsyncWorkflowPollTask.java @@ -52,6 +52,7 @@ public AsyncWorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -67,6 +68,8 @@ public AsyncWorkflowPollTask( .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); + if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( WorkerVersioningProtoUtils.deploymentOptionsToProto( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java index 9b8141fc0..febd6241a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/BasePoller.java @@ -27,9 +27,14 @@ abstract class BasePoller implements SuspendableWorker { protected ExecutorService pollExecutor; - protected BasePoller(ShutdownableTaskExecutor taskExecutor) { + protected final NamespaceCapabilities namespaceCapabilities; + + protected BasePoller( + ShutdownableTaskExecutor taskExecutor, NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(taskExecutor, "taskExecutor should not be null"); this.taskExecutor = taskExecutor; + this.namespaceCapabilities = + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); } @Override @@ -55,15 +60,24 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean return CompletableFuture.completedFuture(null); } - return shutdownManager - // it's ok to forcefully shutdown pollers, because they are stuck in a long poll call - // so we don't risk loosing any progress doing that. - .shutdownExecutorNow(pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - }); + CompletableFuture pollExecutorShutdown; + if (namespaceCapabilities.isGracefulPollShutdown()) { + // When graceful poll shutdown is enabled, the server will complete outstanding polls with + // empty responses after ShutdownWorker is called. We simply wait for polls to return. + pollExecutorShutdown = + shutdownManager.shutdownExecutor( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(80)); + } else { + // Old behaviour forcibly stops outstanding polls. + pollExecutorShutdown = + shutdownManager.shutdownExecutorNow( + pollExecutor, this + "#pollExecutor", Duration.ofSeconds(1)); + } + return pollExecutorShutdown.exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java index 8dcaa6f33..7fe0335b1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/MultiThreadedPoller.java @@ -52,8 +52,9 @@ public MultiThreadedPoller( PollTask pollTask, ShutdownableTaskExecutor taskExecutor, PollerOptions pollerOptions, - Scope workerMetricsScope) { - super(taskExecutor); + Scope workerMetricsScope, + NamespaceCapabilities namespaceCapabilities) { + super(taskExecutor, namespaceCapabilities); Objects.requireNonNull(identity, "identity cannot be null"); Objects.requireNonNull(pollTask, "poll service should not be null"); Objects.requireNonNull(pollerOptions, "pollerOptions should not be null"); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java index 4fa9d09a5..a3410fa25 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NamespaceCapabilities.java @@ -1,5 +1,6 @@ package io.temporal.internal.worker; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -9,14 +10,28 @@ */ public final class NamespaceCapabilities { private final AtomicBoolean pollerAutoscaling = new AtomicBoolean(false); + private final AtomicBoolean gracefulPollShutdown = new AtomicBoolean(false); private final AtomicBoolean workerHeartbeats = new AtomicBoolean(false); + public void setFromCapabilities(Capabilities capabilities) { + if (capabilities.getPollerAutoscaling()) { + pollerAutoscaling.set(true); + } + if (capabilities.getWorkerPollCompleteOnShutdown()) { + gracefulPollShutdown.set(true); + } + } + public boolean isPollerAutoscaling() { return pollerAutoscaling.get(); } - public void setPollerAutoscaling(boolean value) { - pollerAutoscaling.set(value); + public boolean isGracefulPollShutdown() { + return gracefulPollShutdown.get(); + } + + public void setGracefulPollShutdown(boolean value) { + gracefulPollShutdown.set(value); } public boolean isWorkerHeartbeats() { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java index 4116825b9..0ccab5944 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusPollTask.java @@ -34,6 +34,7 @@ public NexusPollTask( @Nonnull String namespace, @Nonnull String taskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull Scope metricsScope, @@ -49,6 +50,7 @@ public NexusPollTask( .setNamespace(namespace) .setIdentity(identity) .setTaskQueue(TaskQueue.newBuilder().setName(taskQueue)); + pollRequest.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequest.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java index d826e5543..ac364a747 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java @@ -111,6 +111,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), workerMetricsScope, service.getServerCapabilities(), @@ -118,7 +119,7 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { poller = @@ -129,6 +130,7 @@ public boolean start() { namespace, taskQueue, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), this.slotSupplier, workerMetricsScope, @@ -136,7 +138,8 @@ public boolean start() { pollerTracker), this.pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java index f8baba01d..559370772 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SingleWorkerOptions.java @@ -40,6 +40,7 @@ public static final class Builder { private Duration drainStickyTaskQueueTimeout; private boolean usingVirtualThreads; private WorkerDeploymentOptions deploymentOptions; + private String workerInstanceKey; private Builder() {} @@ -64,6 +65,7 @@ private Builder(SingleWorkerOptions options) { this.drainStickyTaskQueueTimeout = options.getDrainStickyTaskQueueTimeout(); this.usingVirtualThreads = options.isUsingVirtualThreads(); this.deploymentOptions = options.getDeploymentOptions(); + this.workerInstanceKey = options.getWorkerInstanceKey(); } public Builder setIdentity(String identity) { @@ -155,6 +157,11 @@ public Builder setDeploymentOptions(WorkerDeploymentOptions deploymentOptions) { return this; } + public Builder setWorkerInstanceKey(String workerInstanceKey) { + this.workerInstanceKey = workerInstanceKey; + return this; + } + public SingleWorkerOptions build() { PollerOptions pollerOptions = this.pollerOptions; if (pollerOptions == null) { @@ -193,7 +200,8 @@ public SingleWorkerOptions build() { this.defaultHeartbeatThrottleInterval, drainStickyTaskQueueTimeout, usingVirtualThreads, - this.deploymentOptions); + this.deploymentOptions, + this.workerInstanceKey); } } @@ -214,6 +222,7 @@ public SingleWorkerOptions build() { private final Duration drainStickyTaskQueueTimeout; private final boolean usingVirtualThreads; private final WorkerDeploymentOptions deploymentOptions; + private final String workerInstanceKey; private SingleWorkerOptions( String identity, @@ -232,7 +241,8 @@ private SingleWorkerOptions( Duration defaultHeartbeatThrottleInterval, Duration drainStickyTaskQueueTimeout, boolean usingVirtualThreads, - WorkerDeploymentOptions deploymentOptions) { + WorkerDeploymentOptions deploymentOptions, + String workerInstanceKey) { this.identity = identity; this.binaryChecksum = binaryChecksum; this.buildId = buildId; @@ -250,6 +260,7 @@ private SingleWorkerOptions( this.drainStickyTaskQueueTimeout = drainStickyTaskQueueTimeout; this.usingVirtualThreads = usingVirtualThreads; this.deploymentOptions = deploymentOptions; + this.workerInstanceKey = workerInstanceKey; } public String getIdentity() { @@ -331,6 +342,10 @@ public WorkerDeploymentOptions getDeploymentOptions() { return deploymentOptions; } + public String getWorkerInstanceKey() { + return workerInstanceKey; + } + public WorkerVersioningOptions getWorkerVersioningOptions() { return new WorkerVersioningOptions( this.getBuildId(), this.isUsingBuildIdForVersioning(), this.getDeploymentOptions()); diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java index 51ab7a700..18cf7fd4a 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/SyncWorkflowWorker.java @@ -3,9 +3,7 @@ import static io.temporal.internal.common.InternalUtils.createStickyTaskQueue; import io.temporal.api.common.v1.Payloads; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.taskqueue.v1.TaskQueue; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.client.WorkflowClient; import io.temporal.common.converter.DataConverter; import io.temporal.common.converter.EncodedValues; @@ -24,11 +22,9 @@ import io.temporal.workflow.Functions.Func1; import java.lang.reflect.Type; import java.time.Duration; -import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.*; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -64,8 +60,6 @@ public SyncWorkflowWorker( @Nonnull WorkflowClient client, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nonnull SingleWorkerOptions singleWorkerOptions, @Nonnull SingleWorkerOptions localActivityOptions, @Nonnull WorkflowRunLockManager runLocks, @@ -123,8 +117,6 @@ public SyncWorkflowWorker( client.getWorkflowServiceStubs(), namespace, taskQueue, - workerInstanceKey, - activeTaskQueueTypesSupplier, stickyTaskQueueName, singleWorkerOptions, runLocks, @@ -250,10 +242,6 @@ public TrackingSlotSupplier getLocalActivitySlotSupplier( return laWorker.getSlotSupplier(); } - public void setHeartbeatSupplier(Supplier supplier) { - workflowWorker.setHeartbeatSupplier(supplier); - } - public boolean hasStickyQueue() { return workflowWorker.hasStickyQueue(); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java index cdb5e5163..18607b5d1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowPollTask.java @@ -47,6 +47,7 @@ public WorkflowPollTask( @Nonnull String taskQueue, @Nullable String stickyTaskQueue, @Nonnull String identity, + @Nonnull String workerInstanceKey, @Nonnull WorkerVersioningOptions versioningOptions, @Nonnull TrackingSlotSupplier slotSupplier, @Nonnull StickyQueueBalancer stickyQueueBalancer, @@ -73,6 +74,7 @@ public WorkflowPollTask( PollWorkflowTaskQueueRequest.newBuilder() .setNamespace(Objects.requireNonNull(namespace)) .setIdentity(Objects.requireNonNull(identity)); + pollRequestBuilder.setWorkerInstanceKey(workerInstanceKey); if (versioningOptions.getWorkerDeploymentOptions() != null) { pollRequestBuilder.setDeploymentOptions( diff --git a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java index f98316d5d..a128c7b75 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/worker/WorkflowWorker.java @@ -13,11 +13,8 @@ import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.QueryResultType; import io.temporal.api.enums.v1.TaskQueueKind; -import io.temporal.api.enums.v1.TaskQueueType; -import io.temporal.api.enums.v1.WorkerStatus; import io.temporal.api.enums.v1.WorkflowTaskFailedCause; import io.temporal.api.failure.v1.Failure; -import io.temporal.api.worker.v1.WorkerHeartbeat; import io.temporal.api.workflowservice.v1.*; import io.temporal.failure.ApplicationFailure; import io.temporal.internal.logging.LoggerTag; @@ -33,7 +30,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.slf4j.Logger; @@ -41,7 +37,6 @@ import org.slf4j.MDC; final class WorkflowWorker implements SuspendableWorker { - private static final String GRACEFUL_SHUTDOWN_MESSAGE = "graceful shutdown"; private static final Logger log = LoggerFactory.getLogger(WorkflowWorker.class); private final WorkflowRunLockManager runLocks; @@ -58,9 +53,6 @@ final class WorkflowWorker implements SuspendableWorker { private final GrpcRetryer grpcRetryer; private final EagerActivityDispatcher eagerActivityDispatcher; private final TrackingSlotSupplier slotSupplier; - private volatile Supplier heartbeatSupplier; - private final String workerInstanceKey; - private final Supplier> activeTaskQueueTypesSupplier; private final TaskCounter taskCounter = new TaskCounter(); private final PollerTracker pollerTracker = new PollerTracker(); @@ -79,8 +71,6 @@ public WorkflowWorker( @Nonnull WorkflowServiceStubs service, @Nonnull String namespace, @Nonnull String taskQueue, - @Nonnull String workerInstanceKey, - @Nonnull Supplier> activeTaskQueueTypesSupplier, @Nullable String stickyTaskQueueName, @Nonnull SingleWorkerOptions options, @Nonnull WorkflowRunLockManager runLocks, @@ -92,8 +82,6 @@ public WorkflowWorker( this.service = Objects.requireNonNull(service); this.namespace = Objects.requireNonNull(namespace); this.taskQueue = Objects.requireNonNull(taskQueue); - this.workerInstanceKey = Objects.requireNonNull(workerInstanceKey); - this.activeTaskQueueTypesSupplier = Objects.requireNonNull(activeTaskQueueTypesSupplier); this.options = Objects.requireNonNull(options); this.stickyTaskQueueName = stickyTaskQueueName; this.pollerOptions = getPollerOptions(options); @@ -133,6 +121,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -146,6 +135,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -162,6 +152,7 @@ public boolean start() { taskQueue, null, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, workerMetricsScope, @@ -175,7 +166,7 @@ public boolean start() { pollers, this.pollTaskExecutor, pollerOptions, - namespaceCapabilities.isPollerAutoscaling(), + namespaceCapabilities, workerMetricsScope); } else { PollerBehaviorSimpleMaximum pollerBehavior = @@ -193,6 +184,7 @@ public boolean start() { taskQueue, stickyTaskQueueName, options.getIdentity(), + options.getWorkerInstanceKey(), options.getWorkerVersioningOptions(), slotSupplier, stickyQueueBalancer, @@ -202,7 +194,8 @@ public boolean start() { stickyPollerTracker), pollTaskExecutor, pollerOptions, - workerMetricsScope); + workerMetricsScope, + namespaceCapabilities); } poller.start(); workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1); @@ -232,46 +225,23 @@ public CompletableFuture shutdown(ShutdownManager shutdownManager, boolean stickyQueueBalancer, options.getDrainStickyTaskQueueTimeout()) : CompletableFuture.completedFuture(null)) .thenCompose(ignore -> poller.shutdown(shutdownManager, interruptTasks)); - return CompletableFuture.allOf( - pollerShutdown.thenCompose( - ignore -> { - ShutdownWorkerRequest.Builder shutdownReq = - ShutdownWorkerRequest.newBuilder() - .setIdentity(options.getIdentity()) - .setNamespace(namespace) - .setTaskQueue(taskQueue) - .setWorkerInstanceKey(workerInstanceKey) - .setReason(GRACEFUL_SHUTDOWN_MESSAGE) - .addAllTaskQueueTypes(activeTaskQueueTypesSupplier.get()); - if (stickyTaskQueueName != null) { - shutdownReq.setStickyTaskQueue(stickyTaskQueueName); - } - if (heartbeatSupplier != null) { - shutdownReq.setWorkerHeartbeat( - heartbeatSupplier.get().toBuilder() - .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) - .build()); - } - return shutdownManager.waitOnWorkerShutdownRequest( - service.futureStub().shutdownWorker(shutdownReq.build())); - }), - pollerShutdown - .thenCompose( - ignore -> - !interruptTasks - ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( - slotSupplier, supplierName) - : CompletableFuture.completedFuture(null)) - .thenCompose( - ignore -> - pollTaskExecutor != null - ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) - : CompletableFuture.completedFuture(null)) - .exceptionally( - e -> { - log.error("Unexpected exception during shutdown", e); - return null; - })); + return pollerShutdown + .thenCompose( + ignore -> + !interruptTasks + ? shutdownManager.waitForSupplierPermitsReleasedUnlimited( + slotSupplier, supplierName) + : CompletableFuture.completedFuture(null)) + .thenCompose( + ignore -> + pollTaskExecutor != null + ? pollTaskExecutor.shutdown(shutdownManager, interruptTasks) + : CompletableFuture.completedFuture(null)) + .exceptionally( + e -> { + log.error("Unexpected exception during shutdown", e); + return null; + }); } @Override @@ -363,10 +333,6 @@ public WorkflowTaskDispatchHandle reserveWorkflowExecutor() { .orElse(null); } - public void setHeartbeatSupplier(Supplier supplier) { - this.heartbeatSupplier = supplier; - } - public TrackingSlotSupplier getSlotSupplier() { return slotSupplier; } diff --git a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java index ce599c6ad..19a52c65e 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/Worker.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/Worker.java @@ -13,6 +13,7 @@ import io.temporal.api.worker.v1.WorkerHostInfo; import io.temporal.api.worker.v1.WorkerPollerInfo; import io.temporal.api.worker.v1.WorkerSlotsInfo; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; import io.temporal.client.WorkflowClient; import io.temporal.client.WorkflowClientOptions; import io.temporal.common.Experimental; @@ -27,6 +28,7 @@ import io.temporal.internal.worker.TaskCounter; import io.temporal.serviceclient.MetricsTag; import io.temporal.serviceclient.Version; +import io.temporal.serviceclient.WorkflowServiceStubs; import io.temporal.worker.tuning.*; import io.temporal.workflow.Functions; import io.temporal.workflow.Functions.Func; @@ -59,17 +61,22 @@ public final class Worker { private static final Logger log = LoggerFactory.getLogger(Worker.class); private final WorkerOptions options; private final String taskQueue; + private final String workerInstanceKey = UUID.randomUUID().toString(); private final List plugins; + private final WorkflowServiceStubs service; + private final String namespace; + private final String identity; + private final String stickyTaskQueueName; final SyncWorkflowWorker workflowWorker; final SyncActivityWorker activityWorker; final SyncNexusWorker nexusWorker; private final AtomicBoolean started = new AtomicBoolean(); private volatile boolean shuttingDown = false; - private final String workerInstanceKey = UUID.randomUUID().toString(); private volatile Instant startTime; private final WorkflowClientOptions clientOptions; private final @Nonnull WorkflowExecutorCache cache; private final Map previousHeartbeatSnapshots = new ConcurrentHashMap<>(); + private volatile Supplier heartbeatSupplier; private static final class TaskSnapshot { final int processed; @@ -106,22 +113,30 @@ private static final class TaskSnapshot { @Nonnull NamespaceCapabilities namespaceCapabilities) { Objects.requireNonNull(client, "client should not be null"); + Objects.requireNonNull(namespaceCapabilities, "namespaceCapabilities should not be null"); this.plugins = Objects.requireNonNull(plugins, "plugins should not be null"); Preconditions.checkArgument( !Strings.isNullOrEmpty(taskQueue), "taskQueue should not be an empty string"); this.taskQueue = taskQueue; + this.service = client.getWorkflowServiceStubs(); this.options = WorkerOptions.newBuilder(options).validateAndBuildWithDefaults(); this.clientOptions = client.getOptions(); this.cache = cache; factoryOptions = WorkerFactoryOptions.newBuilder(factoryOptions).validateAndBuildWithDefaults(); WorkflowClientOptions clientOptions = client.getOptions(); String namespace = clientOptions.getNamespace(); + this.namespace = namespace; Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); Scope taggedScope = metricsScope.tagged(tags); SingleWorkerOptions activityOptions = toActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); if (this.options.isLocalActivityWorkerOnly()) { activityWorker = null; } else { @@ -149,7 +164,12 @@ private static final class TaskSnapshot { SingleWorkerOptions nexusOptions = toNexusOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier nexusSlotSupplier = this.options.getWorkerTuner() == null ? new FixedSizeSlotSupplier<>(this.options.getMaxConcurrentNexusExecutionSize()) @@ -167,10 +187,16 @@ private static final class TaskSnapshot { clientOptions, taskQueue, contextPropagators, - taggedScope); + taggedScope, + workerInstanceKey); SingleWorkerOptions localActivityOptions = toLocalActivityOptions( - factoryOptions, this.options, clientOptions, contextPropagators, taggedScope); + factoryOptions, + this.options, + clientOptions, + contextPropagators, + taggedScope, + workerInstanceKey); SlotSupplier workflowSlotSupplier = this.options.getWorkerTuner() == null @@ -183,18 +209,20 @@ private static final class TaskSnapshot { : this.options.getWorkerTuner().getLocalActivitySlotSupplier(); attachMetricsToResourceController(taggedScope, localActivitySlotSupplier); + this.identity = singleWorkerOptions.getIdentity(); + this.stickyTaskQueueName = + useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null; + workflowWorker = new SyncWorkflowWorker( client, namespace, taskQueue, - workerInstanceKey, - this::getActiveTaskQueueTypes, singleWorkerOptions, localActivityOptions, runLocks, cache, - useStickyTaskQueue ? getStickyTaskQueueName(client.getOptions().getIdentity()) : null, + stickyTaskQueueName, workflowThreadExecutor, eagerActivityDispatcher, workflowSlotSupplier, @@ -454,19 +482,48 @@ void start() { } CompletableFuture shutdown(ShutdownManager shutdownManager, boolean interruptUserTasks) { - shuttingDown = true; - CompletableFuture workflowWorkerShutdownFuture = - workflowWorker.shutdown(shutdownManager, interruptUserTasks); - CompletableFuture nexusWorkerShutdownFuture = - nexusWorker.shutdown(shutdownManager, interruptUserTasks); - if (activityWorker != null) { - return CompletableFuture.allOf( - activityWorker.shutdown(shutdownManager, interruptUserTasks), - workflowWorkerShutdownFuture, - nexusWorkerShutdownFuture); - } else { - return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + ShutdownWorkerRequest.Builder requestBuilder = + ShutdownWorkerRequest.newBuilder() + .setNamespace(namespace) + .setIdentity(identity) + .setWorkerInstanceKey(workerInstanceKey) + .setTaskQueue(taskQueue) + .setReason("graceful shutdown") + .addAllTaskQueueTypes(getActiveTaskQueueTypes()); + if (stickyTaskQueueName != null) { + requestBuilder.setStickyTaskQueue(stickyTaskQueueName); + } + if (heartbeatSupplier != null) { + requestBuilder.setWorkerHeartbeat( + heartbeatSupplier.get().toBuilder() + .setStatus(WorkerStatus.WORKER_STATUS_SHUTTING_DOWN) + .build()); } + CompletableFuture shutdownWorkerRpc = + shutdownManager.waitOnWorkerShutdownRequest( + service.futureStub().shutdownWorker(requestBuilder.build())); + + // When interrupting tasks (shutdownNow), fire the RPC but don't block on it — proceed to + // shut down pollers immediately. For graceful shutdown, wait for the RPC so the server can + // complete outstanding polls with empty responses before we start waiting on them. + CompletableFuture preShutdown = + interruptUserTasks ? CompletableFuture.completedFuture(null) : shutdownWorkerRpc; + + return preShutdown.thenCompose( + ignore -> { + CompletableFuture workflowWorkerShutdownFuture = + workflowWorker.shutdown(shutdownManager, interruptUserTasks); + CompletableFuture nexusWorkerShutdownFuture = + nexusWorker.shutdown(shutdownManager, interruptUserTasks); + if (activityWorker != null) { + return CompletableFuture.allOf( + activityWorker.shutdown(shutdownManager, interruptUserTasks), + workflowWorkerShutdownFuture, + nexusWorkerShutdownFuture); + } else { + return CompletableFuture.allOf(workflowWorkerShutdownFuture, nexusWorkerShutdownFuture); + } + }); } boolean isTerminated() { @@ -491,6 +548,10 @@ String getWorkerInstanceKey() { return workerInstanceKey; } + void setHeartbeatSupplier(Supplier supplier) { + this.heartbeatSupplier = supplier; + } + List getActiveTaskQueueTypes() { List types = new ArrayList<>(); if (workflowWorker.isAnyTypeSupported()) { @@ -826,8 +887,10 @@ private static SingleWorkerOptions toActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setUsingVirtualThreads(options.isUsingVirtualThreadsOnActivityWorker()) .setPollerOptions( PollerOptions.newBuilder() @@ -848,8 +911,10 @@ private static SingleWorkerOptions toNexusOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -870,7 +935,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( WorkflowClientOptions clientOptions, String taskQueue, List contextPropagators, - Scope metricsScope) { + Scope metricsScope, + String workerInstanceKey) { Map tags = new ImmutableMap.Builder(1).put(MetricsTag.TASK_QUEUE, taskQueue).build(); @@ -899,7 +965,8 @@ private static SingleWorkerOptions toWorkflowWorkerOptions( } } - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior( @@ -921,8 +988,10 @@ private static SingleWorkerOptions toLocalActivityOptions( WorkerOptions options, WorkflowClientOptions clientOptions, List contextPropagators, - Scope metricsScope) { - return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators) + Scope metricsScope, + String workerInstanceKey) { + return toSingleWorkerOptions( + factoryOptions, options, clientOptions, contextPropagators, workerInstanceKey) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -939,7 +1008,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( WorkerFactoryOptions factoryOptions, WorkerOptions options, WorkflowClientOptions clientOptions, - List contextPropagators) { + List contextPropagators, + String workerInstanceKey) { String buildId = null; if (options.getBuildId() != null) { buildId = options.getBuildId(); @@ -962,7 +1032,8 @@ private static SingleWorkerOptions.Builder toSingleWorkerOptions( .setWorkerInterceptors(factoryOptions.getWorkerInterceptors()) .setMaxHeartbeatThrottleInterval(options.getMaxHeartbeatThrottleInterval()) .setDefaultHeartbeatThrottleInterval(options.getDefaultHeartbeatThrottleInterval()) - .setDeploymentOptions(options.getDeploymentOptions()); + .setDeploymentOptions(options.getDeploymentOptions()) + .setWorkerInstanceKey(workerInstanceKey); } /** diff --git a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java index 741990624..c9bb8eb21 100644 --- a/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java +++ b/temporal-sdk/src/main/java/io/temporal/worker/WorkerFactory.java @@ -268,17 +268,8 @@ public synchronized void start() { DescribeNamespaceRequest.newBuilder() .setNamespace(workflowClient.getOptions().getNamespace()) .build()); - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getWorkerHeartbeats()) { - namespaceCapabilities.setWorkerHeartbeats(true); - } else { - log.debug( - "Server does not support worker heartbeats for namespace {}", - workflowClient.getOptions().getNamespace()); - } - - if (describeNamespaceResponse.getNamespaceInfo().getCapabilities().getPollerAutoscaling()) { - namespaceCapabilities.setPollerAutoscaling(true); - } + namespaceCapabilities.setFromCapabilities( + describeNamespaceResponse.getNamespaceInfo().getCapabilities()); // Build plugin execution chain (reverse order for proper nesting) Consumer startChain = WorkerFactory::doStart; @@ -321,7 +312,7 @@ private void doStart() { Supplier heartbeatSupplier = worker.buildHeartbeatCallback(workerGroupingKey); hbManager.registerWorker(namespace, worker.getWorkerInstanceKey(), heartbeatSupplier); - worker.workflowWorker.setHeartbeatSupplier(heartbeatSupplier); + worker.setHeartbeatSupplier(heartbeatSupplier); } } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java index 2ade97762..5faa34ca7 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/AsyncPollerTest.java @@ -133,7 +133,7 @@ private AsyncPoller newPoller( pollTask, taskExecutor, options, - false, + new NamespaceCapabilities(), new NoopScope()); } diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java new file mode 100644 index 000000000..ef0e93495 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/GracefulPollShutdownTest.java @@ -0,0 +1,150 @@ +package io.temporal.internal.worker; + +import static org.junit.Assert.*; + +import com.uber.m3.tally.NoopScope; +import io.temporal.api.namespace.v1.NamespaceInfo.Capabilities; +import io.temporal.worker.tuning.PollerBehaviorSimpleMaximum; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * Tests that an in-flight poll survives shutdown when graceful poll shutdown is enabled, and is + * killed promptly when it is not. + */ +@RunWith(Parameterized.class) +public class GracefulPollShutdownTest { + + @Parameterized.Parameter public boolean graceful; + + @Parameterized.Parameters(name = "graceful={0}") + public static Object[] data() { + return new Object[] {true, false}; + } + + @Test(timeout = 10_000) + public void inflightPollSurvivesShutdownOnlyWhenGraceful() throws Exception { + NamespaceCapabilities capabilities = new NamespaceCapabilities(); + capabilities.setFromCapabilities( + Capabilities.newBuilder().setWorkerPollCompleteOnShutdown(graceful).build()); + + AtomicReference processedTask = new AtomicReference<>(); + CountDownLatch taskProcessedLatch = new CountDownLatch(1); + ShutdownableTaskExecutor taskExecutor = + new ShutdownableTaskExecutor() { + @Override + public void process(@Nonnull String task) { + processedTask.set(task); + taskProcessedLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public CompletableFuture shutdown( + ShutdownManager shutdownManager, boolean interruptTasks) { + return CompletableFuture.completedFuture(null); + } + + @Override + public void awaitTermination(long timeout, TimeUnit unit) {} + }; + + // -- poll task: first call returns immediately, second blocks until released -- + CountDownLatch secondPollStarted = new CountDownLatch(1); + CountDownLatch releasePoll = new CountDownLatch(1); + + MultiThreadedPoller.PollTask pollTask = + new MultiThreadedPoller.PollTask() { + private int callCount = 0; + + @Override + public synchronized String poll() { + callCount++; + if (callCount == 1) { + return "task-1"; + } else if (callCount == 2) { + secondPollStarted.countDown(); + try { + releasePoll.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + return "task-2"; + } + // Subsequent calls just block until interrupted (simulates long poll) + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return null; + } + }; + + // -- create poller with 1 thread so polls are sequential -- + MultiThreadedPoller poller = + new MultiThreadedPoller<>( + "test-identity", + pollTask, + taskExecutor, + PollerOptions.newBuilder() + .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) + .setPollThreadNamePrefix("test-poller") + .build(), + new NoopScope(), + capabilities); + + poller.start(); + + // Wait for the first task to be processed (proves poller is running) + assertTrue("first task should be processed", taskProcessedLatch.await(5, TimeUnit.SECONDS)); + assertEquals("task-1", processedTask.get()); + + // Wait for the second poll to be in-flight + assertTrue("second poll should start", secondPollStarted.await(5, TimeUnit.SECONDS)); + + // Trigger shutdown (don't interrupt tasks) + ShutdownManager shutdownManager = new ShutdownManager(); + CompletableFuture shutdownFuture = poller.shutdown(shutdownManager, false); + + if (graceful) { + // In graceful mode the poller waits for the in-flight poll to complete. + // The shutdown should NOT have completed yet since the poll is still blocked. + assertFalse("shutdown should not complete while poll is in-flight", shutdownFuture.isDone()); + + // Simulate the server returning the poll response (as it would after ShutdownWorker RPC) + releasePoll.countDown(); + + // Wait for shutdown to complete - the poll should return "task-2" and be processed + shutdownFuture.get(5, TimeUnit.SECONDS); + + assertEquals("task-2", processedTask.get()); + } else { + // In legacy mode the poller forcefully interrupts in-flight polls. + // Shutdown should complete quickly without releasing the blocked poll. + shutdownFuture.get(5, TimeUnit.SECONDS); + + // The second task should NOT have been processed since the poll was killed. + assertNotEquals( + "task-2 should not be processed in legacy mode", "task-2", processedTask.get()); + } + + shutdownManager.close(); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java index e4223c0b5..c6f11a61a 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/SlotSupplierTest.java @@ -80,6 +80,7 @@ public void supplierIsCalledAppropriately() { TASK_QUEUE, "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, stickyQueueBalancer, @@ -172,6 +173,7 @@ public void asyncPollerSupplierIsCalledAppropriately() throws Exception { TASK_QUEUE, null, "", + "test-instance-key", new WorkerVersioningOptions("", false, null), trackingSS, metricsScope, diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java index 59538ac8b..ab806c960 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/StickyQueueBacklogTest.java @@ -68,6 +68,7 @@ public void stickyQueueBacklogResetTest() { "taskqueue", "stickytaskqueue", "", + "test-instance-key", new WorkerVersioningOptions("", false, null), slotSupplier, stickyQueueBalancer, @@ -97,6 +98,7 @@ public void stickyQueueBacklogResetTest() { .setKind(TaskQueueKind.TASK_QUEUE_KIND_STICKY) .build()) .setNamespace("default") + .setWorkerInstanceKey("test-instance-key") .build()))) .thenReturn(pollResponse); if (throwOnPoll) { diff --git a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java index d4e6e947c..d4f1824c2 100644 --- a/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java +++ b/temporal-sdk/src/test/java/io/temporal/internal/worker/WorkflowWorkerTest.java @@ -14,7 +14,6 @@ import com.uber.m3.util.ImmutableMap; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.common.v1.WorkflowType; -import io.temporal.api.enums.v1.TaskQueueType; import io.temporal.api.workflowservice.v1.*; import io.temporal.common.reporter.TestStatsReporter; import io.temporal.internal.common.InternalUtils; @@ -30,12 +29,8 @@ import io.temporal.worker.tuning.SlotSupplier; import io.temporal.worker.tuning.WorkflowSlotInfo; import java.time.Duration; -import java.util.Arrays; -import java.util.List; import java.util.UUID; import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; import org.junit.Test; import org.mockito.stubbing.Answer; import org.slf4j.Logger; @@ -74,12 +69,11 @@ public void concurrentPollRequestLockTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(3)) @@ -246,12 +240,11 @@ public void respondWorkflowTaskFailureMetricTest() throws Exception { client, "default", "task_queue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky_task_queue", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -391,12 +384,11 @@ public boolean isAnyTypeSupported() { client, "default", "taskQueue", - "test-worker-instance-key", - java.util.Collections::emptyList, "sticky", SingleWorkerOptions.newBuilder() .setIdentity("test_identity") .setBuildId(UUID.randomUUID().toString()) + .setWorkerInstanceKey(UUID.randomUUID().toString()) .setPollerOptions( PollerOptions.newBuilder() .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) @@ -444,80 +436,6 @@ public boolean isAnyTypeSupported() { worker.shutdown(new ShutdownManager(), true).get(); } - @Test - public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { - WorkflowServiceStubs client = mock(WorkflowServiceStubs.class); - when(client.getServerCapabilities()) - .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); - - WorkflowRunLockManager runLockManager = new WorkflowRunLockManager(); - Scope metricsScope = new NoopScope(); - WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLockManager, metricsScope); - SlotSupplier slotSupplier = new FixedSizeSlotSupplier<>(10); - - WorkflowTaskHandler taskHandler = mock(WorkflowTaskHandler.class); - when(taskHandler.isAnyTypeSupported()).thenReturn(true); - - // Supplier that starts with WORKFLOW only, then adds NEXUS later - AtomicReference> typesRef = - new AtomicReference<>(Arrays.asList(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - Supplier> supplier = typesRef::get; - - EagerActivityDispatcher eagerActivityDispatcher = mock(EagerActivityDispatcher.class); - WorkflowWorker worker = - new WorkflowWorker( - client, - "default", - "task_queue", - "test-worker-instance-key", - supplier, - null, - SingleWorkerOptions.newBuilder() - .setIdentity("test_identity") - .setBuildId(UUID.randomUUID().toString()) - .setPollerOptions( - PollerOptions.newBuilder() - .setPollerBehavior(new PollerBehaviorSimpleMaximum(1)) - .build()) - .setMetricsScope(metricsScope) - .build(), - runLockManager, - cache, - taskHandler, - eagerActivityDispatcher, - slotSupplier, - new NamespaceCapabilities()); - - // Simulate registering Nexus after construction - typesRef.set( - Arrays.asList( - TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW, - TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY, - TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - - WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = - mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); - when(client.futureStub()).thenReturn(futureStub); - when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) - .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); - - worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); - - org.mockito.ArgumentCaptor captor = - org.mockito.ArgumentCaptor.forClass(ShutdownWorkerRequest.class); - verify(futureStub).shutdownWorker(captor.capture()); - List shutdownTypes = captor.getValue().getTaskQueueTypesList(); - assertTrue( - "ShutdownWorkerRequest should include NEXUS type added after construction", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); - assertTrue( - "ShutdownWorkerRequest should include WORKFLOW type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); - assertTrue( - "ShutdownWorkerRequest should include ACTIVITY type", - shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); - } - private ReplayWorkflowFactory setUpMockWorkflowFactory() throws Throwable { ReplayWorkflow mockWorkflow = mock(ReplayWorkflow.class); ReplayWorkflowFactory mockFactory = mock(ReplayWorkflowFactory.class); diff --git a/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java new file mode 100644 index 000000000..d48e39725 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/worker/WorkerShutdownTest.java @@ -0,0 +1,141 @@ +package io.temporal.worker; + +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.google.common.util.concurrent.Futures; +import com.uber.m3.tally.NoopScope; +import com.uber.m3.tally.Scope; +import io.nexusrpc.handler.OperationHandler; +import io.nexusrpc.handler.OperationImpl; +import io.nexusrpc.handler.ServiceImpl; +import io.temporal.activity.ActivityInterface; +import io.temporal.activity.ActivityMethod; +import io.temporal.api.enums.v1.TaskQueueType; +import io.temporal.api.workflowservice.v1.GetSystemInfoResponse; +import io.temporal.api.workflowservice.v1.ShutdownWorkerRequest; +import io.temporal.api.workflowservice.v1.ShutdownWorkerResponse; +import io.temporal.api.workflowservice.v1.WorkflowServiceGrpc; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.internal.sync.WorkflowThreadExecutor; +import io.temporal.internal.worker.NamespaceCapabilities; +import io.temporal.internal.worker.ShutdownManager; +import io.temporal.internal.worker.WorkflowExecutorCache; +import io.temporal.internal.worker.WorkflowRunLockManager; +import io.temporal.serviceclient.WorkflowServiceStubs; +import io.temporal.workflow.WorkflowInterface; +import io.temporal.workflow.WorkflowMethod; +import io.temporal.workflow.shared.TestNexusServices; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +public class WorkerShutdownTest { + + @WorkflowInterface + public interface TestWorkflow { + @WorkflowMethod + void run(); + } + + public static class TestWorkflowImpl implements TestWorkflow { + @Override + public void run() {} + } + + @ActivityInterface + public interface TestActivity { + @ActivityMethod + void doThing(); + } + + public static class TestActivityImpl implements TestActivity { + @Override + public void doThing() {} + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return OperationHandler.sync((ctx, details, now) -> "Hello " + now); + } + } + + /** + * Verifies that the active task queue types in the ShutdownWorkerRequest are evaluated at + * shutdown time, not at Worker construction time. Types registered after construction must be + * reflected in the request. + */ + @Test + public void activeTaskQueueTypesEvaluatedAtShutdownTime() throws Exception { + WorkflowServiceStubs service = mock(WorkflowServiceStubs.class); + when(service.getServerCapabilities()) + .thenReturn(() -> GetSystemInfoResponse.Capabilities.newBuilder().build()); + + WorkflowServiceGrpc.WorkflowServiceFutureStub futureStub = + mock(WorkflowServiceGrpc.WorkflowServiceFutureStub.class); + when(service.futureStub()).thenReturn(futureStub); + when(futureStub.shutdownWorker(any(ShutdownWorkerRequest.class))) + .thenReturn(Futures.immediateFuture(ShutdownWorkerResponse.newBuilder().build())); + + WorkflowServiceGrpc.WorkflowServiceBlockingStub blockingStub = + mock(WorkflowServiceGrpc.WorkflowServiceBlockingStub.class); + when(service.blockingStub()).thenReturn(blockingStub); + when(blockingStub.withOption(any(), any())).thenReturn(blockingStub); + + WorkflowClient client = mock(WorkflowClient.class); + when(client.getWorkflowServiceStubs()).thenReturn(service); + when(client.getOptions()) + .thenReturn( + WorkflowClientOptions.newBuilder() + .setNamespace("test-ns") + .validateAndBuildWithDefaults()); + + Scope metricsScope = new NoopScope(); + WorkflowRunLockManager runLocks = new WorkflowRunLockManager(); + WorkflowExecutorCache cache = new WorkflowExecutorCache(10, runLocks, metricsScope); + WorkflowThreadExecutor wfThreadExecutor = mock(WorkflowThreadExecutor.class); + + Worker worker = + new Worker( + client, + "test-task-queue", + WorkerFactoryOptions.newBuilder().build(), + WorkerOptions.newBuilder().build(), + metricsScope, + runLocks, + cache, + false, + wfThreadExecutor, + Collections.emptyList(), + Collections.emptyList(), + new NamespaceCapabilities()); + + // Register types AFTER worker construction. The request built by shutdown should reflect + // these registrations, proving that getActiveTaskQueueTypes() is evaluated lazily. + worker.registerWorkflowImplementationTypes(TestWorkflowImpl.class); + worker.registerActivitiesImplementations(new TestActivityImpl()); + worker.registerNexusServiceImplementation(new TestNexusServiceImpl()); + + worker.shutdown(new ShutdownManager(), true).get(5, TimeUnit.SECONDS); + + ArgumentCaptor captor = + ArgumentCaptor.forClass(ShutdownWorkerRequest.class); + verify(futureStub).shutdownWorker(captor.capture()); + List shutdownTypes = captor.getValue().getTaskQueueTypesList(); + assertTrue( + "ShutdownWorkerRequest should include WORKFLOW type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_WORKFLOW)); + assertTrue( + "ShutdownWorkerRequest should include ACTIVITY type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_ACTIVITY)); + assertTrue( + "ShutdownWorkerRequest should include NEXUS type registered after construction", + shutdownTypes.contains(TaskQueueType.TASK_QUEUE_TYPE_NEXUS)); + } +}