diff --git a/.changeset/fix-slot-backpressure.md b/.changeset/fix-slot-backpressure.md new file mode 100644 index 000000000..9997267e4 --- /dev/null +++ b/.changeset/fix-slot-backpressure.md @@ -0,0 +1,7 @@ +--- +'@pgflow/edge-worker': patch +--- + +Fix slot-aware backpressure and shutdown handling in the edge worker. + +Fixes #114 diff --git a/pkgs/edge-worker/src/core/BatchProcessor.ts b/pkgs/edge-worker/src/core/BatchProcessor.ts index 04c9019ee..81dc37efb 100644 --- a/pkgs/edge-worker/src/core/BatchProcessor.ts +++ b/pkgs/edge-worker/src/core/BatchProcessor.ts @@ -18,8 +18,14 @@ export class BatchProcessor { } async processBatch() { + const availableSlots = this.executionController.availableSlots; + if (availableSlots <= 0) { + await this.executionController.waitForSlot(); + return; + } + this.logger.polling(); - const messageRecords = await this.poller.poll(); + const messageRecords = await this.poller.poll(availableSlots); if (this.signal.aborted) { this.logger.info('Discarding messageRecords because worker is stopping'); @@ -28,10 +34,12 @@ export class BatchProcessor { this.logger.taskCount(messageRecords.length); - const startPromises = messageRecords.map((message) => - this.executionController.start(message) - ); - await Promise.all(startPromises); + for (const message of messageRecords) { + void this.executionController.start(message).catch(() => { + // ExecutionController already logs task failures; swallow here so + // refilling the next slot does not produce unhandled rejections. + }); + } } async awaitCompletion() { diff --git a/pkgs/edge-worker/src/core/ExecutionController.ts b/pkgs/edge-worker/src/core/ExecutionController.ts index a569f6fc6..d85c9af6d 100644 --- a/pkgs/edge-worker/src/core/ExecutionController.ts +++ b/pkgs/edge-worker/src/core/ExecutionController.ts @@ -11,6 +11,8 @@ export class ExecutionController { private promiseQueue: PromiseQueue; private signal: AbortSignal; private createExecutor: (record: TMessage, signal: AbortSignal) => IExecutor; + private readonly maxConcurrent: number; + private slotWaiters = new Set<() => void>(); constructor( executorFactory: (record: TMessage, signal: AbortSignal) => IExecutor, @@ -20,16 +22,21 @@ export class ExecutionController { ) { this.signal = abortSignal; this.createExecutor = executorFactory; + this.maxConcurrent = config.maxConcurrent; this.promiseQueue = newQueue(config.maxConcurrent); this.logger = logger; } - async start(record: TMessage) { + get availableSlots(): number { + return Math.max(0, this.maxConcurrent - this.promiseQueue.size()); + } + + start(record: TMessage) { const executor = this.createExecutor(record, this.signal); this.logger.debug(`Scheduling execution of task ${executor.msgId}`); - return await this.promiseQueue.add(async () => { + return this.promiseQueue.add(async () => { try { this.logger.debug(`Executing task ${executor.msgId}...`); await executor.execute(); @@ -37,6 +44,37 @@ export class ExecutionController { } catch (error) { this.logger.error(`Execution failed for ${executor.msgId}`, error); throw error; + } finally { + this.notifySlotWaiters(); + } + }); + } + + async waitForSlot(): Promise { + if (this.signal.aborted || this.availableSlots > 0) { + return; + } + + await new Promise((resolve) => { + const onAbort = () => { + cleanup(); + resolve(); + }; + const onSlotFreed = () => { + cleanup(); + resolve(); + }; + const cleanup = () => { + this.slotWaiters.delete(onSlotFreed); + this.signal.removeEventListener('abort', onAbort); + }; + + this.slotWaiters.add(onSlotFreed); + this.signal.addEventListener('abort', onAbort, { once: true }); + + if (this.signal.aborted || this.availableSlots > 0) { + cleanup(); + resolve(); } }); } @@ -50,4 +88,12 @@ export class ExecutionController { ); await this.promiseQueue.done(); } + + private notifySlotWaiters() { + const waiters = [...this.slotWaiters]; + this.slotWaiters.clear(); + for (const waiter of waiters) { + waiter(); + } + } } diff --git a/pkgs/edge-worker/src/core/Worker.ts b/pkgs/edge-worker/src/core/Worker.ts index f952cbe14..05fc190ab 100644 --- a/pkgs/edge-worker/src/core/Worker.ts +++ b/pkgs/edge-worker/src/core/Worker.ts @@ -6,6 +6,7 @@ export class Worker { private lifecycle: ILifecycle; private logger: Logger; private abortController = new AbortController(); + private readonly requestShutdown?: () => void; private batchProcessor: IBatchProcessor; private sql: postgres.Sql; @@ -16,12 +17,14 @@ export class Worker { batchProcessor: IBatchProcessor, lifecycle: ILifecycle, sql: postgres.Sql, - logger: Logger + logger: Logger, + requestShutdown?: () => void ) { this.sql = sql; this.lifecycle = lifecycle; this.batchProcessor = batchProcessor; this.logger = logger; + this.requestShutdown = requestShutdown; } startOnlyOnce(workerBootstrap: WorkerBootstrap) { @@ -69,12 +72,13 @@ export class Worker { return; } - this.lifecycle.transitionToStopping(); + this.lifecycle.transitionToStopping(); - try { - // Signal deprecation (which includes "Stopped accepting new messages") - this.logDeprecation(); - this.abortController.abort(); + try { + // Signal deprecation (which includes "Stopped accepting new messages") + this.logDeprecation(); + this.requestShutdown?.(); + this.abortController.abort(); try { this.logger.debug('-> Waiting for main loop to complete'); diff --git a/pkgs/edge-worker/src/core/types.ts b/pkgs/edge-worker/src/core/types.ts index 762cdd423..b3f5bc331 100644 --- a/pkgs/edge-worker/src/core/types.ts +++ b/pkgs/edge-worker/src/core/types.ts @@ -7,7 +7,7 @@ export type { Json } from '@pgflow/core'; export type Supplier = () => T; export interface IPoller { - poll(): Promise; + poll(limit?: number): Promise; } export interface IExecutor { diff --git a/pkgs/edge-worker/src/flow/StepTaskPoller.ts b/pkgs/edge-worker/src/flow/StepTaskPoller.ts index af2bb139c..dedf13846 100644 --- a/pkgs/edge-worker/src/flow/StepTaskPoller.ts +++ b/pkgs/edge-worker/src/flow/StepTaskPoller.ts @@ -36,15 +36,16 @@ export class StepTaskPoller this.logger = logger; } - async poll(): Promise[]> { + async poll(limit?: number): Promise[]> { if (this.isAborted()) { this.logger.debug('Polling aborted, returning empty array'); return []; } const workerId = this.getWorkerId(); + const batchSize = limit ?? this.config.batchSize; this.logger.debug( - `Two-phase polling for flow tasks with batch size ${this.config.batchSize}, maxPollSeconds: ${this.config.maxPollSeconds}, pollIntervalMs: ${this.config.pollIntervalMs}` + `Two-phase polling for flow tasks with batch size ${batchSize}, maxPollSeconds: ${this.config.maxPollSeconds}, pollIntervalMs: ${this.config.pollIntervalMs}` ); try { @@ -52,7 +53,7 @@ export class StepTaskPoller const messages = await this.adapter.readMessages( this.config.queueName, this.config.visibilityTimeout ?? 2, - this.config.batchSize, + batchSize, this.config.maxPollSeconds, this.config.pollIntervalMs ); diff --git a/pkgs/edge-worker/src/flow/createFlowWorker.ts b/pkgs/edge-worker/src/flow/createFlowWorker.ts index 9f53b9572..b024fd1b6 100644 --- a/pkgs/edge-worker/src/flow/createFlowWorker.ts +++ b/pkgs/edge-worker/src/flow/createFlowWorker.ts @@ -204,5 +204,11 @@ export function createFlowWorker< ); // Return Worker - return new Worker(batchProcessor, lifecycle, sql, createLogger('Worker')); + return new Worker( + batchProcessor, + lifecycle, + sql, + createLogger('Worker'), + () => platformAdapter.requestShutdown() + ); } diff --git a/pkgs/edge-worker/src/platform/SupabasePlatformAdapter.ts b/pkgs/edge-worker/src/platform/SupabasePlatformAdapter.ts index c11010911..92d999fd1 100644 --- a/pkgs/edge-worker/src/platform/SupabasePlatformAdapter.ts +++ b/pkgs/edge-worker/src/platform/SupabasePlatformAdapter.ts @@ -104,15 +104,17 @@ export class SupabasePlatformAdapter implements PlatformAdapter { - // Trigger shutdown signal - this.abortController.abort(); - - // Cleanup resources - await this._platformResources.sql.end(); + this.requestShutdown(); if (this.worker) { await this.worker.stop(); } + + await this._platformResources.sql.end(); + } + + requestShutdown(): void { + this.abortController.abort(); } createLogger(module: string): Logger { diff --git a/pkgs/edge-worker/src/platform/types.ts b/pkgs/edge-worker/src/platform/types.ts index 80b89e437..972a4eff0 100644 --- a/pkgs/edge-worker/src/platform/types.ts +++ b/pkgs/edge-worker/src/platform/types.ts @@ -77,6 +77,11 @@ export interface PlatformAdapter = Re */ stopWorker(): Promise; + /** + * Trigger the shared shutdown signal used by pollers, executors, and contexts. + */ + requestShutdown(): void; + /** * Get the connection string for the database * Returns undefined if sql was provided directly via config diff --git a/pkgs/edge-worker/src/queue/ReadWithPollPoller.ts b/pkgs/edge-worker/src/queue/ReadWithPollPoller.ts index e49170090..121ad417f 100644 --- a/pkgs/edge-worker/src/queue/ReadWithPollPoller.ts +++ b/pkgs/edge-worker/src/queue/ReadWithPollPoller.ts @@ -22,15 +22,17 @@ export class ReadWithPollPoller { this.logger = logger; } - async poll(): Promise[]> { + async poll(limit?: number): Promise[]> { if (this.isAborted()) { this.logger.debug('Polling aborted, returning empty array'); return []; } - this.logger.debug(`Polling queue '${this.queue.queueName}' with batch size ${this.config.batchSize}`); + const batchSize = limit ?? this.config.batchSize; + + this.logger.debug(`Polling queue '${this.queue.queueName}' with batch size ${batchSize}`); const messages = await this.queue.readWithPoll( - this.config.batchSize, + batchSize, this.config.visibilityTimeout, this.config.maxPollSeconds, this.config.pollIntervalMs diff --git a/pkgs/edge-worker/src/queue/createQueueWorker.ts b/pkgs/edge-worker/src/queue/createQueueWorker.ts index 081a5b850..78d659008 100644 --- a/pkgs/edge-worker/src/queue/createQueueWorker.ts +++ b/pkgs/edge-worker/src/queue/createQueueWorker.ts @@ -218,5 +218,11 @@ export function createQueueWorker platformAdapter.requestShutdown() + ); } diff --git a/pkgs/edge-worker/tests/integration/_helpers.ts b/pkgs/edge-worker/tests/integration/_helpers.ts index af521c485..5dd4fbe42 100644 --- a/pkgs/edge-worker/tests/integration/_helpers.ts +++ b/pkgs/edge-worker/tests/integration/_helpers.ts @@ -43,6 +43,7 @@ export function createTestPlatformAdapter(sql: postgres.Sql): PlatformAdapter>({ + slug: 'test_slot_refill_map_flow', +}).map({ slug: 'process' }, async (item) => { + await delay(item.delayMs); + return item.id; +}); + // Test 1: Root map - flow input is array, map processes each element const RootMapFlow = new Flow({ slug: 'test_root_map_flow' }) .map({ slug: 'double' }, async (num) => { @@ -212,4 +219,66 @@ Deno.test( await worker.stop(); } }) -); \ No newline at end of file +); + +Deno.test( + 'map worker refills a freed slot before the slowest task in the previous batch finishes', + withPgNoTransaction(async (sql) => { + await sql`select pgflow_tests.reset_db();`; + + const worker = startWorker(sql, SlotRefillFlow, { + maxConcurrent: 2, + batchSize: 2, + maxPollSeconds: 1, + pollIntervalMs: 100, + }); + + try { + await createRootMapFlow(sql, 'test_slot_refill_map_flow', 'process'); + + const flowRun = await startFlow(sql, SlotRefillFlow, [ + { id: 'slow', delayMs: 1000 }, + { id: 'fast', delayMs: 50 }, + { id: 'refill', delayMs: 50 }, + ]); + + await waitForRunCompletion(sql, flowRun.run_id); + + const taskTimes = await sql< + { output: string; started_at: string | null; completed_at: string | null }[] + >` + SELECT output #>> '{}' AS output, started_at::text, completed_at::text + FROM pgflow.step_tasks + WHERE run_id = ${flowRun.run_id} + AND step_slug = 'process' + `; + + const timings = new Map( + taskTimes.map((task) => [task.output, task]) + ); + + const slow = timings.get('slow'); + const fast = timings.get('fast'); + const refill = timings.get('refill'); + + assertEquals(!!slow?.started_at, true); + assertEquals(!!fast?.completed_at, true); + assertEquals(!!refill?.started_at, true); + + assertEquals( + new Date(refill!.started_at!).getTime() >= + new Date(fast!.completed_at!).getTime(), + true, + 'refill task should wait for a free slot' + ); + assertEquals( + new Date(refill!.started_at!).getTime() < + new Date(slow!.completed_at!).getTime(), + true, + 'refill task should start before the slow task from the previous batch finishes' + ); + } finally { + await worker.stop(); + } + }) +); diff --git a/pkgs/edge-worker/tests/integration/maxConcurrent.test.ts b/pkgs/edge-worker/tests/integration/maxConcurrent.test.ts index 5ae7d27b3..432a0ce5c 100644 --- a/pkgs/edge-worker/tests/integration/maxConcurrent.test.ts +++ b/pkgs/edge-worker/tests/integration/maxConcurrent.test.ts @@ -15,6 +15,99 @@ async function sleepFor1s() { await delay(1000); } +Deno.test( + 'refills a freed slot before the slowest task in the previous batch finishes', + withTransaction(async (sql) => { + const taskTimes = new Map< + string, + { startedAt?: number; finishedAt?: number } + >(); + + const worker = createQueueWorker( + async (rawPayload: { id: string; delayMs: number } | string) => { + const payload = typeof rawPayload === 'string' + ? JSON.parse(rawPayload) as { id: string; delayMs: number } + : rawPayload; + const timing = taskTimes.get(payload.id) ?? {}; + timing.startedAt = Date.now(); + taskTimes.set(payload.id, timing); + + await delay(payload.delayMs); + + timing.finishedAt = Date.now(); + taskTimes.set(payload.id, timing); + }, + { + sql, + maxConcurrent: 2, + batchSize: 2, + maxPollSeconds: 1, + visibilityTimeout: 5, + queueName: QUEUE_NAME, + }, + createFakeLogger, + createTestPlatformAdapter(sql) + ); + + try { + worker.startOnlyOnce({ + edgeFunctionName: 'test', + workerId: crypto.randomUUID(), + }); + await waitForQueue(sql, QUEUE_NAME); + + await sql` + SELECT pgmq.send_batch( + ${QUEUE_NAME}, + ARRAY[ + ${JSON.stringify({ id: 'slow', delayMs: 1000 })}::jsonb, + ${JSON.stringify({ id: 'fast', delayMs: 50 })}::jsonb, + ${JSON.stringify({ id: 'refill', delayMs: 50 })}::jsonb + ] + ) + `; + + await waitFor( + () => { + const slow = taskTimes.get('slow'); + const fast = taskTimes.get('fast'); + const refill = taskTimes.get('refill'); + + return taskTimes.size === 3 && slow?.finishedAt && fast?.finishedAt && refill?.finishedAt + ? { slow, fast, refill } + : false; + }, + { + timeoutMs: 5000, + pollIntervalMs: 20, + description: 'all queue tasks to finish', + } + ); + + const slow = taskTimes.get('slow')!; + const fast = taskTimes.get('fast')!; + const refill = taskTimes.get('refill')!; + + assertEquals(typeof slow.startedAt, 'number'); + assertEquals(typeof fast.startedAt, 'number'); + assertEquals(typeof refill.startedAt, 'number'); + + assertEquals( + (refill.startedAt as number) >= (fast.finishedAt as number), + true, + 'refill task should not start before a slot is freed' + ); + assertEquals( + (refill.startedAt as number) < (slow.finishedAt as number), + true, + 'refill task should start before the slow task from the previous batch finishes' + ); + } finally { + await worker.stop(); + } + }) +); + Deno.test( 'maxConcurrent option is respected', withTransaction(async (sql) => { diff --git a/pkgs/edge-worker/tests/integration/stopping_worker.test.ts b/pkgs/edge-worker/tests/integration/stopping_worker.test.ts new file mode 100644 index 000000000..a931838e1 --- /dev/null +++ b/pkgs/edge-worker/tests/integration/stopping_worker.test.ts @@ -0,0 +1,113 @@ +import { assertEquals } from '@std/assert'; +import { createQueueWorker } from '../../src/queue/createQueueWorker.ts'; +import type { PlatformAdapter } from '../../src/platform/types.ts'; +import type { SupabaseEnv, SupabaseResources } from '@pgflow/dsl/supabase'; +import { createServiceSupabaseClient } from '../../src/core/supabase-utils.ts'; +import { integrationConfig } from '../config.ts'; +import { withTransaction } from '../db.ts'; +import { createFakeLogger } from '../fakes.ts'; +import { waitFor } from '../e2e/_helpers.ts'; +import { waitForQueue } from '../helpers.ts'; + +const QUEUE_NAME = 'stopping_worker'; +const TEST_SUPABASE_ENV: SupabaseEnv = { + SUPABASE_DB_URL: 'postgresql://test', + SUPABASE_URL: 'https://test.supabase.co', + SUPABASE_ANON_KEY: 'test-anon-key', + SUPABASE_SERVICE_ROLE_KEY: 'test-service-key', + SB_EXECUTION_ID: 'test-execution-id', +}; + +function createAbortAwarePlatformAdapter( + sql: Parameters[0] extends (sql: infer T) => unknown ? T : never +): PlatformAdapter { + const abortController = new AbortController(); + const platformResources: SupabaseResources = { + sql, + supabase: createServiceSupabaseClient(TEST_SUPABASE_ENV), + }; + + return { + get env() { + return TEST_SUPABASE_ENV; + }, + get shutdownSignal() { + return abortController.signal; + }, + get platformResources() { + return platformResources; + }, + get connectionString() { + return integrationConfig.dbUrl; + }, + get isLocalEnvironment() { + return false; + }, + requestShutdown() { + abortController.abort(); + }, + async startWorker() {}, + async stopWorker() {}, + }; +} + +Deno.test( + 'worker.stop aborts the handler shutdown signal', + withTransaction(async (sql) => { + let handlerStarted = false; + let sawAbortedSignal = false; + + const worker = createQueueWorker( + async (_payload, context) => { + handlerStarted = true; + await new Promise((resolve) => { + context.shutdownSignal.addEventListener( + 'abort', + () => { + sawAbortedSignal = context.shutdownSignal.aborted; + resolve(); + }, + { once: true } + ); + }); + }, + { + sql, + maxConcurrent: 1, + maxPollSeconds: 1, + visibilityTimeout: 5, + queueName: QUEUE_NAME, + }, + createFakeLogger, + createAbortAwarePlatformAdapter(sql) + ); + + try { + worker.startOnlyOnce({ + edgeFunctionName: 'test', + workerId: crypto.randomUUID(), + }); + await waitForQueue(sql, QUEUE_NAME); + + await sql`SELECT pgmq.send(${QUEUE_NAME}, '{}'::jsonb)`; + + await waitFor(() => handlerStarted || false, { + timeoutMs: 5000, + pollIntervalMs: 50, + description: 'handler to start', + }); + + const stopStartedAt = Date.now(); + await worker.stop(); + + assertEquals(Date.now() - stopStartedAt < 300, true); + assertEquals(sawAbortedSignal, true); + } finally { + try { + await worker.stop(); + } catch { + // ignore cleanup errors after test assertion + } + } + }) +); diff --git a/pkgs/edge-worker/tests/unit/platform/SupabasePlatformAdapter.test.ts b/pkgs/edge-worker/tests/unit/platform/SupabasePlatformAdapter.test.ts index 6d22d5488..a7bad2bdf 100644 --- a/pkgs/edge-worker/tests/unit/platform/SupabasePlatformAdapter.test.ts +++ b/pkgs/edge-worker/tests/unit/platform/SupabasePlatformAdapter.test.ts @@ -337,3 +337,35 @@ Deno.test({ assertEquals(adapter.shutdownSignal.aborted, true); }, }); + +Deno.test({ + name: 'stopWorker drains worker before closing sql', + sanitizeResources: false, + fn: async () => { + const callOrder: string[] = []; + const deps = createMockDeps(); + const adapter = new SupabasePlatformAdapter(undefined, deps); + + (adapter as unknown as { worker: Worker | null }).worker = { + startOnlyOnce: () => {}, + stop: () => { + callOrder.push('worker.stop'); + return Promise.resolve(); + }, + } as unknown as Worker; + + const sql = adapter.sql as unknown as { end: () => Promise }; + sql.end = () => { + callOrder.push('sql.end'); + return Promise.resolve(); + }; + + adapter.shutdownSignal.addEventListener('abort', () => { + callOrder.push('abort'); + }, { once: true }); + + await adapter.stopWorker(); + + assertEquals(callOrder, ['abort', 'worker.stop', 'sql.end']); + }, +});