diff --git a/pyproject.toml b/pyproject.toml index ac9127e49..61fe159fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dev = [ "opentelemetry-exporter-otlp-proto-grpc>=1.11.1,<2", "opentelemetry-semantic-conventions>=0.40b0,<1", "opentelemetry-sdk-extension-aws>=2.0.0,<3", + "async-timeout>=4.0,<6; python_version < '3.11'", ] [tool.poe.tasks] diff --git a/temporalio/contrib/workflow_streams/README.md b/temporalio/contrib/workflow_streams/README.md new file mode 100644 index 000000000..ba5582e52 --- /dev/null +++ b/temporalio/contrib/workflow_streams/README.md @@ -0,0 +1,33 @@ +# Temporal Workflow Streams + +> โš ๏ธ **This package is currently at an experimental release stage.** โš ๏ธ + +**Workflow Streams** is a Temporal Python SDK contrib library that gives a +Workflow a durable, offset-addressed event channel for keeping outside +observers updated on the progress of the Workflow and its Activities. +Typical uses include driving a UI for a long-running AI agent, surfacing +status during in-flight payment or order processing, and reporting progress +from data pipelines. It is not designed for ultra-low-latency applications +such as real-time voice; per-roundtrip latency is around 100ms, and cost +scales with durable batches rather than tokens. + +Under the hood the stream is built directly on Temporal's existing +message-passing primitives: Signals carry publishes, Updates serve +long-poll subscriptions, and a Query exposes the current global offset. +The library packages the boilerplate that turns those primitives into +a usable stream: batching to amortize per-event overhead, deduplication +for exactly-once delivery, topic filtering, and continue-as-new helpers +that hand stream state across Workflow runs. + +## Documentation + +๐Ÿ“– **The full guide lives in the Temporal documentation site:** +**[Workflow Streams โ€” Python SDK](https://docs.temporal.io/develop/python/libraries/workflow-streams)** + +It covers installation, enabling streaming on a Workflow, publishing from +Workflows and Activities, subscribing, continue-as-new, delivery semantics, +codec and payload encoding, architecture, and caveats โ€” with runnable code +snippets throughout. + +For runnable end-to-end examples, see the +[Workflow Streams samples](https://github.com/temporalio/samples-python/tree/main/workflow-streams). diff --git a/temporalio/contrib/workflow_streams/__init__.py b/temporalio/contrib/workflow_streams/__init__.py new file mode 100644 index 000000000..41f670f0c --- /dev/null +++ b/temporalio/contrib/workflow_streams/__init__.py @@ -0,0 +1,43 @@ +"""Workflow Streams for Temporal workflows. + +.. warning:: + This package is experimental and may change in future versions. + +The Workflow Streams contrib library gives a workflow a durable, +offset-addressed event channel built from Signals and polling Updates +with an SSE bridge. Cost scales with durable batches, not tokens. +Latency is around 100ms per roundtrip; not for ultra-low-latency voice. + +See :py:class:`WorkflowStream` for the workflow-side stream object and +:py:class:`WorkflowStreamClient` for the external client interface. +""" + +from temporalio.contrib.workflow_streams._client import WorkflowStreamClient +from temporalio.contrib.workflow_streams._stream import WorkflowStream +from temporalio.contrib.workflow_streams._topic_handle import ( + TopicHandle, + WorkflowTopicHandle, +) +from temporalio.contrib.workflow_streams._types import ( + PollInput, + PollResult, + PublishEntry, + PublisherState, + PublishInput, + WorkflowStreamItem, + WorkflowStreamState, +) + +__all__ = [ + "PollInput", + "PollResult", + "PublishEntry", + "PublishInput", + "PublisherState", + "TopicHandle", + "WorkflowStream", + "WorkflowStreamClient", + "WorkflowStreamItem", + "WorkflowStreamState", + "WorkflowTopicHandle", +] diff --git a/temporalio/contrib/workflow_streams/_client.py b/temporalio/contrib/workflow_streams/_client.py new file mode 100644 index 000000000..e28437e69 --- /dev/null +++ b/temporalio/contrib/workflow_streams/_client.py @@ -0,0 +1,628 @@ +"""External-side client for Workflow Streams. + +Used by activities, starters, and any code with a workflow handle to +publish messages and subscribe to topics on a workflow that hosts a +:class:`WorkflowStream`. + +Each published value is turned into a :class:`Payload` via the client's +sync payload converter. The **codec chain** (e.g. encryption, compression) +is **not** run per item โ€” it runs once at the envelope +level when Temporal's SDK encodes the ``__temporal_workflow_stream_publish`` +signal args and the ``__temporal_workflow_stream_poll`` update result. +Running the codec per item as well would double-encrypt / double-compress, +because the envelope path covers the items again. The per-item +``Payload`` still carries the encoding metadata (``encoding: json/plain``, +``messageType``, etc.) required by ``subscribe(result_type=T)`` on the +consumer side. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from collections.abc import AsyncIterator +from datetime import timedelta +from typing import Any, TypeVar, overload + +from typing_extensions import Self + +from temporalio import activity +from temporalio.api.common.v1 import Payload +from temporalio.client import ( + Client, + WorkflowExecutionStatus, + WorkflowHandle, + WorkflowUpdateFailedError, + WorkflowUpdateRPCTimeoutOrCancelledError, +) +from temporalio.converter import DataConverter, PayloadConverter +from temporalio.service import RPCError, RPCStatusCode + +from ._topic_handle import TopicHandle +from ._types import ( + PollInput, + PollResult, + PublishEntry, + PublishInput, + WorkflowStreamItem, + _decode_payload, + _encode_payload, +) + +T = TypeVar("T") + + +class WorkflowStreamClient: + """Client for publishing to and subscribing from a workflow stream. + + .. warning:: + This class is experimental and may change in future versions. + + Create via :py:meth:`create` (explicit client + workflow id), + :py:meth:`from_within_activity` (infer both from the current activity + context), or by passing a handle directly to the constructor. + + For publishing, bind a typed topic handle and use the client as + an async context manager to get automatic batching:: + + client = WorkflowStreamClient.create(temporal_client, workflow_id) + events = client.topic("events", type=MyEvent) + async with client: + events.publish(my_event) + events.publish(another_event, force_flush=True) + ... # more publishing + # Buffer is flushed automatically on context manager exit. + + For subscribing:: + + client = WorkflowStreamClient.create(temporal_client, workflow_id) + async for item in client.subscribe(["events"], result_type=MyEvent): + process(item.data) + """ + + def __init__( + self, + handle: WorkflowHandle[Any, Any], + *, + client: Client | None = None, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> None: + """Create a stream client from a workflow handle. + + Prefer :py:meth:`create` โ€” it enables continue-as-new following + in ``subscribe()`` and supplies the :class:`Client` needed to + reach the data converter chain. + + Args: + handle: Workflow handle to the workflow hosting the stream. + client: Temporal client whose payload converter will be used + to turn published values into ``Payload`` objects and to + decode subscriptions when ``result_type`` is set. The + codec chain is **not** applied per item (doing so would + double-encrypt โ€” see module docstring). If ``None``, the + default payload converter is used. + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Must be less than the + workflow's ``publisher_ttl`` (default 15 minutes) to + preserve exactly-once delivery. Default: 10 minutes. + """ + self._handle: WorkflowHandle[Any, Any] = handle + self._client: Client | None = client + self._workflow_id = handle.id + self._batch_interval = batch_interval + self._max_batch_size = max_batch_size + self._max_retry_duration = max_retry_duration + self._buffer: list[tuple[str, Any]] = [] + self._flush_event = asyncio.Event() + self._flush_task: asyncio.Task[None] | None = None + self._flush_lock = asyncio.Lock() + self._publisher_id: str = uuid.uuid4().hex[:16] + self._sequence: int = 0 + self._pending: list[PublishEntry] | None = None + self._pending_seq: int = 0 + self._pending_since: float | None = None + self._topic_types: dict[str, type[Any]] = {} + + @classmethod + def create( + cls, + client: Client, + workflow_id: str, + *, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> WorkflowStreamClient: + """Create a stream client from a Temporal client and workflow ID. + + Use this when the caller has an explicit ``Client`` and + ``workflow_id`` in hand (starters, BFFs, other workflows' + activities). For code running inside an activity that targets + its own parent workflow, see :py:meth:`from_within_activity`. + + A client created through this method follows continue-as-new + chains in ``subscribe()`` and uses the client's payload + converter for per-item ``Payload`` construction. + + Args: + client: Temporal client. + workflow_id: ID of the workflow hosting the stream. + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Default: 10 minutes. + """ + handle = client.get_workflow_handle(workflow_id) + return cls( + handle, + client=client, + batch_interval=batch_interval, + max_batch_size=max_batch_size, + max_retry_duration=max_retry_duration, + ) + + @classmethod + def from_within_activity( + cls, + *, + batch_interval: timedelta = timedelta(seconds=2), + max_batch_size: int | None = None, + max_retry_duration: timedelta = timedelta(seconds=600), + ) -> WorkflowStreamClient: + """Create a stream client targeting the current activity's parent workflow. + + Must be called from within an activity that was scheduled by a + workflow. The Temporal client and parent workflow id are taken + from the activity context. + + Standalone activities โ€” those started directly via + :py:meth:`temporalio.client.Client.start_activity` rather than + from a workflow โ€” have no parent workflow, so this method + raises. Use :py:meth:`create` from a standalone activity, + passing ``activity.client()`` and the target workflow id + explicitly (typically threaded through the activity's input). + + Args: + batch_interval: Interval between automatic flushes. + max_batch_size: Auto-flush when buffer reaches this size. + max_retry_duration: Maximum time to retry a failed flush + before raising TimeoutError. Default: 10 minutes. + """ + info = activity.info() + workflow_id = info.workflow_id + if workflow_id is None: + raise RuntimeError( + "from_within_activity requires an activity scheduled by a workflow; " + "this activity has no parent workflow. From a standalone " + "activity, use WorkflowStreamClient.create(activity.client(), " + "workflow_id) with the target workflow id passed in explicitly." + ) + return cls.create( + activity.client(), + workflow_id, + batch_interval=batch_interval, + max_batch_size=max_batch_size, + max_retry_duration=max_retry_duration, + ) + + async def __aenter__(self) -> Self: + """Start the background flusher task.""" + self._flush_task = asyncio.create_task(self._run_flusher()) + return self + + async def __aexit__(self, *_exc: object) -> None: + """Stop the flusher and flush any remaining buffered entries.""" + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + self._flush_task = None + # Drain both pending and buffer. A single _flush() processes + # either pending OR buffer, not both โ€” so if the flusher was + # cancelled mid-signal (pending set) while the producer added + # more items (buffer non-empty), a single final flush would + # orphan the buffer. + while self._pending is not None or self._buffer: + await self._flush() + + def _publish_to_topic( + self, topic: str, value: Any, *, force_flush: bool = False + ) -> None: + """Internal publish path used by :class:`TopicHandle`. + + Not part of the public API โ€” call + :meth:`TopicHandle.publish` instead. + """ + self._buffer.append((topic, value)) + if force_flush or ( + self._max_batch_size is not None + and len(self._buffer) >= self._max_batch_size + ): + self._flush_event.set() + + @overload + def topic(self, name: str) -> TopicHandle[Any]: ... + @overload + def topic(self, name: str, *, type: type[T]) -> TopicHandle[T]: ... + + def topic( + self, name: str, *, type: type[T] | None = None + ) -> TopicHandle[T] | TopicHandle[Any]: + """Return a typed handle for publishing to and subscribing from ``name``. + + The handle records the topic name and value type so call sites + do not have to repeat them. Each :class:`WorkflowStreamClient` + instance binds a topic name to exactly one type: a second call + with an unequal type raises ``RuntimeError``. Repeating the + same call with the same type is idempotent and returns an + equivalent handle. + + Type uniformity is checked only on this client instance โ€” it + does not coordinate across processes. The check uses Python + equality on the type object; subtype and union-superset + relationships are not recognized. + + Omitting ``type`` (or passing ``type=typing.Any``) is the + documented escape hatch for heterogeneous topics or + dynamic-topic forwarders: the handle accepts any value, and + subscribers receive the converter's default decoded value. + Pre-built ``Payload`` values can be passed to + :meth:`TopicHandle.publish` regardless of the bound type + (zero-copy fast path) โ€” there is no need to bind the topic to + ``Payload`` itself, and doing so would break the subscribe + path (use ``result_type=RawValue`` on + :meth:`WorkflowStreamClient.subscribe` if you need raw + payloads on a subscriber). + + Args: + name: Topic name. + type: Value type bound to this handle. Used as the + ``result_type`` when subscribing through the handle. + Defaults to ``typing.Any`` (heterogeneous topic). + + Returns: + :class:`TopicHandle` bound to ``name`` and the resolved + type. + + Raises: + RuntimeError: If ``name`` is already bound on this client + to a different type. + """ + bound: Any = Any if type is None else type + if bound is Payload: + raise RuntimeError( + "Cannot bind a topic to type=Payload: the payload converter " + "has no Payload decode path, so TopicHandle.subscribe would " + "fail. Pre-built Payload values can be passed to " + "TopicHandle.publish on any-typed handle (zero-copy fast " + "path); omit type (or pass type=typing.Any) for " + "heterogeneous topics, and subscribe via " + "WorkflowStreamClient.subscribe with result_type=RawValue " + "when raw payloads are needed." + ) + existing = self._topic_types.get(name) + if existing is not None and existing != bound: + raise RuntimeError( + f"Topic {name!r} is already bound to type {existing!r} on this " + f"client; refusing to rebind to {bound!r}. Use a single type " + f"per topic, or omit type (=typing.Any) for heterogeneous topics." + ) + self._topic_types[name] = bound + return TopicHandle(self, name, bound) + + async def flush(self) -> None: + """Flush buffered (and pending) items and wait for server confirmation. + + Returns once the items buffered at call time have been signaled to + the workflow and acknowledged by the server. Returns immediately + if there is nothing to send. + + This is in addition to the declarative ``force_flush=True`` on + :py:meth:`TopicHandle.publish` and to the automatic flush on + context-manager exit. Use this when you need a synchronization + point โ€” proof that prior publications have reached the + server โ€” at a moment that does not naturally correspond to a + specific event. + + Safe to call concurrently with topic-handle publishes and with + the background flusher: the flush lock serializes signal sends. + Items added concurrently after entry may piggyback on this + flush or be deferred to a subsequent one. + + Raises: + TimeoutError: If a pending batch from a prior failure cannot + be sent within ``max_retry_duration``. The pending batch + is dropped; subsequent publications use a fresh sequence. + """ + while self._pending is not None or self._buffer: + await self._flush() + + def _payload_converter(self) -> PayloadConverter: + """Return the sync payload converter for per-item encode/decode. + + Uses the configured client's payload converter when available; + otherwise falls back to the default. The codec chain + (e.g. encryption, compression) is intentionally not + invoked here โ€” it runs once at the envelope level when the + signal/update goes over the wire. See module docstring. + """ + if self._client is not None: + return self._client.data_converter.payload_converter + return DataConverter.default.payload_converter + + def _encode_buffer(self, entries: list[tuple[str, Any]]) -> list[PublishEntry]: + """Convert buffered (topic, value) pairs to wire entries. + + Non-Payload values go through the sync payload converter so the + resulting ``Payload`` carries encoding metadata for + ``result_type=`` decode on the consumer side. Pre-built + Payloads bypass conversion. + """ + converter = self._payload_converter() + out: list[PublishEntry] = [] + for topic, value in entries: + if isinstance(value, Payload): + payload = value + else: + payload = converter.to_payloads([value])[0] + out.append(PublishEntry(topic=topic, data=_encode_payload(payload))) + return out + + async def _flush(self) -> None: + """Send buffered or pending messages to the workflow via signal. + + On failure, the pending batch and sequence are kept for retry. + Only advances the confirmed sequence on success. + """ + async with self._flush_lock: + if self._pending is not None: + # Retry path: check max_retry_duration + if ( + self._pending_since is not None + and time.monotonic() - self._pending_since + > self._max_retry_duration.total_seconds() + ): + # Advance confirmed sequence so the next batch gets + # a fresh sequence number. Without this, the next + # batch reuses pending_seq, which the workflow may + # have already accepted โ€” causing silent dedup + # (data loss). See DropPendingFixed / + # SequenceFreshness in the design doc. + self._sequence = self._pending_seq + self._pending = None + self._pending_seq = 0 + self._pending_since = None + raise TimeoutError( + f"Flush retry exceeded max_retry_duration " + f"({self._max_retry_duration}). Pending batch dropped. " + f"If the signal was delivered, items are in the log. " + f"If not, they are lost." + ) + batch = self._pending + seq = self._pending_seq + elif self._buffer: + # New batch path. Encode before clearing the buffer so + # a payload-converter exception leaves the items in + # place for inspection or retry rather than silently + # dropping them. + batch = self._encode_buffer(self._buffer) + self._buffer = [] + seq = self._sequence + 1 + self._pending = batch + self._pending_seq = seq + self._pending_since = time.monotonic() + else: + return + + try: + # If the SDK ever exposes request_id on signal() and the + # server dedups it across CAN, pinning + # request_id=f"{publisher_id}:{seq}" here lets the + # workflow-side dedup go away. See DESIGN ยง"Replace + # workflow-side dedup with server-side request_id". + await self._handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=batch, + publisher_id=self._publisher_id, + sequence=seq, + ), + ) + # Success: advance confirmed sequence, clear pending + self._sequence = seq + self._pending = None + self._pending_seq = 0 + self._pending_since = None + except Exception: + # Pending stays set for retry on the next _flush() call + raise + + async def _run_flusher(self) -> None: + """Background task: wait for timer OR force_flush wakeup, then flush.""" + while True: + try: + await asyncio.wait_for( + self._flush_event.wait(), + timeout=self._batch_interval.total_seconds(), + ) + except asyncio.TimeoutError: + pass + self._flush_event.clear() + await self._flush() + + @overload + def subscribe( + self, + topics: str | list[str] | None = ..., + from_offset: int = ..., + *, + result_type: type[T], + poll_cooldown: timedelta = ..., + ) -> AsyncIterator[WorkflowStreamItem[T]]: ... + @overload + def subscribe( + self, + topics: str | list[str] | None = ..., + from_offset: int = ..., + *, + result_type: None = None, + poll_cooldown: timedelta = ..., + ) -> AsyncIterator[WorkflowStreamItem[Any]]: ... + + async def subscribe( + self, + topics: str | list[str] | None = None, + from_offset: int = 0, + *, + result_type: type | None = None, + poll_cooldown: timedelta = timedelta(milliseconds=100), + ) -> AsyncIterator[WorkflowStreamItem[Any]]: + """Async iterator that polls for new items. + + Automatically follows continue-as-new chains when the client + was created via :py:meth:`create`. + + Args: + topics: Topic filter. A single topic name, a list of topic + names, or None. None or empty list means all topics. + from_offset: Global offset to start reading from. + result_type: Optional target type. Each yielded + :class:`WorkflowStreamItem` has its ``data`` decoded via + the client's sync payload converter. When omitted, the + converter's default ``Any`` decoding is used (for the + stock JSON converter that means a Python primitive, + ``dict``, or ``list``). Pass + ``result_type=temporalio.common.RawValue`` for an + opaque ``RawValue`` wrapping the original + ``Payload`` โ€” useful for heterogeneous topics where + the caller dispatches on ``Payload.metadata`` or wants + to forward the bytes without decoding. + poll_cooldown: Minimum interval between polls to avoid + overwhelming the workflow when items arrive faster + than the poll round-trip. Defaults to 100ms. + + Yields: + :class:`WorkflowStreamItem` for each matching item. + """ + if result_type is Payload: + raise RuntimeError( + "Cannot subscribe with result_type=Payload: the payload " + "converter has no Payload decode path. Omit result_type " + "for default decoding, or pass result_type=RawValue to " + "receive a RawValue wrapping the raw Payload." + ) + topic_filter: list[str] + if topics is None: + topic_filter = [] + elif isinstance(topics, str): + topic_filter = [topics] + else: + topic_filter = topics + offset = from_offset + while True: + try: + result: PollResult = await self._handle.execute_update( + "__temporal_workflow_stream_poll", + PollInput(topics=topic_filter, from_offset=offset), + result_type=PollResult, + ) + except asyncio.CancelledError: + return + except WorkflowUpdateFailedError as e: + cause_type = getattr(e.cause, "type", None) + if cause_type == "TruncatedOffset": + # Subscriber fell behind truncation. Retry from + # offset 0 which the stream treats as "from the + # beginning of whatever exists" (i.e., from + # base_offset). + offset = 0 + continue + if cause_type == "AcceptedUpdateCompletedWorkflow": + # Workflow returned (or continued-as-new) before + # this poll's update completed. Either follow the + # chain or exit cleanly. + if await self._follow_continue_as_new(): + continue + return + raise + except WorkflowUpdateRPCTimeoutOrCancelledError: + if await self._follow_continue_as_new(): + continue + return + except RPCError as e: + # Workflow may have completed between polls; subscribe + # exits cleanly on terminal status so callers don't + # have to wrap the iterator in error handling for the + # normal end-of-stream case. + if e.status != RPCStatusCode.NOT_FOUND: + raise + if await self._follow_continue_as_new(): + continue + if await self._workflow_in_terminal_state(): + return + raise + converter = self._payload_converter() + for wire_item in result.items: + payload = _decode_payload(wire_item.data) + data: Any = ( + converter.from_payload(payload) + if result_type is None + else converter.from_payload(payload, result_type) + ) + yield WorkflowStreamItem( + topic=wire_item.topic, + data=data, + offset=wire_item.offset, + ) + offset = result.next_offset + cooldown_secs = poll_cooldown.total_seconds() + if not result.more_ready and cooldown_secs > 0: + await asyncio.sleep(cooldown_secs) + + async def _follow_continue_as_new(self) -> bool: + """Check if the workflow continued-as-new and re-target the handle. + + Returns True if the handle was updated (caller should retry). + """ + if self._client is None: + return False + try: + desc = await self._handle.describe() + except Exception: + return False + if desc.status == WorkflowExecutionStatus.CONTINUED_AS_NEW: + self._handle = self._client.get_workflow_handle(self._workflow_id) + return True + return False + + async def _workflow_in_terminal_state(self) -> bool: + """Return True if the workflow has reached a terminal state. + + Used by ``subscribe()`` to distinguish "workflow finished โ€” + stream is done" from "wrong workflow id" when a poll RPC + returns NOT_FOUND. + """ + try: + desc = await self._handle.describe() + except Exception: + return False + return desc.status in ( + WorkflowExecutionStatus.COMPLETED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.CANCELED, + WorkflowExecutionStatus.TERMINATED, + WorkflowExecutionStatus.TIMED_OUT, + ) + + async def get_offset(self) -> int: + """Query the current global offset (base_offset + log length).""" + return await self._handle.query( + "__temporal_workflow_stream_offset", result_type=int + ) diff --git a/temporalio/contrib/workflow_streams/_stream.py b/temporalio/contrib/workflow_streams/_stream.py new file mode 100644 index 000000000..2753f04c2 --- /dev/null +++ b/temporalio/contrib/workflow_streams/_stream.py @@ -0,0 +1,469 @@ +"""Workflow-side stream object for Workflow Streams. + +Instantiate :class:`WorkflowStream` once from your workflow's ``@workflow.init`` +method. The constructor registers the stream signal, update, and query +handlers on the current workflow via +:func:`temporalio.workflow.set_signal_handler`, +:func:`temporalio.workflow.set_update_handler`, and +:func:`temporalio.workflow.set_query_handler`. + +For workflows that support continue-as-new, include a +``WorkflowStreamState | None`` field on the workflow input and pass it as +``prior_state`` โ€” it is ``None`` on fresh starts and carries accumulated +state on continue-as-new. + +Workflow-side and client-side topic handles +(:meth:`WorkflowTopicHandle.publish` and +:meth:`TopicHandle.publish`) both use the synchronous payload +converter for per-item ``Payload`` construction. The codec chain +(e.g. encryption, compression) is **not** run per item on either +side โ€” it runs once at the envelope level when Temporal's SDK +encodes the signal/update that carries the batch. Running it per +item as well would double-encrypt, because every signal arg +already goes through the client's ``DataConverter.encode`` at +dispatch time. +""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence +from datetime import timedelta +from typing import Any, Callable, NoReturn, TypeVar, overload + +from temporalio import workflow +from temporalio.api.common.v1 import Payload +from temporalio.exceptions import ApplicationError + +from ._topic_handle import WorkflowTopicHandle +from ._types import ( + PollInput, + PollResult, + PublisherState, + PublishInput, + WorkflowStreamItem, + WorkflowStreamState, + _decode_payload, + _encode_payload, + _WorkflowStreamWireItem, +) + +_PUBLISH_SIGNAL = "__temporal_workflow_stream_publish" +_POLL_UPDATE = "__temporal_workflow_stream_poll" +_OFFSET_QUERY = "__temporal_workflow_stream_offset" + +_MAX_POLL_RESPONSE_BYTES = 1_000_000 + +T = TypeVar("T") + + +def _payload_wire_size(payload: Payload, topic: str) -> int: + """Approximate poll-response contribution of a single item. + + Wire form is ``_WorkflowStreamWireItem(topic, base64(proto(Payload)), offset)``. + Base64 inflates by ~4/3; we use the serialized length as a + conservative approximation. + """ + return (payload.ByteSize() * 4 + 2) // 3 + len(topic) + + +class WorkflowStream: + """Workflow-side stream object โ€” append-only log with publish/poll handlers. + + .. warning:: + This class is experimental and may change in future versions. + + Construct once from ``@workflow.init``; the constructor registers + the stream signal, update, and query handlers on the current + workflow. Raises :class:`RuntimeError` if a ``WorkflowStream`` has + already been registered on the workflow. + + Registered handlers: + + - ``__temporal_workflow_stream_publish`` signal โ€” external publish with dedup + - ``__temporal_workflow_stream_poll`` update โ€” long-poll subscription + - ``__temporal_workflow_stream_offset`` query โ€” current log length + + Note: + Because the publish handler is registered dynamically from + ``__init__``, on the activation where the stream is + constructed the publish signal can be buffered until after + class-level signal/update handlers are scheduled. Define + such handlers as ``async`` and ``await asyncio.sleep(0)`` + before reading stream state, so the publish signal is + processed first. + """ + + def __init__(self, prior_state: WorkflowStreamState | None = None) -> None: + """Initialize stream state and register workflow handlers. + + Must be called directly from the workflow's ``@workflow.init`` + method. Calls made from ``@workflow.run``, helper methods, or + signal/update/query handlers raise :class:`RuntimeError`. + + The check inspects the immediate caller's frame and requires the + function name to be ``__init__``. + + Args: + prior_state: State carried from a previous run via + :meth:`get_state` through continue-as-new, or ``None`` + on first start. + + Raises: + RuntimeError: If not called directly from a method named + ``__init__``, or if the stream signal handler is + already registered on this workflow (i.e., + ``WorkflowStream`` was instantiated twice). + + Note: + When carrying state across continue-as-new, type the + carrying field as ``WorkflowStreamState | None``, not + ``Any``. The default data converter deserializes ``Any`` + fields as plain dicts, which silently strips the + ``WorkflowStreamState`` type and breaks the new run. + """ + caller = sys._getframe(1) + caller_name = caller.f_code.co_name + if caller_name != "__init__": + raise RuntimeError( + "WorkflowStream must be constructed directly from the workflow's " + f"@workflow.init method, not from {caller_name!r}." + ) + if workflow.get_signal_handler(_PUBLISH_SIGNAL) is not None: + raise RuntimeError( + "WorkflowStream is already registered on this workflow. " + "Construct WorkflowStream(...) at most once from @workflow.init." + ) + + if prior_state is not None: + self._log: list[WorkflowStreamItem[Payload]] = [ + WorkflowStreamItem(topic=item.topic, data=_decode_payload(item.data)) + for item in prior_state.log + ] + self._base_offset: int = prior_state.base_offset + self._publishers: dict[str, PublisherState] = { + pid: PublisherState(sequence=ps.sequence, last_seen=ps.last_seen) + for pid, ps in prior_state.publishers.items() + } + else: + self._log = [] + self._base_offset = 0 + self._publishers = {} + self._detaching: bool = False + self._topic_types: dict[str, type[Any]] = {} + + workflow.set_signal_handler(_PUBLISH_SIGNAL, self._on_publish) + workflow.set_update_handler( + _POLL_UPDATE, self._on_poll, validator=self._validate_poll + ) + workflow.set_query_handler(_OFFSET_QUERY, self._on_offset) + + def _publish_to_topic(self, topic: str, value: Any) -> None: + """Internal publish path used by :class:`WorkflowTopicHandle`. + + Not part of the public API โ€” call + :meth:`WorkflowTopicHandle.publish` instead. + """ + if isinstance(value, Payload): + payload = value + else: + payload = workflow.payload_converter().to_payloads([value])[0] + self._log.append(WorkflowStreamItem(topic=topic, data=payload)) + + @overload + def topic(self, name: str) -> WorkflowTopicHandle[Any]: ... + @overload + def topic(self, name: str, *, type: type[T]) -> WorkflowTopicHandle[T]: ... + + def topic( + self, name: str, *, type: type[T] | None = None + ) -> WorkflowTopicHandle[T] | WorkflowTopicHandle[Any]: + """Return a typed handle for publishing to ``name`` from this workflow. + + The handle records the topic name and value type so call sites + do not have to repeat them. Each :class:`WorkflowStream` + instance binds a topic name to exactly one type: a second call + with an unequal type raises ``RuntimeError``. Repeating the + same call with the same type is idempotent and returns an + equivalent handle. + + Type uniformity is checked only on this stream instance โ€” it + does not coordinate across publishers (other workflows, + activities, external clients). The check uses Python equality + on the type object; subtype and union-superset relationships + are not recognized. + + Omitting ``type`` (or passing ``type=typing.Any``) is the + documented escape hatch for heterogeneous topics. Pre-built + ``Payload`` values can be passed to + :meth:`WorkflowTopicHandle.publish` regardless of the bound + type (zero-copy fast path) โ€” there is no need to bind the + topic to ``Payload`` itself. + + Args: + name: Topic name. + type: Value type bound to this handle. Defaults to + ``typing.Any`` (heterogeneous topic). + + Returns: + :class:`WorkflowTopicHandle` bound to ``name`` and the + resolved type. + + Raises: + RuntimeError: If ``name`` is already bound on this stream + to a different type. + """ + bound: Any = Any if type is None else type + if bound is Payload: + raise RuntimeError( + "Cannot bind a topic to type=Payload. Pre-built Payload " + "values can be passed to WorkflowTopicHandle.publish on " + "any-typed handle (zero-copy fast path); omit type (or " + "pass type=typing.Any) for heterogeneous topics." + ) + existing = self._topic_types.get(name) + if existing is not None and existing != bound: + raise RuntimeError( + f"Topic {name!r} is already bound to type {existing!r} on this " + f"workflow stream; refusing to rebind to {bound!r}. Use a " + f"single type per topic, or omit type (=typing.Any) for " + f"heterogeneous topics." + ) + self._topic_types[name] = bound + return WorkflowTopicHandle(self, name, bound) + + def get_state( + self, *, publisher_ttl: timedelta = timedelta(seconds=900) + ) -> WorkflowStreamState: + """Return a serializable snapshot of stream state for continue-as-new. + + Drops dedup state for publishers idle longer than + ``publisher_ttl``. The TTL must exceed the + ``max_retry_duration`` of any client that may still be + retrying a failed flush. + + Args: + publisher_ttl: Duration after which an idle publisher's + dedup state is dropped. Default 15 minutes. + """ + now = workflow.now() + + active_publishers = { + pid: ps + for pid, ps in self._publishers.items() + if now - ps.last_seen < publisher_ttl + } + + return WorkflowStreamState( + log=[ + _WorkflowStreamWireItem( + topic=item.topic, data=_encode_payload(item.data) + ) + for item in self._log + ], + base_offset=self._base_offset, + publishers=active_publishers, + ) + + def detach_pollers(self) -> None: + """Release waiting pollers and reject new poll updates. + + After this call the stream's ``__temporal_workflow_stream_poll`` + update handler releases its in-flight subscribers on this run: + each waiting poll returns its current item batch (often empty) + so the consumer can either follow continue-as-new or stop, and + new polls are rejected at the validator. Publishes still land + in the in-memory log and ``get_state`` / ``continue_as_new`` + remain valid โ€” the stream is being held open just long enough + to snapshot state and hand off to the next run. + + Call this before + ``await workflow.wait_condition(workflow.all_handlers_finished)`` + and ``workflow.continue_as_new()``. + """ + self._detaching = True + + async def continue_as_new( + self, + build_args: Callable[[WorkflowStreamState], Sequence[Any]], + *, + publisher_ttl: timedelta = timedelta(seconds=900), + ) -> NoReturn: + """Detach pollers, wait for handlers, continue-as-new with built args. + + Replaces this three-line recipe for the common case where the + only continue-as-new parameter that varies is ``args``: + + .. code-block:: python + + self.stream.detach_pollers() + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new(args=...) + + ``build_args`` is invoked *after* pollers have been detached, + with the post-detach :class:`WorkflowStreamState` as its single + argument. The caller threads that state into whatever input + dataclass the workflow expects: + + .. code-block:: python + + await self.stream.continue_as_new(lambda state: [WorkflowInput( + items_processed=self.items_processed, + stream_state=state, + )]) + + Workflows that need to override other CAN parameters + (``task_queue``, ``retry_policy``, ``run_timeout``, etc.) should + keep using the explicit ``detach_pollers`` / ``wait_condition`` / + ``workflow.continue_as_new(...)`` recipe. + + Args: + build_args: Callable that receives the post-detach stream + state and returns the positional ``args`` for the new + run. + publisher_ttl: Forwarded to :meth:`get_state`. + + Does not return; ``workflow.continue_as_new`` raises an internal + exception that the SDK uses to close the run. + """ + self.detach_pollers() + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new( + args=build_args(self.get_state(publisher_ttl=publisher_ttl)), + ) + + def truncate(self, up_to_offset: int) -> None: + """Discard log entries before ``up_to_offset``. + + After truncation, polls requesting an offset before the new + base will receive an ApplicationError. All global offsets + remain monotonic. + + Raises ApplicationError (not ValueError) when ``up_to_offset`` + is past the end of the log so that callers invoking this from + an update handler surface it as an update failure rather than + a workflow-task poison pill. + + Args: + up_to_offset: The global offset to truncate up to + (exclusive). Entries at offsets + ``[base_offset, up_to_offset)`` are discarded. + """ + log_index = up_to_offset - self._base_offset + if log_index <= 0: + return + if log_index > len(self._log): + raise ApplicationError( + f"Cannot truncate to offset {up_to_offset}: " + f"valid range is [{self._base_offset}, {self._base_offset + len(self._log)})", + type="TruncateOutOfRange", + non_retryable=True, + ) + self._log = self._log[log_index:] + self._base_offset = up_to_offset + + def _on_publish(self, payload: PublishInput) -> None: + """Receive publications from external clients (activities, starters). + + Deduplicates using (publisher_id, sequence). If publisher_id is + set and the sequence is <= the last seen sequence for that + publisher, the entire batch is dropped as a duplicate. Batches + are atomic: the dedup decision applies to the whole batch, not + individual items. + + This block is a polyfill for missing server-side ``request_id`` + dedup across continue-as-new. If the SDK ever exposes + ``request_id`` on signals and the server dedups it across CAN, + this branch and the ``_publishers`` state become redundant. See + DESIGN ยง"Replace workflow-side dedup with server-side + request_id" for the migration plan. + """ + if payload.publisher_id: + existing = self._publishers.get(payload.publisher_id) + if existing is not None and payload.sequence <= existing.sequence: + return + self._publishers[payload.publisher_id] = PublisherState( + sequence=payload.sequence, + last_seen=workflow.now(), + ) + for entry in payload.items: + self._log.append( + WorkflowStreamItem(topic=entry.topic, data=_decode_payload(entry.data)) + ) + + async def _on_poll(self, payload: PollInput) -> PollResult: + """Long-poll: block until new items available or detaching, then return.""" + # Re-evaluate the predicate against current ``_base_offset`` on + # every iteration: a ``truncate()`` between this poll's arrival + # and the wait firing changes ``log_offset`` underneath us, so + # capturing it as a local would freeze the wait against stale + # state and the poll would only return when the long-poll RPC + # times out. + await workflow.wait_condition( + lambda: ( + payload.from_offset < self._base_offset + or len(self._log) > payload.from_offset - self._base_offset + or self._detaching + ), + ) + log_offset = payload.from_offset - self._base_offset + if log_offset < 0: + if payload.from_offset == 0: + # "From the beginning" โ€” start at whatever is available. + log_offset = 0 + else: + # Subscriber had a specific position that's been + # truncated. ApplicationError fails this update (client + # gets the error) without crashing the workflow task โ€” + # avoids a poison pill during replay. + raise ApplicationError( + f"Requested offset {payload.from_offset} has been truncated. " + f"Current base offset is {self._base_offset}.", + type="TruncatedOffset", + non_retryable=True, + ) + all_new = self._log[log_offset:] + if payload.topics: + topic_set = set(payload.topics) + candidates = [ + (self._base_offset + log_offset + i, item) + for i, item in enumerate(all_new) + if item.topic in topic_set + ] + else: + candidates = [ + (self._base_offset + log_offset + i, item) + for i, item in enumerate(all_new) + ] + # Cap response size to ~1MB wire bytes. + wire_items: list[_WorkflowStreamWireItem] = [] + size = 0 + more_ready = False + next_offset = self._base_offset + len(self._log) + for off, item in candidates: + item_size = _payload_wire_size(item.data, item.topic) + if size + item_size > _MAX_POLL_RESPONSE_BYTES and wire_items: + # Resume from this item on the next poll. + next_offset = off + more_ready = True + break + size += item_size + wire_items.append( + _WorkflowStreamWireItem( + topic=item.topic, data=_encode_payload(item.data), offset=off + ) + ) + return PollResult( + items=wire_items, + next_offset=next_offset, + more_ready=more_ready, + ) + + def _validate_poll(self, _payload: PollInput) -> None: + """Reject new polls when pollers are detached for continue-as-new.""" + if self._detaching: + raise RuntimeError("Workflow pollers are detached for continue-as-new") + + def _on_offset(self) -> int: + """Return the current global offset (base_offset + log length).""" + return self._base_offset + len(self._log) diff --git a/temporalio/contrib/workflow_streams/_topic_handle.py b/temporalio/contrib/workflow_streams/_topic_handle.py new file mode 100644 index 000000000..3b94e226f --- /dev/null +++ b/temporalio/contrib/workflow_streams/_topic_handle.py @@ -0,0 +1,164 @@ +"""Typed topic handles for Workflow Streams. + +A topic handle is a thin typed view over an underlying publisher. It +carries the topic name and the value type ``T`` so call sites do not +have to repeat them on every publish, and so cross-language SDKs can +mirror the binding cleanly. + +Type-uniformity is enforced per publisher instance: each +:class:`WorkflowStreamClient` (or :class:`WorkflowStream`) maps a topic +name to exactly one bound ``T``. Re-binding the same name to an +unequal type raises ``RuntimeError``. The check uses Python equality +on the type object โ€” primitives, dataclasses, generic aliases, and +unions all compare structurally โ€” and intentionally does not attempt +to recognize subtype or union-superset relationships. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from datetime import timedelta +from typing import TYPE_CHECKING, Generic, TypeVar + +from temporalio.api.common.v1 import Payload + +from ._types import WorkflowStreamItem + +if TYPE_CHECKING: + from ._client import WorkflowStreamClient + from ._stream import WorkflowStream + +T = TypeVar("T") + + +class TopicHandle(Generic[T]): + """Client-side handle for publishing to and subscribing from a single topic. + + .. warning:: + This class is experimental and may change in future versions. + + Constructed via :meth:`WorkflowStreamClient.topic`. Publishes share + the underlying client's batching, dedup, and codec path; this + object holds only the topic name and bound type. + """ + + def __init__( + self, + client: WorkflowStreamClient, + name: str, + type: type[T], + ) -> None: + """Bind the handle to a client, topic name, and type. + + Prefer :meth:`WorkflowStreamClient.topic` over calling this + directly; the factory is what records the per-client type + binding and rejects conflicts. + """ + self._client = client + self._name = name + self._type = type + + @property + def name(self) -> str: + """The topic name this handle is bound to.""" + return self._name + + @property + def type(self) -> type[T]: + """The value type this handle is bound to.""" + return self._type + + def publish(self, value: T | Payload, *, force_flush: bool = False) -> None: + """Buffer ``value`` for publishing on this topic. + + Equivalent to the underlying client's publish path; the value + flows through the same buffer, batch interval, and dedup + sequence. + + Args: + value: Value to publish. Goes through the client's sync + payload converter at flush time. A pre-built + :class:`temporalio.api.common.v1.Payload` bypasses + conversion (zero-copy fast path), regardless of the + handle's bound type. + force_flush: If True, wake the flusher to send immediately + (fire-and-forget โ€” does not block the caller). + """ + self._client._publish_to_topic(self._name, value, force_flush=force_flush) + + async def subscribe( + self, + from_offset: int = 0, + *, + poll_cooldown: timedelta = timedelta(milliseconds=100), + ) -> AsyncIterator[WorkflowStreamItem[T]]: + """Async iterator over items on this topic, decoded as ``T``. + + For raw ``Payload`` access, or any other decode type that + differs from the handle's bound ``T``, use + :meth:`WorkflowStreamClient.subscribe` directly with an + explicit ``result_type`` (typically + :class:`temporalio.common.RawValue`). The handle's bound + type intentionally cannot be ``Payload`` โ€” the converter has + no Payload decode path. + + Args: + from_offset: Global offset to start reading from. + poll_cooldown: Minimum interval between polls when there + are no new items. + """ + async for item in self._client.subscribe( + [self._name], + from_offset=from_offset, + result_type=self._type, + poll_cooldown=poll_cooldown, + ): + yield item + + +class WorkflowTopicHandle(Generic[T]): + """Workflow-side handle for publishing to a single topic. + + .. warning:: + This class is experimental and may change in future versions. + + Constructed via :meth:`WorkflowStream.topic`. Has no + ``subscribe`` โ€” workflows do not consume their own stream. + """ + + def __init__( + self, + stream: WorkflowStream, + name: str, + type: type[T], + ) -> None: + """Bind the handle to a stream, topic name, and type. + + Prefer :meth:`WorkflowStream.topic` over calling this directly; + the factory is what records the per-stream type binding and + rejects conflicts. + """ + self._stream = stream + self._name = name + self._type = type + + @property + def name(self) -> str: + """The topic name this handle is bound to.""" + return self._name + + @property + def type(self) -> type[T]: + """The value type this handle is bound to.""" + return self._type + + def publish(self, value: T | Payload) -> None: + """Append ``value`` to the workflow stream on this topic. + + Args: + value: Value to publish. Goes through the workflow's sync + payload converter. A pre-built + :class:`temporalio.api.common.v1.Payload` bypasses + conversion, regardless of the handle's bound type. + """ + self._stream._publish_to_topic(self._name, value) diff --git a/temporalio/contrib/workflow_streams/_types.py b/temporalio/contrib/workflow_streams/_types.py new file mode 100644 index 000000000..94bfb1a9b --- /dev/null +++ b/temporalio/contrib/workflow_streams/_types.py @@ -0,0 +1,171 @@ +"""Shared data types for the Workflow Streams contrib module. + +The user-facing ``data`` fields on :class:`WorkflowStreamItem` are +:class:`temporalio.api.common.v1.Payload`. Per-item values are converted to +``Payload`` by the payload converter at publish time, and the resulting +bytes/metadata are preserved per item so subscribers can decode with +``subscribe(result_type=T)``. The codec chain (e.g. encryption, compression) +applies once at the outer signal/update envelope level โ€” not separately to each +embedded item โ€” so codec behavior is symmetric between workflow-side and +client-side publishing. + +The wire representation (``PublishEntry``, ``_WorkflowStreamWireItem``) uses +base64-encoded ``Payload.SerializeToString()`` bytes because the default JSON +payload converter cannot serialize a ``Payload`` embedded inside a dataclass +(it only special-cases top-level Payloads on signal/update args). +""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass, field +from datetime import datetime +from typing import Generic, TypeVar + +from temporalio.api.common.v1 import Payload + +T = TypeVar("T") + + +# basedpyright flags _-prefixed module-level functions as unused even when +# sibling modules import them (_stream.py, _client.py). Vanilla pyright does +# not. Suppressions below are required for `poe lint`. +def _encode_payload(payload: Payload) -> str: # pyright: ignore[reportUnusedFunction] + """Wire format: base64(Payload.SerializeToString()).""" + return base64.b64encode(payload.SerializeToString()).decode("ascii") + + +def _decode_payload(wire: str) -> Payload: # pyright: ignore[reportUnusedFunction] + """Inverse of :func:`_encode_payload`.""" + payload = Payload() + payload.ParseFromString(base64.b64decode(wire)) + return payload + + +@dataclass +class WorkflowStreamItem(Generic[T]): + """A single item in the workflow stream's log. + + .. warning:: + This class is experimental and may change in future versions. + + The ``data`` field carries the decoded value produced by + :meth:`WorkflowStreamClient.subscribe`. The generic parameter ``T`` + matches the ``result_type`` passed to ``subscribe``: an instance of + ``T`` when ``result_type=T``, the converter's default ``Any`` + decoding when ``result_type`` is omitted, or a + :class:`temporalio.common.RawValue` wrapping the original + ``Payload`` when ``result_type=RawValue``. + + The ``offset`` field is populated at poll time from the item's + position in the global log. + """ + + topic: str + data: T + offset: int = 0 + + +@dataclass +class PublishEntry: + """A single entry to publish via signal (wire type). + + .. warning:: + This class is experimental and may change in future versions. + + ``data`` is base64-encoded ``Payload.SerializeToString()`` output โ€” + see module docstring for why a nested ``Payload`` cannot be used + directly. + """ + + topic: str + data: str + + +@dataclass +class PublishInput: + """Signal payload: batch of entries to publish. + + .. warning:: + This class is experimental and may change in future versions. + + Includes publisher_id and sequence to ensure exactly-once delivery. + """ + + items: list[PublishEntry] = field(default_factory=list) + publisher_id: str = "" + sequence: int = 0 + + +@dataclass +class PollInput: + """Update payload: request to poll for new items. + + .. warning:: + This class is experimental and may change in future versions. + """ + + topics: list[str] = field(default_factory=list) + from_offset: int = 0 + + +@dataclass +class _WorkflowStreamWireItem: + """Wire representation of a WorkflowStreamItem (base64 of serialized Payload).""" + + topic: str + data: str + offset: int = 0 + + +@dataclass +class PollResult: + """Update response: items matching the poll request. + + .. warning:: + This class is experimental and may change in future versions. + + ``items`` use the wire representation. When ``more_ready`` is True, + the response was truncated to stay within size limits and the + subscriber should poll again immediately rather than applying a + cooldown delay. + """ + + items: list[_WorkflowStreamWireItem] = field(default_factory=list) + next_offset: int = 0 + more_ready: bool = False + + +@dataclass +class PublisherState: + """Per-publisher dedup state. + + .. warning:: + This class is experimental and may change in future versions. + + Tracks the last accepted ``sequence`` and the ``workflow.now()`` at + which it was accepted, used together for at-least-once dedup and + TTL-based pruning at continue-as-new time. + """ + + sequence: int + last_seen: datetime + + +@dataclass +class WorkflowStreamState: + """Serializable snapshot of stream state for continue-as-new. + + .. warning:: + This class is experimental and may change in future versions. + + The containing workflow input must type the field as + ``WorkflowStreamState | None``, not ``Any``, so the default data converter + can reconstruct the dataclass from JSON. + + Log items use the wire representation for serialization stability. + """ + + log: list[_WorkflowStreamWireItem] = field(default_factory=list) + base_offset: int = 0 + publishers: dict[str, PublisherState] = field(default_factory=dict) diff --git a/tests/contrib/workflow_streams/__init__.py b/tests/contrib/workflow_streams/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/workflow_streams/test_payload_roundtrip.py b/tests/contrib/workflow_streams/test_payload_roundtrip.py new file mode 100644 index 000000000..545d3a405 --- /dev/null +++ b/tests/contrib/workflow_streams/test_payload_roundtrip.py @@ -0,0 +1,137 @@ +"""Regression guards for the workflow_streams Payload wire format. + +1. The default JSON converter does not handle ``Payload`` embedded in a + dataclass โ€” serialization fails with ``TypeError``. This rules out a + naive nested-Payload wire format. +2. A proto-serialized ``Payload`` inside a dataclass does round-trip. + This is the wire format used: base64 of ``Payload.SerializeToString()`` + inside ``PublishEntry``/``_WorkflowStreamWireItem``, surfacing + ``Payload`` (or a decoded value via ``result_type=``) at the user API. +""" + +from __future__ import annotations + +import base64 +import uuid +from dataclasses import dataclass, field + +import pytest + +from temporalio import workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client +from tests.helpers import new_worker + + +@dataclass +class NestedPayloadEnvelope: + items: list[Payload] = field(default_factory=list) + + +@dataclass +class SerializedEntry: + topic: str + data: str # base64(Payload.SerializeToString()) + + +@dataclass +class SerializedEnvelope: + items: list[SerializedEntry] = field(default_factory=list) + + +@workflow.defn +class NestedPayloadWorkflow: + def __init__(self) -> None: + self._received: NestedPayloadEnvelope | None = None + + @workflow.signal + def receive(self, envelope: NestedPayloadEnvelope) -> None: + self._received = envelope + + @workflow.query + def decoded_strings(self) -> list[str]: + assert self._received is not None + conv = workflow.payload_converter() + return [conv.from_payload(p, str) for p in self._received.items] + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._received is not None) + + +@workflow.defn +class SerializedPayloadWorkflow: + def __init__(self) -> None: + self._received: SerializedEnvelope | None = None + + @workflow.signal + def receive(self, envelope: SerializedEnvelope) -> None: + self._received = envelope + + @workflow.query + def decoded_strings(self) -> list[str]: + assert self._received is not None + conv = workflow.payload_converter() + out: list[str] = [] + for entry in self._received.items: + p = Payload() + p.ParseFromString(base64.b64decode(entry.data)) + out.append(conv.from_payload(p, str)) + return out + + @workflow.query + def topics(self) -> list[str]: + assert self._received is not None + return [e.topic for e in self._received.items] + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._received is not None) + + +@pytest.mark.asyncio +async def test_nested_payload_in_dataclass_fails(client: Client) -> None: + """Confirm the load-bearing negative result: Payload inside dataclass doesn't serialize.""" + conv = client.data_converter.payload_converter + payloads = [conv.to_payloads([v])[0] for v in ["hello", "world"]] + envelope = NestedPayloadEnvelope(items=payloads) + + async with new_worker(client, NestedPayloadWorkflow) as worker: + handle = await client.start_workflow( + NestedPayloadWorkflow.run, + id=f"nested-payload-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + with pytest.raises(TypeError, match="Payload is not JSON serializable"): + await handle.signal(NestedPayloadWorkflow.receive, envelope) + await handle.terminate() + + +@pytest.mark.asyncio +async def test_serialized_payload_fallback_round_trips(client: Client) -> None: + """Proto-serialize Payload -> base64 -> dataclass round-trips through signal.""" + conv = client.data_converter.payload_converter + originals = ["hello", "world", "payload"] + payloads = [conv.to_payloads([v])[0] for v in originals] + envelope = SerializedEnvelope( + items=[ + SerializedEntry( + topic=f"t{i}", + data=base64.b64encode(p.SerializeToString()).decode("ascii"), + ) + for i, p in enumerate(payloads) + ] + ) + + async with new_worker(client, SerializedPayloadWorkflow) as worker: + handle = await client.start_workflow( + SerializedPayloadWorkflow.run, + id=f"serialized-payload-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.signal(SerializedPayloadWorkflow.receive, envelope) + decoded = await handle.query(SerializedPayloadWorkflow.decoded_strings) + assert decoded == originals + topics = await handle.query(SerializedPayloadWorkflow.topics) + assert topics == ["t0", "t1", "t2"] + await handle.result() diff --git a/tests/contrib/workflow_streams/test_workflow_streams.py b/tests/contrib/workflow_streams/test_workflow_streams.py new file mode 100644 index 000000000..203e1313d --- /dev/null +++ b/tests/contrib/workflow_streams/test_workflow_streams.py @@ -0,0 +1,2569 @@ +"""E2E integration tests for temporalio.contrib.workflow_streams.""" + +from __future__ import annotations + +import asyncio +import sys +import uuid +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, cast +from unittest.mock import patch + +if sys.version_info >= (3, 11): + from asyncio import timeout as _async_timeout # pyright: ignore[reportUnreachable] +else: + from async_timeout import ( # pyright: ignore[reportUnreachable] + timeout as _async_timeout, + ) + +import google.protobuf.duration_pb2 +import nexusrpc +import nexusrpc.handler +import pytest + +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 +import temporalio.api.workflowservice.v1 +from temporalio import activity, nexus, workflow +from temporalio.client import ( + Client, + WorkflowHandle, + WorkflowUpdateFailedError, + WorkflowUpdateStage, +) +from temporalio.common import RawValue +from temporalio.contrib.workflow_streams import ( + PollInput, + PollResult, + PublishEntry, + PublishInput, + TopicHandle, + WorkflowStream, + WorkflowStreamClient, + WorkflowStreamItem, + WorkflowStreamState, + WorkflowTopicHandle, +) +from temporalio.contrib.workflow_streams._types import _encode_payload +from temporalio.converter import DataConverter +from temporalio.exceptions import ApplicationError +from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers import assert_eq_eventually, new_worker +from tests.helpers.nexus import make_nexus_endpoint_name + + +def _wire_bytes(data: bytes) -> str: + """Build a PublishEntry.data string from raw bytes. + + Mirrors what :class:`WorkflowStreamClient` produces on the encode path: + default payload converter turns the bytes into a ``Payload``, which + is then proto-serialized and base64-encoded for the wire. + """ + payload = DataConverter.default.payload_converter.to_payloads([data])[0] + return _encode_payload(payload) + + +# --------------------------------------------------------------------------- +# Test workflows (must be module-level, not local classes) +# --------------------------------------------------------------------------- + + +@workflow.defn +class BasicWorkflowStreamWorkflow: + @workflow.init + def __init__(self) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class ActivityPublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_items", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.stream.topic("status", type=bytes).publish(b"activity_done") + await workflow.wait_condition(lambda: self._closed) + + +@dataclass +class AgentEvent: + kind: str + payload: dict[str, Any] + + +@workflow.defn +class StructuredPublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.stream.topic("events", type=AgentEvent).publish( + AgentEvent(kind="tick", payload={"i": i}) + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class TopicHandlePublishWorkflow: + """Workflow that publishes via the workflow-side topic handle.""" + + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self.events = self.stream.topic("events", type=AgentEvent) + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.events.publish(AgentEvent(kind="tick", payload={"i": i})) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class WorkflowSidePublishWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.stream.topic("events", type=bytes).publish(f"item-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class MultiTopicWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_multi_topic", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class InterleavedWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + self.stream.topic("status", type=bytes).publish(b"started") + await workflow.execute_activity( + "publish_items", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.stream.topic("status", type=bytes).publish(b"done") + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class PriorityWorkflow: + @workflow.init + def __init__(self) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self) -> None: + await workflow.execute_activity( + "publish_with_priority", + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class FlushOnExitWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_batch_test", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class MaxBatchWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def publisher_sequences(self) -> dict[str, int]: + return {pid: ps.sequence for pid, ps in self.stream._publishers.items()} + + @workflow.run + async def run(self, count: int) -> None: + await workflow.execute_activity( + "publish_with_max_batch", + count, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + self.stream.topic("status", type=bytes).publish(b"activity_done") + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class LateWorkflowStreamWorkflow: + """Calls WorkflowStream() from @workflow.run, not from @workflow.init. + + The constructor inspects the caller's frame and requires the + function name to be ``__init__``; called from ``run``, it must + raise ``RuntimeError``. The workflow returns the error message so + the test can assert on it without forcing a workflow task failure. + """ + + @workflow.run + async def run(self) -> str: + try: + WorkflowStream() + except RuntimeError as e: + return str(e) + return "no error raised" + + +@workflow.defn +class DoubleInitWorkflow: + """Calls WorkflowStream() twice from @workflow.init. + + The first call succeeds; the second must raise RuntimeError because + the workflow stream signal handler is already registered. The workflow + stashes the error message so the test can assert on it without + forcing a workflow task failure. + """ + + @workflow.init + def __init__(self) -> None: + self.stream = WorkflowStream() + self._closed = False + self.double_init_error: str | None = None + try: + WorkflowStream() + except RuntimeError as e: + self.double_init_error = str(e) + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def get_double_init_error(self) -> str | None: + return self.double_init_error + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +# --------------------------------------------------------------------------- +# Activities +# --------------------------------------------------------------------------- + + +@activity.defn(name="publish_items") +async def publish_items(count: int) -> None: + client = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(milliseconds=500) + ) + async with client: + for i in range(count): + activity.heartbeat() + client.topic("events", type=bytes).publish(f"item-{i}".encode()) + + +@activity.defn(name="publish_multi_topic") +async def publish_multi_topic(count: int) -> None: + topics = ["a", "b", "c"] + client = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(milliseconds=500) + ) + async with client: + for i in range(count): + activity.heartbeat() + topic = topics[i % len(topics)] + client.topic(topic, type=bytes).publish(f"{topic}-{i}".encode()) + + +@activity.defn(name="publish_with_priority") +async def publish_with_priority() -> None: + # Long batch_interval AND long post-publish hold ensure that only a + # working force_flush wakeup can deliver items before __aexit__ flushes. + # The hold is deliberately much longer than the test's collect timeout + # so a regression (force_flush no-op) surfaces as a missing item rather + # than flaking on slow CI. + client = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(seconds=60) + ) + async with client: + client.topic("events", type=bytes).publish(b"normal-0") + client.topic("events", type=bytes).publish(b"normal-1") + client.topic("events", type=bytes).publish(b"priority", force_flush=True) + for _ in range(100): + activity.heartbeat() + await asyncio.sleep(0.1) + + +@activity.defn(name="publish_batch_test") +async def publish_batch_test(count: int) -> None: + client = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(seconds=60) + ) + async with client: + for i in range(count): + activity.heartbeat() + client.topic("events", type=bytes).publish(f"item-{i}".encode()) + + +@activity.defn(name="publish_with_max_batch") +async def publish_with_max_batch(count: int) -> None: + client = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(seconds=60), max_batch_size=3 + ) + async with client: + for i in range(count): + activity.heartbeat() + client.topic("events", type=bytes).publish(f"item-{i}".encode()) + # Yield so the flusher task can run when max_batch_size triggers + # _flush_event. Real workloads (e.g. agents awaiting LLM streams) + # yield constantly; a tight loop with no awaits would never let + # the flusher fire and would collapse back to exit-only flushing. + await asyncio.sleep(0) + # Long batch_interval ensures only max_batch_size triggers flushes. + # Context manager exit flushes any remainder. + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +async def _is_different_run( + old_handle: WorkflowHandle[Any, Any], + new_handle: WorkflowHandle[Any, Any], +) -> bool: + """Check if new_handle points to a different run than old_handle.""" + try: + desc = await new_handle.describe() + return desc.run_id != old_handle.result_run_id + except Exception: + return False + + +async def collect_items( + client: Client, + handle: WorkflowHandle[Any, Any], + topics: list[str] | None, + from_offset: int, + expected_count: int, + timeout: float = 15.0, + *, + result_type: type | None = bytes, +) -> list[WorkflowStreamItem]: + """Subscribe and collect exactly expected_count items, with timeout. + + Default ``result_type=bytes`` matches the bytes-oriented tests that + compare ``item.data`` against literal byte strings. Pass + ``result_type=None`` for the converter's default ``Any`` decoding, + or ``result_type=RawValue`` for a ``RawValue``-wrapped ``Payload``. + """ + stream = WorkflowStreamClient.create(client, handle.id) + items: list[WorkflowStreamItem] = [] + try: + async with _async_timeout(timeout): + async for item in stream.subscribe( + topics=topics, + from_offset=from_offset, + poll_cooldown=timedelta(0), + result_type=result_type, + ): + items.append(item) + if len(items) >= expected_count: + break + except asyncio.TimeoutError: + pass + return items + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_activity_publish_and_subscribe(client: Client) -> None: + """Activity publishes items, external client subscribes and receives them.""" + count = 10 + async with new_worker( + client, + ActivityPublishWorkflow, + activities=[publish_items], + ) as worker: + handle = await client.start_workflow( + ActivityPublishWorkflow.run, + count, + id=f"workflow-stream-basic-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # Collect activity items + the "activity_done" status item + items = await collect_items(client, handle, None, 0, count + 1) + assert len(items) == count + 1 + + # Check activity items + for i in range(count): + assert items[i].topic == "events" + assert items[i].data == f"item-{i}".encode() + + # Check workflow-side status item + assert items[count].topic == "status" + assert items[count].data == b"activity_done" + + await handle.signal(ActivityPublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_structured_type_round_trip(client: Client) -> None: + """Workflow publishes dataclass values; subscriber decodes via result_type.""" + count = 4 + async with new_worker(client, StructuredPublishWorkflow) as worker: + handle = await client.start_workflow( + StructuredPublishWorkflow.run, + count, + id=f"workflow-stream-structured-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + items = await collect_items( + client, handle, None, 0, count, result_type=AgentEvent + ) + assert len(items) == count + for i, item in enumerate(items): + assert isinstance(item.data, AgentEvent) + assert item.data == AgentEvent(kind="tick", payload={"i": i}) + + await handle.signal(StructuredPublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_default_decode_and_raw_value(client: Client) -> None: + """No ``result_type`` decodes via Any; ``result_type=RawValue`` yields a ``Payload``.""" + count = 2 + async with new_worker(client, StructuredPublishWorkflow) as worker: + handle = await client.start_workflow( + StructuredPublishWorkflow.run, + count, + id=f"workflow-stream-default-decode-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + any_items = await collect_items( + client, handle, None, 0, count, result_type=None + ) + assert len(any_items) == count + for i, item in enumerate(any_items): + # Default JSON converter decodes a dataclass to a plain dict. + assert item.data == {"kind": "tick", "payload": {"i": i}} + + raw_items = await collect_items( + client, handle, None, 0, count, result_type=RawValue + ) + assert len(raw_items) == count + for item in raw_items: + assert isinstance(item.data, RawValue) + assert item.data.payload.data # non-empty serialized JSON bytes + + await handle.signal(StructuredPublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_with_payload_result_type_rejected(client: Client) -> None: + """``subscribe(result_type=Payload)`` raises โ€” there is no Payload decode path. + + Mirrors the topic-handle rejection (``stream.topic(name, type=Payload)``) + so the direct ``subscribe`` API can't smuggle in the same ambiguity that + the topic-handle layer already guards against. Users wanting raw payloads + pass ``result_type=RawValue``. + """ + from temporalio.api.common.v1 import Payload + + handle = client.get_workflow_handle("nonexistent-workflow-id") + stream = WorkflowStreamClient(handle) + with pytest.raises(RuntimeError, match="result_type=Payload"): + async for _ in stream.subscribe(result_type=Payload): + pass + + +@pytest.mark.asyncio +async def test_topic_handle_workflow_side_publish_and_subscribe( + client: Client, +) -> None: + """Workflow publishes via WorkflowStream.topic; client subscribes via TopicHandle.""" + count = 3 + async with new_worker(client, TopicHandlePublishWorkflow) as worker: + handle = await client.start_workflow( + TopicHandlePublishWorkflow.run, + count, + id=f"workflow-stream-topic-handle-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + stream = WorkflowStreamClient.create(client, handle.id) + events = stream.topic("events", type=AgentEvent) + assert isinstance(events, TopicHandle) + assert events.name == "events" + assert events.type is AgentEvent + + items: list[WorkflowStreamItem] = [] + async with _async_timeout(15.0): + async for item in events.subscribe(poll_cooldown=timedelta(0)): + items.append(item) + if len(items) >= count: + break + assert [item.data for item in items] == [ + AgentEvent(kind="tick", payload={"i": i}) for i in range(count) + ] + + await handle.signal(TopicHandlePublishWorkflow.close) + + +@workflow.defn +class TopicHandleUniquenessWorkflow: + """Probes the WorkflowStream.topic uniqueness check in @workflow.init. + + Returns a tuple (idempotent_ok, error_message) so the test can assert + both branches: same-type rebind is silent, different-type rebind raises. + """ + + @workflow.init + def __init__(self) -> None: + from temporalio.api.common.v1 import Payload + + self.stream = WorkflowStream() + first = self.stream.topic("events", type=AgentEvent) + self._idempotent_ok = ( + isinstance( + self.stream.topic("events", type=AgentEvent), WorkflowTopicHandle + ) + and first.type is AgentEvent + ) + try: + self.stream.topic("events", type=bytes) + except RuntimeError as exc: + self._error = str(exc) + else: + self._error = "" + try: + self.stream.topic("misused", type=Payload) + except RuntimeError as exc: + self._payload_error = str(exc) + else: + self._payload_error = "" + + @workflow.run + async def run(self) -> tuple[bool, str, str]: + return (self._idempotent_ok, self._error, self._payload_error) + + +@pytest.mark.asyncio +async def test_topic_handle_uniqueness_on_workflow_stream(client: Client) -> None: + """Same-type rebind is idempotent; different-type rebind raises in @workflow.init. + + Also covers the workflow-side rejection of ``type=Payload`` โ€” + binding a topic to ``Payload`` itself has no decode path, so + ``WorkflowStream.topic`` raises in ``@workflow.init``. + """ + async with new_worker(client, TopicHandleUniquenessWorkflow) as worker: + idempotent_ok, error, payload_error = await client.execute_workflow( + TopicHandleUniquenessWorkflow.run, + id=f"workflow-stream-handle-unique-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + assert idempotent_ok is True + assert "already bound to type" in error + assert "events" in error + assert "type=Payload" in payload_error + + +@pytest.mark.asyncio +async def test_topic_handle_client_uniqueness(client: Client) -> None: + """Re-binding a topic name to a different type on a client raises.""" + handle = client.get_workflow_handle("nonexistent-workflow-id") + stream = WorkflowStreamClient(handle) + + first = stream.topic("events", type=AgentEvent) + assert first.name == "events" + assert first.type is AgentEvent + + # Same type is idempotent. + again = stream.topic("events", type=AgentEvent) + assert again.type is AgentEvent + + # Different type raises. + with pytest.raises(RuntimeError, match="already bound to type"): + stream.topic("events", type=bytes) + + # Different topic with a different type is fine. + other = stream.topic("other", type=bytes) + assert other.type is bytes + + # Any escape hatch coexists on a different topic. Omitting ``type`` + # is the documented form (defaults to ``typing.Any``); we also + # exercise the explicit ``type=Any`` path with the cast required + # because ``Any`` is a typing special form rather than a class. + raw = stream.topic("forwarded") + assert raw.type is Any + explicit = stream.topic( + "forwarded-explicit", type=cast(type[Any], cast(object, Any)) + ) + assert explicit.type is Any + + # Binding to Payload itself is rejected โ€” subscribers would have + # no decode path. Pre-built Payload values can still be published + # via a normally-typed handle (zero-copy fast path). + from temporalio.api.common.v1 import Payload + + with pytest.raises(RuntimeError, match="type=Payload"): + stream.topic("misused", type=Payload) + + +@pytest.mark.asyncio +async def test_topic_handle_payload_passthrough(client: Client) -> None: + """Pre-built Payloads pass through topic.publish regardless of bound type.""" + count = 2 + async with new_worker(client, BasicWorkflowStreamWorkflow) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-handle-payload-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + stream = WorkflowStreamClient.create( + client, handle.id, batch_interval=timedelta(milliseconds=50) + ) + events = stream.topic("events", type=bytes) + async with stream: + converter = DataConverter.default.payload_converter + for i in range(count): + payload = converter.to_payloads([f"raw-{i}".encode()])[0] + events.publish(payload) + await stream.flush() + + items = await collect_items(client, handle, ["events"], 0, count) + assert [item.data for item in items] == [ + f"raw-{i}".encode() for i in range(count) + ] + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_topic_filtering(client: Client) -> None: + """Publish to multiple topics, subscribe with filter.""" + count = 9 # 3 per topic + async with new_worker( + client, + MultiTopicWorkflow, + activities=[publish_multi_topic], + ) as worker: + handle = await client.start_workflow( + MultiTopicWorkflow.run, + count, + id=f"workflow-stream-filter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe to topic "a" only โ€” should get 3 items + a_items = await collect_items(client, handle, ["a"], 0, 3) + assert len(a_items) == 3 + assert all(item.topic == "a" for item in a_items) + + # Subscribe to ["a", "c"] โ€” should get 6 items + ac_items = await collect_items(client, handle, ["a", "c"], 0, 6) + assert len(ac_items) == 6 + assert all(item.topic in ("a", "c") for item in ac_items) + + # Subscribe to all (None) โ€” should get all 9 + all_items = await collect_items(client, handle, None, 0, 9) + assert len(all_items) == 9 + + await handle.signal(MultiTopicWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_from_offset_and_per_item_offsets(client: Client) -> None: + """Subscribe from zero and non-zero offsets; each item carries its global offset.""" + count = 5 + async with new_worker( + client, + WorkflowSidePublishWorkflow, + ) as worker: + handle = await client.start_workflow( + WorkflowSidePublishWorkflow.run, + count, + id=f"workflow-stream-offset-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe from offset 0 โ€” all items, offsets 0..count-1 + all_items = await collect_items(client, handle, None, 0, count) + assert len(all_items) == count + for i, item in enumerate(all_items): + assert item.offset == i + assert item.data == f"item-{i}".encode() + + # Subscribe from offset 3 โ€” items 3, 4 with offsets 3, 4 + later_items = await collect_items(client, handle, None, 3, 2) + assert len(later_items) == 2 + assert later_items[0].offset == 3 + assert later_items[0].data == b"item-3" + assert later_items[1].offset == 4 + assert later_items[1].data == b"item-4" + + await handle.signal(WorkflowSidePublishWorkflow.close) + + +@pytest.mark.asyncio +async def test_per_item_offsets_with_topic_filter(client: Client) -> None: + """Per-item offsets are global (not per-topic) even when filtering.""" + count = 9 # 3 per topic (a, b, c round-robin) + async with new_worker( + client, + MultiTopicWorkflow, + activities=[publish_multi_topic], + ) as worker: + handle = await client.start_workflow( + MultiTopicWorkflow.run, + count, + id=f"workflow-stream-item-offset-filter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Subscribe to topic "a" only โ€” items are at global offsets 0, 3, 6 + a_items = await collect_items(client, handle, ["a"], 0, 3) + assert len(a_items) == 3 + assert a_items[0].offset == 0 + assert a_items[1].offset == 3 + assert a_items[2].offset == 6 + + # Subscribe to topic "b" โ€” items are at global offsets 1, 4, 7 + b_items = await collect_items(client, handle, ["b"], 0, 3) + assert len(b_items) == 3 + assert b_items[0].offset == 1 + assert b_items[1].offset == 4 + assert b_items[2].offset == 7 + + await handle.signal(MultiTopicWorkflow.close) + + +@pytest.mark.asyncio +async def test_poll_truncated_offset_returns_application_error(client: Client) -> None: + """Polling a truncated offset raises ApplicationError (not ValueError) + and does not crash the workflow task.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"workflow-stream-trunc-error-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Truncate up to offset 3 via update โ€” completion is explicit. + await handle.execute_update("truncate", 3) + + # Poll from offset 1 (truncated) โ€” should get ApplicationError, + # NOT crash the workflow task. Catching WorkflowUpdateFailedError is + # sufficient to prove the handler raised ApplicationError: Temporal's + # update protocol completes the update with this error only when the + # handler raises ApplicationError. A bare ValueError (or any other + # exception) would fail the workflow task instead, causing + # execute_update to hang โ€” not raise. The follow-up collect_items + # below proves the workflow task wasn't poisoned. + with pytest.raises(WorkflowUpdateFailedError) as exc_info: + await handle.execute_update( + "__temporal_workflow_stream_poll", + PollInput(topics=[], from_offset=1), + result_type=PollResult, + ) + cause = exc_info.value.cause + assert isinstance(cause, ApplicationError) + assert cause.type == "TruncatedOffset" + + # Workflow should still be usable โ€” poll from valid offset 3 + items = await collect_items(client, handle, None, 3, 2) + assert len(items) == 2 + assert items[0].offset == 3 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_truncate_past_end_raises_application_error(client: Client) -> None: + """truncate() with an offset past the log end raises ApplicationError + (type=TruncateOutOfRange) โ€” the update surfaces as a clean failure + without poisoning the workflow task.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 2, + id=f"workflow-stream-trunc-oor-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Only 2 items exist; asking to truncate to offset 5 is out of range. + with pytest.raises(WorkflowUpdateFailedError) as exc_info: + await handle.execute_update("truncate", 5) + cause = exc_info.value.cause + assert isinstance(cause, ApplicationError) + assert cause.type == "TruncateOutOfRange" + + # Workflow task wasn't poisoned โ€” a valid poll still completes. + items = await collect_items(client, handle, None, 0, 2) + assert len(items) == 2 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_subscribe_recovers_from_truncation(client: Client) -> None: + """subscribe() auto-recovers when offset falls behind truncation.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"workflow-stream-trunc-recover-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Truncate first 3. The update returns after the handler completes. + await handle.execute_update("truncate", 3) + + # subscribe from offset 1 (truncated) โ€” should auto-recover + # and deliver items from base_offset (3) + stream = WorkflowStreamClient(handle) + items: list[WorkflowStreamItem] = [] + try: + async with _async_timeout(5): + async for item in stream.subscribe( + from_offset=1, poll_cooldown=timedelta(0), result_type=bytes + ): + items.append(item) + if len(items) >= 2: + break + except asyncio.TimeoutError: + pass + assert len(items) == 2 + assert items[0].offset == 3 + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_truncate_during_waiting_poll_raises_truncated_offset( + client: Client, +) -> None: + """A truncate that advances ``base_offset`` past a waiting poll's + ``from_offset`` must wake the poll and raise ``TruncatedOffset``. + + Reproduces the bug where ``_on_poll`` captured ``log_offset`` once + before ``wait_condition`` and then sliced ``self._log[log_offset:]`` + against the post-truncate state. With the old predicate + ``len(self._log) > log_offset`` the wait would either never fire + (truncation shrinks the log below the captured offset) or fire on a + later publish and silently emit the wrong items at offsets the + subscriber had already moved past. + """ + async with new_worker(client, TruncateRaceWorkflow) as worker: + handle = await client.start_workflow( + TruncateRaceWorkflow.run, + id=f"workflow-stream-trunc-race-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Seed: 5 items at offsets 0..4. base_offset stays 0. + await handle.execute_update(TruncateRaceWorkflow.publish, 5) + + # Park a poll from offset=10 โ€” past the current end of the log. + # With wait_for_stage=ACCEPTED the handler has begun executing + # and is parked at workflow.wait_condition by the time the + # client gets the handle back. + poll_handle = await handle.start_update( + "__temporal_workflow_stream_poll", + PollInput(topics=[], from_offset=10), + result_type=PollResult, + wait_for_stage=WorkflowUpdateStage.ACCEPTED, + ) + + # In one workflow activation: publish 7 more items (log grows to + # 12 entries at offsets 0..11) and then truncate to 11. Result: + # base_offset=11, log=[item @11]. The waiting poll's + # from_offset=10 is now strictly less than base_offset, so the + # fixed predicate must wake it and the post-wait recompute must + # raise TruncatedOffset. Both halves of the fix are exercised: + # without the predicate change the wait stays asleep through + # this activation; without the post-wait recompute the slice + # silently returns wrong items / next_offset. + await handle.execute_update(TruncateRaceWorkflow.publish_then_truncate, (7, 11)) + + with pytest.raises(WorkflowUpdateFailedError) as exc_info: + await poll_handle.result() + cause = exc_info.value.cause + assert isinstance(cause, ApplicationError) + assert cause.type == "TruncatedOffset" + + await handle.signal(TruncateRaceWorkflow.close) + + +@pytest.mark.asyncio +async def test_workflow_and_activity_publish_interleaved(client: Client) -> None: + """Workflow publishes status events around activity publishing.""" + count = 5 + async with new_worker( + client, + InterleavedWorkflow, + activities=[publish_items], + ) as worker: + handle = await client.start_workflow( + InterleavedWorkflow.run, + count, + id=f"workflow-stream-interleave-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Total: 1 (started) + count (activity) + 1 (done) = count + 2 + items = await collect_items(client, handle, None, 0, count + 2) + assert len(items) == count + 2 + + # First item is workflow-side "started" + assert items[0].topic == "status" + assert items[0].data == b"started" + + # Middle items are from activity + for i in range(count): + assert items[i + 1].topic == "events" + assert items[i + 1].data == f"item-{i}".encode() + + # Last item is workflow-side "done" + assert items[count + 1].topic == "status" + assert items[count + 1].data == b"done" + + await handle.signal(InterleavedWorkflow.close) + + +@pytest.mark.asyncio +async def test_priority_flush(client: Client) -> None: + """Priority publish triggers immediate flush without waiting for timer.""" + async with new_worker( + client, + PriorityWorkflow, + activities=[publish_with_priority], + ) as worker: + handle = await client.start_workflow( + PriorityWorkflow.run, + id=f"workflow-stream-priority-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # If priority works, items arrive within milliseconds of the publish. + # The activity holds for ~10s after priority publish; this timeout + # gives plenty of margin for workflow/worker scheduling on slow CI + # while staying well below the activity hold so a regression (no + # priority wakeup) surfaces as a missing item, not a pass via + # __aexit__ flush. + items = await collect_items(client, handle, None, 0, 3, timeout=5.0) + assert len(items) == 3 + assert items[2].data == b"priority" + + await handle.signal(PriorityWorkflow.close) + + +@pytest.mark.asyncio +async def test_iterator_cancellation(client: Client) -> None: + """Cancelling a subscription iterator after it has yielded an item + completes cleanly.""" + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-cancel-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Seed one item so the iterator provably reaches an active state + # before we cancel โ€” no sleep-based wait. + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"seed"))] + ), + ) + + stream_client = WorkflowStreamClient.create(client, handle.id) + first_item = asyncio.Event() + items: list[WorkflowStreamItem] = [] + + async def subscribe_and_collect() -> None: + async for item in stream_client.subscribe( + from_offset=0, poll_cooldown=timedelta(0), result_type=bytes + ): + items.append(item) + first_item.set() + + task = asyncio.create_task(subscribe_and_collect()) + # Bounded wait so a subscribe regression fails fast instead of hanging. + async with _async_timeout(5): + await first_item.wait() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert len(items) == 1 + assert items[0].data == b"seed" + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_context_manager_flushes_on_exit(client: Client) -> None: + """Context manager exit flushes all buffered items.""" + count = 5 + async with new_worker( + client, + FlushOnExitWorkflow, + activities=[publish_batch_test], + ) as worker: + handle = await client.start_workflow( + FlushOnExitWorkflow.run, + count, + id=f"workflow-stream-flush-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Despite 60s batch interval, all items arrive because __aexit__ flushes + items = await collect_items(client, handle, None, 0, count, timeout=15.0) + assert len(items) == count + for i in range(count): + assert items[i].data == f"item-{i}".encode() + + await handle.signal(FlushOnExitWorkflow.close) + + +@pytest.mark.asyncio +async def test_explicit_flush_barrier(client: Client) -> None: + """``await client.flush()`` is a synchronization point. + + Verifies the documented contract: + 1. Returns immediately when the buffer is empty. + 2. After it returns, items published before the call are durable + on the workflow side (observable via ``get_offset()``) โ€” even + when the timer-driven flush would not yet have fired. + 3. Calling it again after a successful flush is a no-op. + + Uses a 60s ``batch_interval`` so a regression where ``flush()`` + silently relies on the background timer surfaces as a hang + against the test's 5s timeout, not a slow pass. + """ + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-flush-barrier-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + stream = WorkflowStreamClient.create( + client, handle.id, batch_interval=timedelta(seconds=60) + ) + + async with _async_timeout(5): + # 1. Empty-buffer flush is a no-op (must not block). + assert await stream.get_offset() == 0 + await stream.flush() + assert await stream.get_offset() == 0 + + # 2. Flush makes prior publishes visible without waiting on + # the 60s batch timer. + stream.topic("events", type=bytes).publish(b"a") + stream.topic("events", type=bytes).publish(b"b") + stream.topic("events", type=bytes).publish(b"c") + await stream.flush() + assert await stream.get_offset() == 3 + + # 3. Second flush with no new items is a no-op. + await stream.flush() + assert await stream.get_offset() == 3 + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_concurrent_subscribers(client: Client) -> None: + """Two subscribers on different topics make interleaved progress. + + Publishes A-0, waits for subscriber A to observe it; publishes B-0, + waits for subscriber B to observe it. At this point both subscribers + have received exactly one item and are polling for their second, + so both subscriptions are provably in flight at the same time. + Then publishes A-1, B-1 the same way. A sequential execution (A drains + then B starts) cannot satisfy the ordering because B's first item + isn't published until after A has already received its first. + """ + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-concurrent-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + stream = WorkflowStreamClient(handle) + a_items: list[WorkflowStreamItem] = [] + b_items: list[WorkflowStreamItem] = [] + a_got = [asyncio.Event(), asyncio.Event()] + b_got = [asyncio.Event(), asyncio.Event()] + + async def collect( + topic: str, + collected: list[WorkflowStreamItem], + events: list[asyncio.Event], + ) -> None: + async for item in stream.subscribe( + topics=[topic], + from_offset=0, + poll_cooldown=timedelta(0), + result_type=bytes, + ): + collected.append(item) + events[len(collected) - 1].set() + if len(collected) >= len(events): + break + + a_task = asyncio.create_task(collect("a", a_items, a_got)) + b_task = asyncio.create_task(collect("b", b_items, b_got)) + + async def publish(topic: str, data: bytes) -> None: + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput(items=[PublishEntry(topic=topic, data=_wire_bytes(data))]), + ) + + try: + async with _async_timeout(10): + await publish("a", b"a-0") + await a_got[0].wait() + await publish("b", b"b-0") + await b_got[0].wait() + # Both subscribers are now mid-subscription, each having + # seen one item and polling for the next. + await publish("a", b"a-1") + await a_got[1].wait() + await publish("b", b"b-1") + await b_got[1].wait() + + await asyncio.gather(a_task, b_task) + finally: + a_task.cancel() + b_task.cancel() + + assert [i.data for i in a_items] == [b"a-0", b"a-1"] + assert [i.data for i in b_items] == [b"b-0", b"b-1"] + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_max_batch_size(client: Client) -> None: + """max_batch_size triggers auto-flush without waiting for timer.""" + count = 7 # with max_batch_size=3: flushes at 3, 6, then remainder 1 on exit + async with new_worker( + client, + MaxBatchWorkflow, + activities=[publish_with_max_batch], + max_cached_workflows=0, + ) as worker: + handle = await client.start_workflow( + MaxBatchWorkflow.run, + count, + id=f"workflow-stream-maxbatch-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # count items from activity + 1 "activity_done" from workflow + items = await collect_items(client, handle, None, 0, count + 1, timeout=15.0) + assert len(items) == count + 1 + for i in range(count): + assert items[i].data == f"item-{i}".encode() + + # max_batch_size actually engages: at least one flush fires during + # the publish loop, so 7 items ship as >=2 signals. Without this + # assertion the test would pass even if max_batch_size were ignored + # and all 7 items went out in a single exit-time flush (batch_count + # == 1). Note: max_batch_size is a *trigger* threshold, not a cap โ€” + # the flusher may take more items from the buffer than max_batch_size + # if more were added while a prior signal was in flight, so the exact + # batch count depends on interleaving. Asserting >= 2 is the + # non-flaky way to verify the mechanism is live. + seqs = await handle.query(MaxBatchWorkflow.publisher_sequences) + assert len(seqs) == 1, f"expected one publisher, got {seqs}" + (batch_count,) = seqs.values() + assert batch_count >= 2, ( + f"expected >=2 batches with max_batch_size=3 and 7 items, got " + f"{batch_count} โ€” max_batch_size did not trigger a mid-loop flush" + ) + + await handle.signal(MaxBatchWorkflow.close) + + +@pytest.mark.asyncio +async def test_replay_safety(client: Client) -> None: + """Workflow stream broker survives workflow replay (max_cached_workflows=0).""" + async with new_worker( + client, + InterleavedWorkflow, + activities=[publish_items], + max_cached_workflows=0, + ) as worker: + handle = await client.start_workflow( + InterleavedWorkflow.run, + 5, + id=f"workflow-stream-replay-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # 1 (started) + 5 (activity) + 1 (done) = 7 + items = await collect_items(client, handle, None, 0, 7) + # Full ordered sequence โ€” endpoint-only checks would miss mid-stream + # replay corruption (reordering, duplication, dropped items). + assert [i.data for i in items] == [ + b"started", + b"item-0", + b"item-1", + b"item-2", + b"item-3", + b"item-4", + b"done", + ] + assert [i.offset for i in items] == list(range(7)) + await handle.signal(InterleavedWorkflow.close) + + +@pytest.mark.asyncio +async def test_flush_retry_preserves_items_after_failures( + client: Client, +) -> None: + """After flush failures, a subsequent successful flush delivers all items + in publish order, exactly once. + + Exercises the retry code path behaviorally: simulated delivery failures + must not drop items, must not duplicate them on retry, and must not + reorder items published during the failed state. + """ + async with new_worker(client, BasicWorkflowStreamWorkflow) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-flush-retry-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + stream = WorkflowStreamClient(handle) + real_signal = handle.signal + fail_remaining = 2 + + async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any: + nonlocal fail_remaining + if fail_remaining > 0: + fail_remaining -= 1 + raise RuntimeError("simulated delivery failure") + return await real_signal(*args, **kwargs) + + with patch.object(handle, "signal", side_effect=maybe_failing_signal): + stream.topic("events", type=bytes).publish(b"item-0") + stream.topic("events", type=bytes).publish(b"item-1") + with pytest.raises(RuntimeError): + await stream._flush() + + # Publish more during the failed state โ€” must not overtake the + # pending retry on eventual delivery. + stream.topic("events", type=bytes).publish(b"item-2") + with pytest.raises(RuntimeError): + await stream._flush() + + # Third flush succeeds, delivering the pending retry batch. + await stream._flush() + # Fourth flush delivers the buffered "item-2". + await stream._flush() + + items = await collect_items(client, handle, None, 0, 3) + assert [i.data for i in items] == [b"item-0", b"item-1", b"item-2"] + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_flush_raises_after_max_retry_duration(client: Client) -> None: + """When max_retry_duration is exceeded, flush raises TimeoutError and the + client can resume publishing without losing subsequent items.""" + async with new_worker(client, BasicWorkflowStreamWorkflow) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-retry-expiry-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Inject a controllable clock into the client module. The client's + # retry check compares `time.monotonic() - _pending_since` against + # `max_retry_duration`, so advancing the clock between flushes makes + # the timeout fire deterministically regardless of wall-clock speed + # or clock resolution. + stream = WorkflowStreamClient( + handle, max_retry_duration=timedelta(milliseconds=100) + ) + real_signal = handle.signal + fail_signals = True + + async def maybe_failing_signal(*args: Any, **kwargs: Any) -> Any: + if fail_signals: + raise RuntimeError("simulated failure") + return await real_signal(*args, **kwargs) + + clock = [0.0] + with ( + patch( + "temporalio.contrib.workflow_streams._client.time.monotonic", + side_effect=lambda: clock[0], + ), + patch.object(handle, "signal", side_effect=maybe_failing_signal), + ): + stream.topic("events", type=bytes).publish(b"lost") + + # First flush fails and enters the pending-retry state. + with pytest.raises(RuntimeError): + await stream._flush() + + # Advance the clock well past max_retry_duration. + clock[0] = 10.0 + + # Next flush raises TimeoutError โ€” the pending batch is abandoned. + with pytest.raises(TimeoutError, match="max_retry_duration"): + await stream._flush() + + # Stop failing signals; subsequent publishes must succeed. + fail_signals = False + stream.topic("events", type=bytes).publish(b"kept") + await stream._flush() + + items = await collect_items(client, handle, None, 0, 1) + assert len(items) == 1 + assert items[0].data == b"kept" + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_dedup_rejects_duplicate_signal(client: Client) -> None: + """Workflow deduplicates signals with the same publisher_id + sequence.""" + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-dedup-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Send a batch with publisher_id and sequence + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-0"))], + publisher_id="test-pub", + sequence=1, + ), + ) + + # Send the same sequence again โ€” should be deduped + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"duplicate"))], + publisher_id="test-pub", + sequence=1, + ), + ) + + # Send a new sequence โ€” should go through + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"item-1"))], + publisher_id="test-pub", + sequence=2, + ), + ) + + # Should have 2 items, not 3 (collect_items' update call acts as barrier) + items = await collect_items(client, handle, None, 0, 2) + assert len(items) == 2 + assert items[0].data == b"item-0" + assert items[1].data == b"item-1" + + # Verify offset is 2 (not 3) + stream_client = WorkflowStreamClient(handle) + offset = await stream_client.get_offset() + assert offset == 2 + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_double_init_raises(client: Client) -> None: + """Instantiating WorkflowStream twice from @workflow.init raises RuntimeError. + + The first WorkflowStream() registers the __temporal_workflow_stream_publish signal handler; the + second call detects the existing handler and raises rather than + silently overwriting it. + """ + async with new_worker(client, DoubleInitWorkflow) as worker: + handle = await client.start_workflow( + DoubleInitWorkflow.run, + id=f"workflow-stream-double-init-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + err = await handle.query(DoubleInitWorkflow.get_double_init_error) + assert err is not None + assert "already registered" in err + await handle.signal(DoubleInitWorkflow.close) + + +@pytest.mark.asyncio +async def test_workflow_stream_outside_init_raises(client: Client) -> None: + """Constructing WorkflowStream outside @workflow.init raises RuntimeError. + + The workflow calls WorkflowStream() from @workflow.run; the caller-frame + guard must reject the call because the caller's function name is + ``run``, not ``__init__``. + """ + async with new_worker(client, LateWorkflowStreamWorkflow) as worker: + result = await client.execute_workflow( + LateWorkflowStreamWorkflow.run, + id=f"workflow-stream-late-init-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + assert "must be constructed directly from the workflow's" in result + assert "'run'" in result + + +@pytest.mark.asyncio +async def test_truncate_stream(client: Client) -> None: + """WorkflowStream.truncate discards prefix and adjusts base_offset.""" + async with new_worker( + client, + TruncateWorkflow, + ) as worker: + handle = await client.start_workflow( + TruncateWorkflow.run, + 5, + id=f"workflow-stream-truncate-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Verify all 5 items + items = await collect_items(client, handle, None, 0, 5) + assert len(items) == 5 + + # Truncate up to offset 3 (discard items 0, 1, 2). The update + # returns after the handler completes. + await handle.execute_update("truncate", 3) + + # Offset should still be 5 (truncation moves base_offset, not tail) + stream_client = WorkflowStreamClient(handle) + offset = await stream_client.get_offset() + assert offset == 5 + + # Reading from offset 3 should work (items 3, 4) + items_after = await collect_items(client, handle, None, 3, 2) + assert len(items_after) == 2 + assert items_after[0].data == b"item-3" + assert items_after[1].data == b"item-4" + + await handle.signal("close") + + +@pytest.mark.asyncio +async def test_ttl_pruning_in_get_stream_state(client: Client) -> None: + """WorkflowStream.get_state prunes publishers whose last-seen time exceeds the + TTL while retaining newer publishers. The log itself is unaffected. + + Uses a wall-clock gap between publishes so that workflow.time() + advances between the two publishers' tasks. workflow.time() can't be + cleanly injected from outside, so a short real sleep is the mechanism. + """ + async with new_worker( + client, + TTLTestWorkflow, + ) as worker: + handle = await client.start_workflow( + TTLTestWorkflow.run, + id=f"workflow-stream-ttl-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # pub-old arrives first. + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"old"))], + publisher_id="pub-old", + sequence=1, + ), + ) + + # Sanity: pub-old is recorded (generous TTL retains it). + state_before = await handle.query(TTLTestWorkflow.get_state_with_ttl, 9999.0) + assert "pub-old" in state_before.publishers + + # Let workflow.time() advance by real wall-clock time. Use a + # generous gap (1.0s) relative to the TTL (0.5s) so the test + # tolerates CI scheduling delays โ€” pub-old must be >=0.5s past, + # pub-new must be <0.5s past, at the moment of the query. + await asyncio.sleep(1.0) + + # pub-new arrives after the gap. + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="events", data=_wire_bytes(b"new"))], + publisher_id="pub-new", + sequence=1, + ), + ) + + # TTL=0.5s prunes pub-old (~1.0s old) but keeps pub-new (~0s). + state = await handle.query(TTLTestWorkflow.get_state_with_ttl, 0.5) + assert "pub-old" not in state.publishers + assert "pub-new" in state.publishers + # Log contents are not touched by publisher pruning. + assert len(state.log) == 2 + + await handle.signal("close") + + +# --------------------------------------------------------------------------- +# Truncate and TTL test workflows +# --------------------------------------------------------------------------- + + +@workflow.defn +class TruncateWorkflow: + """Test scaffolding that exposes WorkflowStream.truncate via a user-authored + update. + + The contrib module does not define a built-in external truncate API โ€” + truncation is a workflow-internal decision (typically driven by + consumer progress or a retention policy). Workflows that want external + control wire up their own signal or update. We use an update here so + callers get explicit completion (signals are fire-and-forget). + + The ``truncate`` update is ``async`` and opens with + ``await asyncio.sleep(0)`` โ€” the documented recipe from the + contrib/stream README for sync-shaped handlers that read ``WorkflowStream`` + state. The yield lets any buffered ``__temporal_workflow_stream_publish`` signal in + the same activation apply before the handler inspects ``self._log``. + This keeps the test workflow aligned with the pattern users are + directed to follow. + + ``prepub_count`` seeds the log with N byte-payload items during + ``@workflow.init`` as test convenience, so the error-path tests + have deterministic log content without an extra round trip to + publish from the client. + """ + + @workflow.init + def __init__(self, prepub_count: int = 0) -> None: + self.stream = WorkflowStream() + self._closed = False + for i in range(prepub_count): + self.stream.topic("events", type=bytes).publish(f"item-{i}".encode()) + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.update + async def truncate(self, up_to_offset: int) -> None: + # Recipe from README.md "Gotcha" section: yield once so any + # buffered __temporal_workflow_stream_publish in the same activation applies + # before we read self._log. asyncio.sleep(0) is a pure asyncio + # yield โ€” no Temporal timer, no history event. + await asyncio.sleep(0) + self.stream.truncate(up_to_offset) + + @workflow.run + async def run(self, _prepub_count: int = 0) -> None: + # _prepub_count is consumed in @workflow.init above. @workflow.run + # must accept the same positional args, but the names are free + # to differ. + del _prepub_count + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class TruncateRaceWorkflow: + """Workflow that exposes ``publish`` and ``publish_then_truncate`` + updates so a test can deterministically interleave a waiting + ``__temporal_workflow_stream_poll`` update against a truncate that + advances ``base_offset`` past the poll's ``from_offset``. + + The ``publish_then_truncate`` handler runs publish loop and truncate + in a single workflow activation (no awaits between them), so a poll + parked at ``wait_condition`` sees the post-truncate state on its + next predicate evaluation rather than firing on an intermediate + publish. + """ + + @workflow.init + def __init__(self) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.update + async def publish(self, count: int) -> None: + await asyncio.sleep(0) + topic = self.stream.topic("events", type=bytes) + for i in range(count): + topic.publish(f"item-{i}".encode()) + + @workflow.update + async def publish_then_truncate(self, args: tuple[int, int]) -> None: + await asyncio.sleep(0) + publish_count, truncate_to = args + topic = self.stream.topic("events", type=bytes) + for i in range(publish_count): + topic.publish(f"prepub-{i}".encode()) + self.stream.truncate(truncate_to) + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class TTLTestWorkflow: + """Workflow that exposes WorkflowStream.get_state via query for TTL testing.""" + + @workflow.init + def __init__(self) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.query + def get_state_with_ttl(self, ttl_seconds: float) -> WorkflowStreamState: + # Query arg is passed as float because the default JSON payload + # converter does not serialize ``timedelta``; convert here. + return self.stream.get_state(publisher_ttl=timedelta(seconds=ttl_seconds)) + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self._closed) + + +# --------------------------------------------------------------------------- +# Continue-as-new workflow and test +# --------------------------------------------------------------------------- + + +@dataclass +class CANWorkflowInputTyped: + """Uses proper typing.""" + + stream_state: WorkflowStreamState | None = None + + +@workflow.defn +class ContinueAsNewTypedWorkflow: + """CAN workflow using properly-typed stream_state.""" + + @workflow.init + def __init__(self, input: CANWorkflowInputTyped) -> None: + self.stream = WorkflowStream(prior_state=input.stream_state) + self._should_continue = False + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.signal + def trigger_continue(self) -> None: + self._should_continue = True + + @workflow.query + def publisher_sequences(self) -> dict[str, int]: + return {pid: ps.sequence for pid, ps in self.stream._publishers.items()} + + @workflow.run + async def run(self, _input: CANWorkflowInputTyped) -> None: + # _input is consumed in @workflow.init above. @workflow.run must + # accept the same positional args, but the names are free to differ. + del _input + while True: + await workflow.wait_condition(lambda: self._should_continue or self._closed) + if self._closed: + return + if self._should_continue: + self._should_continue = False + self.stream.detach_pollers() + await workflow.wait_condition(workflow.all_handlers_finished) + workflow.continue_as_new( + args=[ + CANWorkflowInputTyped( + stream_state=self.stream.get_state(), + ) + ] + ) + + +@pytest.mark.asyncio +async def test_continue_as_new_properly_typed(client: Client) -> None: + """CAN preserves the log, global offsets, AND publisher dedup state + when stream_state is properly typed as ``WorkflowStreamState | None``.""" + async with new_worker( + client, + ContinueAsNewTypedWorkflow, + ) as worker: + handle = await client.start_workflow( + ContinueAsNewTypedWorkflow.run, + CANWorkflowInputTyped(), + id=f"workflow-stream-can-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish 3 items with an explicit publisher_id/sequence so dedup + # state is seeded and we can verify it survives CAN. + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-0")), + PublishEntry(topic="events", data=_wire_bytes(b"item-1")), + PublishEntry(topic="events", data=_wire_bytes(b"item-2")), + ], + publisher_id="pub", + sequence=1, + ), + ) + + items_before = await collect_items(client, handle, None, 0, 3) + assert len(items_before) == 3 + + await handle.signal(ContinueAsNewTypedWorkflow.trigger_continue) + + new_handle = client.get_workflow_handle(handle.id) + await assert_eq_eventually( + True, + lambda: _is_different_run(handle, new_handle), + ) + + # Log contents and offsets preserved across CAN. + items_after = await collect_items(client, new_handle, None, 0, 3) + assert [i.data for i in items_after] == [b"item-0", b"item-1", b"item-2"] + assert [i.offset for i in items_after] == [0, 1, 2] + + # Dedup state preserved: the carried publisher_sequences dict has + # pub -> 1 after CAN. + seqs_after_can = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_can == {"pub": 1} + + # Re-sending publisher_id="pub", sequence=1 must be rejected by + # dedup โ€” both the log and the publisher_sequences entry stay put. + await new_handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"dup")), + ], + publisher_id="pub", + sequence=1, + ), + ) + seqs_after_dup = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_dup == {"pub": 1} + + # A fresh sequence from the same publisher is accepted, advances + # publisher_sequences to 2, and the new item gets offset 3. + await new_handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-3")), + ], + publisher_id="pub", + sequence=2, + ), + ) + seqs_after_accept = await new_handle.query( + ContinueAsNewTypedWorkflow.publisher_sequences + ) + assert seqs_after_accept == {"pub": 2} + items_all = await collect_items(client, new_handle, None, 0, 4) + assert [i.data for i in items_all] == [ + b"item-0", + b"item-1", + b"item-2", + b"item-3", + ] + assert items_all[3].offset == 3 + + await new_handle.signal(ContinueAsNewTypedWorkflow.close) + + +@workflow.defn +class ContinueAsNewHelperWorkflow: + """CAN workflow that uses the packaged ``WorkflowStream.continue_as_new`` helper.""" + + @workflow.init + def __init__(self, input: CANWorkflowInputTyped) -> None: + self.stream = WorkflowStream(prior_state=input.stream_state) + self._should_continue = False + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.signal + def trigger_continue(self) -> None: + self._should_continue = True + + @workflow.run + async def run(self, _input: CANWorkflowInputTyped) -> None: + del _input + while True: + await workflow.wait_condition(lambda: self._should_continue or self._closed) + if self._closed: + return + if self._should_continue: + self._should_continue = False + await self.stream.continue_as_new( + lambda state: [CANWorkflowInputTyped(stream_state=state)], + ) + + +@pytest.mark.asyncio +async def test_continue_as_new_helper(client: Client) -> None: + """The ``WorkflowStream.continue_as_new`` helper preserves log and dedup state + just like the explicit detach_pollers/wait/CAN recipe.""" + async with new_worker( + client, + ContinueAsNewHelperWorkflow, + ) as worker: + handle = await client.start_workflow( + ContinueAsNewHelperWorkflow.run, + CANWorkflowInputTyped(), + id=f"workflow-stream-can-helper-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[ + PublishEntry(topic="events", data=_wire_bytes(b"item-0")), + PublishEntry(topic="events", data=_wire_bytes(b"item-1")), + ], + publisher_id="pub", + sequence=1, + ), + ) + + items_before = await collect_items(client, handle, None, 0, 2) + assert [i.data for i in items_before] == [b"item-0", b"item-1"] + + await handle.signal(ContinueAsNewHelperWorkflow.trigger_continue) + + new_handle = client.get_workflow_handle(handle.id) + await assert_eq_eventually( + True, + lambda: _is_different_run(handle, new_handle), + ) + + items_after = await collect_items(client, new_handle, None, 0, 2) + assert [i.data for i in items_after] == [b"item-0", b"item-1"] + assert [i.offset for i in items_after] == [0, 1] + + await new_handle.signal(ContinueAsNewHelperWorkflow.close) + + +# --------------------------------------------------------------------------- +# Cross-workflow workflow stream (Scenario 1) +# --------------------------------------------------------------------------- + + +@dataclass +class CrossWorkflowInput: + broker_workflow_id: str + expected_count: int + + +@workflow.defn +class BrokerWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> None: + for i in range(count): + self.stream.topic("events", type=bytes).publish(f"broker-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + + +@workflow.defn +class SubscriberWorkflow: + @workflow.run + async def run(self, input: CrossWorkflowInput) -> list[str]: + return await workflow.execute_activity( + "subscribe_to_broker", + input, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + + +@activity.defn(name="subscribe_to_broker") +async def subscribe_to_broker(input: CrossWorkflowInput) -> list[str]: + client = WorkflowStreamClient.create( + client=activity.client(), + workflow_id=input.broker_workflow_id, + ) + items: list[str] = [] + async with _async_timeout(15.0): + async for item in client.subscribe( + topics=["events"], + from_offset=0, + poll_cooldown=timedelta(0), + result_type=bytes, + ): + items.append(item.data.decode()) + activity.heartbeat() + if len(items) >= input.expected_count: + break + return items + + +@pytest.mark.asyncio +async def test_cross_workflow_stream(client: Client) -> None: + """Workflow B's activity subscribes to events published by Workflow A.""" + count = 5 + task_queue = str(uuid.uuid4()) + + async with new_worker( + client, + BrokerWorkflow, + SubscriberWorkflow, + activities=[subscribe_to_broker], + task_queue=task_queue, + ): + broker_id = f"workflow-stream-broker-{uuid.uuid4()}" + broker_handle = await client.start_workflow( + BrokerWorkflow.run, + count, + id=broker_id, + task_queue=task_queue, + ) + + sub_handle = await client.start_workflow( + SubscriberWorkflow.run, + CrossWorkflowInput( + broker_workflow_id=broker_id, + expected_count=count, + ), + id=f"workflow-stream-subscriber-{uuid.uuid4()}", + task_queue=task_queue, + ) + + result = await sub_handle.result() + assert result == [f"broker-{i}" for i in range(count)] + + # Also verify external subscription still works + external_items = await collect_items( + client, broker_handle, ["events"], 0, count + ) + assert len(external_items) == count + + await broker_handle.signal(BrokerWorkflow.close) + + +# --------------------------------------------------------------------------- +# Standalone activity (started directly via Client, no parent workflow) +# --------------------------------------------------------------------------- + + +@dataclass +class StandalonePublishInput: + broker_workflow_id: str + count: int + + +@activity.defn(name="standalone_publish_to_broker") +async def standalone_publish_to_broker(input: StandalonePublishInput) -> None: + """Publish to a broker workflow from a standalone activity. + + Same usage as in any external program: build a Client (here taken + via ``activity.client()``), pass an explicit workflow id to + ``WorkflowStreamClient.create``. ``from_within_activity`` is not usable + here because the activity has no parent workflow. + """ + assert ( + activity.info().workflow_id is None + ), "test bug: this activity should be standalone" + client = WorkflowStreamClient.create( + client=activity.client(), + workflow_id=input.broker_workflow_id, + batch_interval=timedelta(milliseconds=500), + ) + async with client: + for i in range(input.count): + activity.heartbeat() + client.topic("events", type=bytes).publish(f"standalone-{i}".encode()) + + +@activity.defn(name="standalone_subscribe_to_broker") +async def standalone_subscribe_to_broker(input: CrossWorkflowInput) -> list[str]: + assert ( + activity.info().workflow_id is None + ), "test bug: this activity should be standalone" + client = WorkflowStreamClient.create( + client=activity.client(), + workflow_id=input.broker_workflow_id, + ) + items: list[str] = [] + async with _async_timeout(15.0): + async for item in client.subscribe( + topics=["events"], + from_offset=0, + poll_cooldown=timedelta(0), + result_type=bytes, + ): + items.append(item.data.decode()) + activity.heartbeat() + if len(items) >= input.expected_count: + break + return items + + +@activity.defn(name="standalone_from_within_activity_misuse") +async def standalone_from_within_activity_misuse() -> str: + """Calling from_within_activity in a standalone activity must raise a clear error.""" + try: + WorkflowStreamClient.from_within_activity() + except RuntimeError as e: + return str(e) + return "" + + +@pytest.mark.asyncio +async def test_standalone_activity_publish( + client: Client, env: WorkflowEnvironment +) -> None: + """Activity started directly via Client.start_activity publishes via create().""" + if env.supports_time_skipping: + pytest.skip( + "Java test server does not support Client.start_activity: " + "https://github.com/temporalio/sdk-java/issues/2741" + ) + count = 5 + task_queue = str(uuid.uuid4()) + + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + activities=[standalone_publish_to_broker], + task_queue=task_queue, + ): + broker_id = f"workflow-stream-standalone-broker-{uuid.uuid4()}" + broker_handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=broker_id, + task_queue=task_queue, + ) + + activity_handle = await client.start_activity( + standalone_publish_to_broker, + StandalonePublishInput(broker_workflow_id=broker_id, count=count), + id=f"standalone-publish-{uuid.uuid4()}", + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + await activity_handle.result() + + items = await collect_items(client, broker_handle, ["events"], 0, count) + assert [i.data for i in items] == [ + f"standalone-{i}".encode() for i in range(count) + ] + + await broker_handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_standalone_activity_subscribe( + client: Client, env: WorkflowEnvironment +) -> None: + """Standalone activity subscribes to a broker workflow via create().""" + if env.supports_time_skipping: + pytest.skip( + "Java test server does not support Client.start_activity: " + "https://github.com/temporalio/sdk-java/issues/2741" + ) + count = 5 + task_queue = str(uuid.uuid4()) + + async with new_worker( + client, + BrokerWorkflow, + activities=[standalone_subscribe_to_broker], + task_queue=task_queue, + ): + broker_id = f"workflow-stream-standalone-sub-broker-{uuid.uuid4()}" + broker_handle = await client.start_workflow( + BrokerWorkflow.run, + count, + id=broker_id, + task_queue=task_queue, + ) + + activity_handle = await client.start_activity( + standalone_subscribe_to_broker, + CrossWorkflowInput( + broker_workflow_id=broker_id, + expected_count=count, + ), + id=f"standalone-subscribe-{uuid.uuid4()}", + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=30), + heartbeat_timeout=timedelta(seconds=10), + ) + result = await activity_handle.result() + assert result == [f"broker-{i}" for i in range(count)] + + await broker_handle.signal(BrokerWorkflow.close) + + +@pytest.mark.asyncio +async def test_from_within_activity_in_standalone_activity_raises( + client: Client, env: WorkflowEnvironment +) -> None: + """from_within_activity() raises a clear error pointing at create() when used in a + standalone activity (one without a parent workflow).""" + if env.supports_time_skipping: + pytest.skip( + "Java test server does not support Client.start_activity: " + "https://github.com/temporalio/sdk-java/issues/2741" + ) + task_queue = str(uuid.uuid4()) + + async with new_worker( + client, + activities=[standalone_from_within_activity_misuse], + task_queue=task_queue, + ): + activity_handle = await client.start_activity( + standalone_from_within_activity_misuse, + id=f"standalone-misuse-{uuid.uuid4()}", + task_queue=task_queue, + start_to_close_timeout=timedelta(seconds=10), + ) + msg = await activity_handle.result() + assert "no parent workflow" in msg + assert "WorkflowStreamClient.create" in msg + + +# --------------------------------------------------------------------------- +# Cross-namespace workflow stream via Nexus (Scenario 2) +# --------------------------------------------------------------------------- + + +@dataclass +class StartBrokerInput: + count: int + broker_id: str + + +@dataclass +class NexusCallerInput: + count: int + broker_id: str + endpoint: str + + +@workflow.defn +class NexusBrokerWorkflow: + @workflow.init + def __init__(self, count: int) -> None: + self.stream = WorkflowStream() + self._closed = False + + @workflow.signal + def close(self) -> None: + self._closed = True + + @workflow.run + async def run(self, count: int) -> str: + for i in range(count): + self.stream.topic("events", type=bytes).publish(f"nexus-{i}".encode()) + await workflow.wait_condition(lambda: self._closed) + return "done" + + +@nexusrpc.service +class WorkflowStreamNexusService: + start_broker: nexusrpc.Operation[StartBrokerInput, str] + + +@nexusrpc.handler.service_handler(service=WorkflowStreamNexusService) +class WorkflowStreamNexusHandler: + @workflow_run_operation + async def start_broker( + self, ctx: WorkflowRunOperationContext, input: StartBrokerInput + ) -> nexus.WorkflowHandle[str]: + return await ctx.start_workflow( + NexusBrokerWorkflow.run, + input.count, + id=input.broker_id, + ) + + +@workflow.defn +class NexusCallerWorkflow: + @workflow.run + async def run(self, input: NexusCallerInput) -> str: + nc = workflow.create_nexus_client( + service=WorkflowStreamNexusService, + endpoint=input.endpoint, + ) + return await nc.execute_operation( + WorkflowStreamNexusService.start_broker, + StartBrokerInput(count=input.count, broker_id=input.broker_id), + ) + + +async def create_cross_namespace_endpoint( + client: Client, + endpoint_name: str, + target_namespace: str, + task_queue: str, +) -> None: + await client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=endpoint_name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=target_namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + + +@pytest.mark.asyncio +async def test_poll_more_ready_when_response_exceeds_size_limit( + client: Client, +) -> None: + """Poll response sets more_ready=True when items exceed ~1MB wire size.""" + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-more-ready-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish items that total well over 1MB in the poll response. + # Send in separate signals to stay under the RPC size limit. + # Each item is ~200KB; 8 items = ~1.6MB wire (base64 inflates ~33%). + chunk = b"x" * 200_000 + for _ in range(8): + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="big", data=_wire_bytes(chunk))] + ), + ) + + # First poll from offset 0 โ€” should get some items but not all. + # (The update acts as a barrier for all prior publish signals.) + result1: PollResult = await handle.execute_update( + "__temporal_workflow_stream_poll", + PollInput(topics=[], from_offset=0), + result_type=PollResult, + ) + assert result1.more_ready is True + assert len(result1.items) < 8 + assert result1.next_offset < 8 + + # Continue polling until we have all items + all_items = list(result1.items) + offset = result1.next_offset + last_result: PollResult = result1 + while len(all_items) < 8: + last_result = await handle.execute_update( + "__temporal_workflow_stream_poll", + PollInput(topics=[], from_offset=offset), + result_type=PollResult, + ) + all_items.extend(last_result.items) + offset = last_result.next_offset + assert len(all_items) == 8 + # The final poll that drained the log should set more_ready=False + assert last_result.more_ready is False + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_subscribe_iterates_through_more_ready(client: Client) -> None: + """Subscriber correctly yields all items when polls are size-truncated.""" + async with new_worker( + client, + BasicWorkflowStreamWorkflow, + ) as worker: + handle = await client.start_workflow( + BasicWorkflowStreamWorkflow.run, + id=f"workflow-stream-more-ready-iter-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Publish 8 x 200KB items (~2MB+ wire, exceeds 1MB cap) + chunk = b"x" * 200_000 + for _ in range(8): + await handle.signal( + "__temporal_workflow_stream_publish", + PublishInput( + items=[PublishEntry(topic="big", data=_wire_bytes(chunk))] + ), + ) + + # subscribe() should seamlessly iterate through all 8 items + items = await collect_items(client, handle, None, 0, 8, timeout=10.0) + assert len(items) == 8 + for item in items: + assert item.data == chunk + + await handle.signal(BasicWorkflowStreamWorkflow.close) + + +@pytest.mark.asyncio +async def test_cross_namespace_nexus_stream( + client: Client, env: WorkflowEnvironment +) -> None: + """Nexus operation starts a workflow stream broker in another namespace; test subscribes.""" + if env.supports_time_skipping: + pytest.skip("Nexus not supported with time-skipping server") + + count = 5 + handler_ns = f"handler-ns-{uuid.uuid4().hex[:8]}" + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + broker_id = f"nexus-broker-{uuid.uuid4()}" + + # Register the handler namespace with the dev server + await client.workflow_service.register_namespace( + temporalio.api.workflowservice.v1.RegisterNamespaceRequest( + namespace=handler_ns, + workflow_execution_retention_period=google.protobuf.duration_pb2.Duration( + seconds=86400, + ), + ) + ) + + handler_client = await Client.connect( + client.service_client.config.target_host, + namespace=handler_ns, + ) + + # Create endpoint targeting the handler namespace + await create_cross_namespace_endpoint( + client, + endpoint_name, + target_namespace=handler_ns, + task_queue=task_queue, + ) + + # Handler worker in handler namespace + async with Worker( + handler_client, + task_queue=task_queue, + workflows=[NexusBrokerWorkflow], + nexus_service_handlers=[WorkflowStreamNexusHandler()], + ): + # Caller worker in default namespace + caller_tq = str(uuid.uuid4()) + async with new_worker( + client, + NexusCallerWorkflow, + task_queue=caller_tq, + ): + # Start caller โ€” invokes Nexus op which starts broker in handler ns + caller_handle = await client.start_workflow( + NexusCallerWorkflow.run, + NexusCallerInput( + count=count, + broker_id=broker_id, + endpoint=endpoint_name, + ), + id=f"nexus-caller-{uuid.uuid4()}", + task_queue=caller_tq, + ) + + # Wait for the broker workflow to be started by the Nexus operation + broker_handle = handler_client.get_workflow_handle(broker_id) + + async def broker_started() -> bool: + try: + await broker_handle.describe() + return True + except Exception: + return False + + await assert_eq_eventually( + True, broker_started, timeout=timedelta(seconds=15) + ) + + # Subscribe to broker events from the handler namespace + items = await collect_items( + handler_client, broker_handle, ["events"], 0, count + ) + assert len(items) == count + for i in range(count): + assert items[i].topic == "events" + assert items[i].data == f"nexus-{i}".encode() + + # Clean up โ€” signal broker to close so caller can complete + await broker_handle.signal("close") + result = await caller_handle.result() + assert result == "done" diff --git a/uv.lock b/uv.lock index ecb7e38f9..eb9dc75a6 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-23T15:55:57.051193Z" +exclude-newer = "2026-04-23T17:46:27.746666Z" exclude-newer-span = "P1W" [options.exclude-newer-package] @@ -970,7 +970,7 @@ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } wheels = [ @@ -5195,6 +5195,7 @@ pydantic = [ [package.dev-dependencies] dev = [ + { name = "async-timeout", marker = "python_full_version < '3.11'" }, { name = "basedpyright" }, { name = "cibuildwheel" }, { name = "googleapis-common-protos" }, @@ -5259,6 +5260,7 @@ provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google [package.metadata.requires-dev] dev = [ + { name = "async-timeout", marker = "python_full_version < '3.11'", specifier = ">=4.0,<6" }, { name = "basedpyright", specifier = "==1.34.0" }, { name = "cibuildwheel", specifier = ">=2.22.0,<3" }, { name = "googleapis-common-protos", specifier = "==1.70.0" },