Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/fix-slot-backpressure.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
'@pgflow/edge-worker': patch
---

Fix slot-aware backpressure and shutdown handling in the edge worker.

Fixes #114
18 changes: 13 additions & 5 deletions pkgs/edge-worker/src/core/BatchProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ export class BatchProcessor<TMessage extends IMessage> {
}

async processBatch() {
const availableSlots = this.executionController.availableSlots;
if (availableSlots <= 0) {
await this.executionController.waitForSlot();
return;
}

this.logger.polling();
const messageRecords = await this.poller.poll();
const messageRecords = await this.poller.poll(availableSlots);

if (this.signal.aborted) {
this.logger.info('Discarding messageRecords because worker is stopping');
Expand All @@ -28,10 +34,12 @@ export class BatchProcessor<TMessage extends IMessage> {

this.logger.taskCount(messageRecords.length);

const startPromises = messageRecords.map((message) =>
this.executionController.start(message)
);
await Promise.all(startPromises);
for (const message of messageRecords) {
void this.executionController.start(message).catch(() => {
// ExecutionController already logs task failures; swallow here so
// refilling the next slot does not produce unhandled rejections.
});
}
}

async awaitCompletion() {
Expand Down
50 changes: 48 additions & 2 deletions pkgs/edge-worker/src/core/ExecutionController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export class ExecutionController<TMessage extends IMessage> {
private promiseQueue: PromiseQueue;
private signal: AbortSignal;
private createExecutor: (record: TMessage, signal: AbortSignal) => IExecutor;
private readonly maxConcurrent: number;
private slotWaiters = new Set<() => void>();

constructor(
executorFactory: (record: TMessage, signal: AbortSignal) => IExecutor,
Expand All @@ -20,23 +22,59 @@ export class ExecutionController<TMessage extends IMessage> {
) {
this.signal = abortSignal;
this.createExecutor = executorFactory;
this.maxConcurrent = config.maxConcurrent;
this.promiseQueue = newQueue(config.maxConcurrent);
this.logger = logger;
}

async start(record: TMessage) {
get availableSlots(): number {
return Math.max(0, this.maxConcurrent - this.promiseQueue.size());
}

start(record: TMessage) {
const executor = this.createExecutor(record, this.signal);

this.logger.debug(`Scheduling execution of task ${executor.msgId}`);

return await this.promiseQueue.add(async () => {
return this.promiseQueue.add(async () => {
try {
this.logger.debug(`Executing task ${executor.msgId}...`);
await executor.execute();
this.logger.debug(`Execution successful for ${executor.msgId}`);
} catch (error) {
this.logger.error(`Execution failed for ${executor.msgId}`, error);
throw error;
} finally {
this.notifySlotWaiters();
}
});
}

async waitForSlot(): Promise<void> {
if (this.signal.aborted || this.availableSlots > 0) {
return;
}

await new Promise<void>((resolve) => {
const onAbort = () => {
cleanup();
resolve();
};
const onSlotFreed = () => {
cleanup();
resolve();
};
const cleanup = () => {
this.slotWaiters.delete(onSlotFreed);
this.signal.removeEventListener('abort', onAbort);
};

this.slotWaiters.add(onSlotFreed);
this.signal.addEventListener('abort', onAbort, { once: true });

if (this.signal.aborted || this.availableSlots > 0) {
cleanup();
resolve();
}
});
}
Expand All @@ -50,4 +88,12 @@ export class ExecutionController<TMessage extends IMessage> {
);
await this.promiseQueue.done();
}

private notifySlotWaiters() {
const waiters = [...this.slotWaiters];
this.slotWaiters.clear();
for (const waiter of waiters) {
waiter();
}
}
}
16 changes: 10 additions & 6 deletions pkgs/edge-worker/src/core/Worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export class Worker {
private lifecycle: ILifecycle;
private logger: Logger;
private abortController = new AbortController();
private readonly requestShutdown?: () => void;

private batchProcessor: IBatchProcessor;
private sql: postgres.Sql;
Expand All @@ -16,12 +17,14 @@ export class Worker {
batchProcessor: IBatchProcessor,
lifecycle: ILifecycle,
sql: postgres.Sql,
logger: Logger
logger: Logger,
requestShutdown?: () => void
) {
this.sql = sql;
this.lifecycle = lifecycle;
this.batchProcessor = batchProcessor;
this.logger = logger;
this.requestShutdown = requestShutdown;
}

startOnlyOnce(workerBootstrap: WorkerBootstrap) {
Expand Down Expand Up @@ -69,12 +72,13 @@ export class Worker {
return;
}

this.lifecycle.transitionToStopping();
this.lifecycle.transitionToStopping();

try {
// Signal deprecation (which includes "Stopped accepting new messages")
this.logDeprecation();
this.abortController.abort();
try {
// Signal deprecation (which includes "Stopped accepting new messages")
this.logDeprecation();
this.requestShutdown?.();
this.abortController.abort();

try {
this.logger.debug('-> Waiting for main loop to complete');
Expand Down
2 changes: 1 addition & 1 deletion pkgs/edge-worker/src/core/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export type { Json } from '@pgflow/core';
export type Supplier<T> = () => T;

export interface IPoller<IMessage> {
poll(): Promise<IMessage[]>;
poll(limit?: number): Promise<IMessage[]>;
}

export interface IExecutor {
Expand Down
7 changes: 4 additions & 3 deletions pkgs/edge-worker/src/flow/StepTaskPoller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,24 @@ export class StepTaskPoller<TFlow extends AnyFlow>
this.logger = logger;
}

async poll(): Promise<StepTaskWithMessage<TFlow>[]> {
async poll(limit?: number): Promise<StepTaskWithMessage<TFlow>[]> {
if (this.isAborted()) {
this.logger.debug('Polling aborted, returning empty array');
return [];
}

const workerId = this.getWorkerId();
const batchSize = limit ?? this.config.batchSize;
this.logger.debug(
`Two-phase polling for flow tasks with batch size ${this.config.batchSize}, maxPollSeconds: ${this.config.maxPollSeconds}, pollIntervalMs: ${this.config.pollIntervalMs}`
`Two-phase polling for flow tasks with batch size ${batchSize}, maxPollSeconds: ${this.config.maxPollSeconds}, pollIntervalMs: ${this.config.pollIntervalMs}`
);

try {
// Phase 1: Read messages from queue
const messages = await this.adapter.readMessages(
this.config.queueName,
this.config.visibilityTimeout ?? 2,
this.config.batchSize,
batchSize,
this.config.maxPollSeconds,
this.config.pollIntervalMs
);
Expand Down
8 changes: 7 additions & 1 deletion pkgs/edge-worker/src/flow/createFlowWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,11 @@ export function createFlowWorker<
);

// Return Worker
return new Worker(batchProcessor, lifecycle, sql, createLogger('Worker'));
return new Worker(
batchProcessor,
lifecycle,
sql,
createLogger('Worker'),
() => platformAdapter.requestShutdown()
);
}
12 changes: 7 additions & 5 deletions pkgs/edge-worker/src/platform/SupabasePlatformAdapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,17 @@ export class SupabasePlatformAdapter implements PlatformAdapter<SupabaseResource
}

async stopWorker(): Promise<void> {
// Trigger shutdown signal
this.abortController.abort();

// Cleanup resources
await this._platformResources.sql.end();
this.requestShutdown();

if (this.worker) {
await this.worker.stop();
}

await this._platformResources.sql.end();
}

requestShutdown(): void {
this.abortController.abort();
}

createLogger(module: string): Logger {
Expand Down
5 changes: 5 additions & 0 deletions pkgs/edge-worker/src/platform/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ export interface PlatformAdapter<TResources extends Record<string, unknown> = Re
*/
stopWorker(): Promise<void>;

/**
* Trigger the shared shutdown signal used by pollers, executors, and contexts.
*/
requestShutdown(): void;

/**
* Get the connection string for the database
* Returns undefined if sql was provided directly via config
Expand Down
8 changes: 5 additions & 3 deletions pkgs/edge-worker/src/queue/ReadWithPollPoller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ export class ReadWithPollPoller<TPayload extends Json> {
this.logger = logger;
}

async poll(): Promise<PgmqMessageRecord<TPayload>[]> {
async poll(limit?: number): Promise<PgmqMessageRecord<TPayload>[]> {
if (this.isAborted()) {
this.logger.debug('Polling aborted, returning empty array');
return [];
}

this.logger.debug(`Polling queue '${this.queue.queueName}' with batch size ${this.config.batchSize}`);
const batchSize = limit ?? this.config.batchSize;

this.logger.debug(`Polling queue '${this.queue.queueName}' with batch size ${batchSize}`);
const messages = await this.queue.readWithPoll(
this.config.batchSize,
batchSize,
this.config.visibilityTimeout,
this.config.maxPollSeconds,
this.config.pollIntervalMs
Expand Down
8 changes: 7 additions & 1 deletion pkgs/edge-worker/src/queue/createQueueWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -218,5 +218,11 @@ export function createQueueWorker<TPayload extends Json, TResources extends Reco
createLogger('BatchProcessor')
);

return new Worker(batchProcessor, lifecycle, sql, createLogger('Worker'));
return new Worker(
batchProcessor,
lifecycle,
sql,
createLogger('Worker'),
() => platformAdapter.requestShutdown()
);
}
1 change: 1 addition & 0 deletions pkgs/edge-worker/tests/integration/_helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export function createTestPlatformAdapter(sql: postgres.Sql): PlatformAdapter<Su
get platformResources() { return platformResources; },
get connectionString() { return integrationConfig.dbUrl; },
get isLocalEnvironment() { return false; },
requestShutdown() { abortController.abort(); },
async startWorker(_createWorkerFn: CreateWorkerFn) {},
async stopWorker() {},
};
Expand Down
71 changes: 70 additions & 1 deletion pkgs/edge-worker/tests/integration/flow/mapFlow.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ import {
assertAllStepsCompleted,
} from './_testHelpers.ts';

const SlotRefillFlow = new Flow<Array<{ id: string; delayMs: number }>>({
slug: 'test_slot_refill_map_flow',
}).map({ slug: 'process' }, async (item) => {
await delay(item.delayMs);
return item.id;
});

// Test 1: Root map - flow input is array, map processes each element
const RootMapFlow = new Flow<number[]>({ slug: 'test_root_map_flow' })
.map({ slug: 'double' }, async (num) => {
Expand Down Expand Up @@ -212,4 +219,66 @@ Deno.test(
await worker.stop();
}
})
);
);

Deno.test(
'map worker refills a freed slot before the slowest task in the previous batch finishes',
withPgNoTransaction(async (sql) => {
await sql`select pgflow_tests.reset_db();`;

const worker = startWorker(sql, SlotRefillFlow, {
maxConcurrent: 2,
batchSize: 2,
maxPollSeconds: 1,
pollIntervalMs: 100,
});

try {
await createRootMapFlow(sql, 'test_slot_refill_map_flow', 'process');

const flowRun = await startFlow(sql, SlotRefillFlow, [
{ id: 'slow', delayMs: 1000 },
{ id: 'fast', delayMs: 50 },
{ id: 'refill', delayMs: 50 },
]);

await waitForRunCompletion(sql, flowRun.run_id);

const taskTimes = await sql<
{ output: string; started_at: string | null; completed_at: string | null }[]
>`
SELECT output #>> '{}' AS output, started_at::text, completed_at::text
FROM pgflow.step_tasks
WHERE run_id = ${flowRun.run_id}
AND step_slug = 'process'
`;

const timings = new Map(
taskTimes.map((task) => [task.output, task])
);

const slow = timings.get('slow');
const fast = timings.get('fast');
const refill = timings.get('refill');

assertEquals(!!slow?.started_at, true);
assertEquals(!!fast?.completed_at, true);
assertEquals(!!refill?.started_at, true);

assertEquals(
new Date(refill!.started_at!).getTime() >=
new Date(fast!.completed_at!).getTime(),
true,
'refill task should wait for a free slot'
);
assertEquals(
new Date(refill!.started_at!).getTime() <
new Date(slow!.completed_at!).getTime(),
true,
'refill task should start before the slow task from the previous batch finishes'
);
} finally {
await worker.stop();
}
})
);
Loading
Loading