Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,30 +182,34 @@ class TimeoutLimitMiddleware(AgentMiddleware):
"""

_seconds: float
_deadline: float | None
_deadline_per_thread_id: dict[str, float]

def __init__(self, seconds: float) -> None:
self._seconds = seconds
self._deadline = None
self._deadline_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
# WARN: this might not work with agents handling
# different threads at the same time.
self._deadline = monotonic() + self._seconds
return await handler(request)
try:
# Agent loop starting.
self._deadline_per_thread_id[request.thread_id] = (
monotonic() + self._seconds
)
return await handler(request)
finally:
del self._deadline_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
if self._deadline is not None and monotonic() >= self._deadline:
if monotonic() >= self._deadline_per_thread_id[request.state.thread_id]:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)

Expand Down
24 changes: 17 additions & 7 deletions tests/unit/ai/test_default_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
assert len([m for m in mw if isinstance(m, TimeoutLimitMiddleware)]) == 1


async def _noop_agent_handler(_request: AgentRequest) -> AgentResponse[None]:

Check warning on line 114 in tests/unit/ai/test_default_limits.py

View workflow job for this annotation

GitHub Actions / lint-stage

Function "_noop_agent_handler" is not accessed (reportUnusedFunction)
return AgentResponse(messages=[], structured_output=None)


Expand All @@ -124,23 +124,33 @@
mw = TimeoutLimitMiddleware(60.0)
request = _make_agent_request()

await mw.agent_middleware(request, _noop_agent_handler)
first_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
first_deadline: float | None = None
second_deadline: float | None = None

await mw.agent_middleware(request, _noop_agent_handler)
second_deadline = mw._deadline # pyright: ignore[reportPrivateUsage]
async def _first_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
nonlocal first_deadline
first_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
return AgentResponse(messages=[], structured_output=None)

async def _second_agent_handler(_request: AgentRequest) -> AgentResponse[None]:
nonlocal second_deadline
second_deadline = mw._deadline_per_thread_id["foo"] # pyright: ignore[reportPrivateUsage]
return AgentResponse(messages=[], structured_output=None)

await mw.agent_middleware(request, _first_agent_handler)
await mw.agent_middleware(request, _second_agent_handler)

assert first_deadline is not None
assert second_deadline is not None
assert second_deadline is not None # pyright: ignore[reportUnreachable]
assert second_deadline >= first_deadline

async def test_deadline_is_none_before_first_invoke(self) -> None:
mw = TimeoutLimitMiddleware(60.0)
assert mw._deadline is None # pyright: ignore[reportPrivateUsage]
assert mw._deadline_per_thread_id.get("foo") is None # pyright: ignore[reportPrivateUsage]

async def test_timeout_fires_when_deadline_exceeded(self) -> None:
mw = TimeoutLimitMiddleware(60.0)
mw._deadline = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past
mw._deadline_per_thread_id["foo"] = monotonic() - 1.0 # pyright: ignore[reportPrivateUsage] # already in the past

state = AgentState(messages=[], total_steps=0, token_count=0, thread_id="foo")
request = ModelRequest(system_message="", state=state)
Expand Down
Loading