From 6c47fde203131b6862a8981a920f395aa14e3ba3 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Thu, 30 Apr 2026 09:35:55 +0200 Subject: [PATCH] Use thread_ids in TimeoutLimitMiddleware --- splunklib/ai/hooks.py | 18 +++++++++++------- tests/unit/ai/test_default_limits.py | 24 +++++++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/splunklib/ai/hooks.py b/splunklib/ai/hooks.py index 8cdc8d86..262b85a1 100644 --- a/splunklib/ai/hooks.py +++ b/splunklib/ai/hooks.py @@ -182,11 +182,11 @@ class TimeoutLimitMiddleware(AgentMiddleware): """ _seconds: float - _deadline: float | None + _deadline_per_thread_id: dict[str, float] def __init__(self, seconds: float) -> None: self._seconds = seconds - self._deadline = None + self._deadline_per_thread_id = {} @override async def agent_middleware( @@ -194,10 +194,14 @@ async def agent_middleware( request: AgentRequest, handler: AgentMiddlewareHandler, ) -> AgentResponse[Any | None]: - # WARN: this might not work with agents handling - # different threads at the same time. - self._deadline = monotonic() + self._seconds - return await handler(request) + try: + # Agent loop starting. + self._deadline_per_thread_id[request.thread_id] = ( + monotonic() + self._seconds + ) + return await handler(request) + finally: + del self._deadline_per_thread_id[request.thread_id] # don't leak memory @override async def model_middleware( @@ -205,7 +209,7 @@ async def model_middleware( request: ModelRequest, handler: ModelMiddlewareHandler, ) -> ModelResponse: - if self._deadline is not None and monotonic() >= self._deadline: + if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]: raise TimeoutExceededException(timeout_seconds=self._seconds) return await handler(request) diff --git a/tests/unit/ai/test_default_limits.py b/tests/unit/ai/test_default_limits.py index ce38e3ad..8714d592 100644 --- a/tests/unit/ai/test_default_limits.py +++ b/tests/unit/ai/test_default_limits.py @@ -124,23 +124,33 @@ async def test_deadline_reset_on_each_invoke(self) -> None: mw = TimeoutLimitMiddleware(60.0) request = _make_agent_request() - await mw.agent_middleware(request, _noop_agent_handler) - first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage] + first_deadline: float | None = None + second_deadline: float | None = None - await mw.agent_middleware(request, _noop_agent_handler) - second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage] + async def _first_agent_handler(_request: AgentRequest) -> AgentResponse[None]: + nonlocal first_deadline + first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage] + return AgentResponse(messages=[], structured_output=None) + + async def _second_agent_handler(_request: AgentRequest) -> AgentResponse[None]: + nonlocal second_deadline + second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage] + return AgentResponse(messages=[], structured_output=None) + + await mw.agent_middleware(request, _first_agent_handler) + await mw.agent_middleware(request, _second_agent_handler) assert first_deadline is not None - assert second_deadline is not None + assert second_deadline is not None # pyright: ignore[reportUnreachable] assert second_deadline >= first_deadline async def test_deadline_is_none_before_first_invoke(self) -> None: mw = TimeoutLimitMiddleware(60.0) - assert mw._deadline is None # pyright: ignore[reportPrivateUsage] + assert mw._deadline_per_thread_id.get("foo") is None # pyright: ignore[reportPrivateUsage] async def test_timeout_fires_when_deadline_exceeded(self) -> None: mw = TimeoutLimitMiddleware(60.0) - mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past + mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo") request = ModelRequest(system_message="", state=state)