From 11db930f42ac29d301009044d7e8ca2bdd98d6b9 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Thu, 30 Apr 2026 09:35:55 +0200 Subject: [PATCH] Add StructuredOutputRetryLimitMiddleware and default retry limit --- splunklib/ai/base_agent.py | 18 ++- splunklib/ai/engines/langchain.py | 5 - splunklib/ai/hooks.py | 49 ++++++ .../integration/ai/test_structured_output.py | 147 ++++++++++++++++++ 4 files changed, 210 insertions(+), 9 deletions(-) diff --git a/splunklib/ai/base_agent.py b/splunklib/ai/base_agent.py index 04e1cae0..76731973 100644 --- a/splunklib/ai/base_agent.py +++ b/splunklib/ai/base_agent.py @@ -23,9 +23,11 @@ from splunklib.ai.conversation_store import ConversationStore from splunklib.ai.hooks import ( DEFAULT_STEP_LIMIT, + DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT, DEFAULT_TIMEOUT_SECONDS, DEFAULT_TOKEN_LIMIT, StepLimitMiddleware, + StructuredOutputRetryLimitMiddleware, TimeoutLimitMiddleware, TokenLimitMiddleware, ) @@ -79,16 +81,24 @@ def __init__( self._output_schema = output_schema user_middleware = tuple(middleware) if middleware else () user_middleware_types = {type(m) for m in user_middleware} + # NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful # and sharing one would cause agents to overwrite each other's deadline. - predefined: list[AgentMiddleware] = [ + predefined_before: list[AgentMiddleware] = [ + StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT), + ] + predefined_after: list[AgentMiddleware] = [ TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT), StepLimitMiddleware(DEFAULT_STEP_LIMIT), TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS), ] - # Append predefined middlewares by default if not provided already. - default_middleware = [m for m in predefined if type(m) not in user_middleware_types] - self._middleware = (*user_middleware, *default_middleware) + + self._middleware = ( + *{m for m in predefined_before if type(m) not in user_middleware_types}, + *user_middleware, + *{m for m in predefined_after if type(m) not in user_middleware_types}, + ) + self._trace_id = secrets.token_hex(16) # 32 Hex characters self._conversation_store = conversation_store self._thread_id = thread_id diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index 2ecf106b..472afc30 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -882,11 +882,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse: except StructuredOutputGenerationException as e: # Structured output generation failed, retry. - # TODO: we should provide a mechanism to limit the amount of retries - # thath happen sequentially (say 3), otherwise raise a different exception. - # For now this can be done with the use of model middleware that counts - # the amount of StructuredOutputGenerationException that were raised. - ai_msg = _map_message_to_langchain(e.message) assert isinstance(ai_msg, LC_AIMessage) diff --git a/splunklib/ai/hooks.py b/splunklib/ai/hooks.py index 46206949..8cdc8d86 100644 --- a/splunklib/ai/hooks.py +++ b/splunklib/ai/hooks.py @@ -12,10 +12,12 @@ ModelRequest, ModelResponse, ) +from splunklib.ai.structured_output import StructuredOutputGenerationException DEFAULT_TIMEOUT_SECONDS: float = 600.0 DEFAULT_STEP_LIMIT: int = 100 DEFAULT_TOKEN_LIMIT: int = 200_000 +DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3 class AgentStopException(Exception): @@ -43,6 +45,13 @@ def __init__(self, timeout_seconds: float) -> None: super().__init__(f"Timed out after {timeout_seconds} seconds.") +class StructuredOutputRetryLimitExceededException(AgentStopException): + """Raised by `Agent.invoke`, when structured output retry limit exceeds""" + + def __init__(self, retry_count: int) -> None: + super().__init__(f"Structured output retry limit of {retry_count} exceeded") + + def before_model( func: Callable[[ModelRequest], None | Awaitable[None]], ) -> AgentMiddleware: @@ -199,3 +208,43 @@ async def model_middleware( if self._deadline is not None and monotonic() >= self._deadline: raise TimeoutExceededException(timeout_seconds=self._seconds) return await handler(request) + + +class StructuredOutputRetryLimitMiddleware(AgentMiddleware): + """Stops agent execution when the agent exceeds structured output + retry limit during a single agent loop invocation. + """ + + _limit: int + _retries_per_thread_id: dict[str, int] + + def __init__(self, limit: int) -> None: + self._limit = limit + self._retries_per_thread_id = {} + + @override + async def agent_middleware( + self, + request: AgentRequest, + handler: AgentMiddlewareHandler, + ) -> AgentResponse[Any | None]: + try: + # Agent loop starting. + self._retries_per_thread_id[request.thread_id] = 0 + return await handler(request) + finally: + del self._retries_per_thread_id[request.thread_id] # don't leak memory + + @override + async def model_middleware( + self, + request: ModelRequest, + handler: ModelMiddlewareHandler, + ) -> ModelResponse: + try: + return await handler(request) + except StructuredOutputGenerationException: + self._retries_per_thread_id[request.state.thread_id] += 1 + if self._retries_per_thread_id[request.state.thread_id] > self._limit: + raise StructuredOutputRetryLimitExceededException(self._limit) + raise # re-raise, to retry structured output generation diff --git a/tests/integration/ai/test_structured_output.py b/tests/integration/ai/test_structured_output.py index 242b5403..72107f5f 100644 --- a/tests/integration/ai/test_structured_output.py +++ b/tests/integration/ai/test_structured_output.py @@ -21,6 +21,10 @@ from pydantic.dataclasses import dataclass from splunklib.ai import Agent +from splunklib.ai.hooks import ( + StructuredOutputRetryLimitExceededException, + StructuredOutputRetryLimitMiddleware, +) from splunklib.ai.messages import ( AgentResponse, AIMessage, @@ -930,5 +934,148 @@ async def _model_middleware( assert len(result.messages) == 3 assert result.structured_output.name == "MIKE" + @pytest.mark.asyncio + @ai_snapshot_test() + async def test_default_retry_limit(self) -> None: + pytest.importorskip("langchain_openai") + + class Person(BaseModel): + name: str = Field(description="The person's full name", min_length=1) + + model_call_count = 0 + + @model_middleware + async def _model_middleware( + _request: ModelRequest, + _handler: ModelMiddlewareHandler, + ) -> ModelResponse: + nonlocal model_call_count + model_call_count += 1 + + raise StructuredOutputGenerationException( + message=AIMessage(content="", calls=[]), + error=StructuredOutputValidationError( + validation_error="Invalid output" + ), + ) + + async with Agent( + model=(await self.model()), + system_prompt="Respond with structured data", + output_schema=Person, + service=self.service, + middleware=[_model_middleware], + ) as agent: + with pytest.raises( + StructuredOutputRetryLimitExceededException, + match="Structured output retry limit of 3 exceeded", + ): + await agent.invoke( + [HumanMessage(content="My name is Mike, what is my name?")] + ) + + assert model_call_count == 4 + + @pytest.mark.asyncio + @ai_snapshot_test() + async def test_custom_retry_limit_retry(self) -> None: + pytest.importorskip("langchain_openai") + + class Person(BaseModel): + name: str = Field(description="The person's full name", min_length=1) + + limits = [0, 1, 20] + for limit in limits: + with self.subTest(limit): + model_call_count = 0 + + @model_middleware + async def _model_middleware( + _request: ModelRequest, + _handler: ModelMiddlewareHandler, + ) -> ModelResponse: + nonlocal model_call_count + model_call_count += 1 + + raise StructuredOutputGenerationException( + message=AIMessage(content="", calls=[]), + error=StructuredOutputValidationError( + validation_error="Invalid output" + ), + ) + + async with Agent( + model=(await self.model()), + system_prompt="Respond with structured data", + output_schema=Person, + service=self.service, + middleware=[ + StructuredOutputRetryLimitMiddleware(limit), + _model_middleware, + ], + ) as agent: + with pytest.raises( + StructuredOutputRetryLimitExceededException, + match=f"Structured output retry limit of {limit} exceeded", + ): + await agent.invoke( + [HumanMessage(content="My name is Mike, what is my name?")] + ) + + # We expect limit + 1, since first LLM call is not a retry. + assert model_call_count == limit + 1 + + @pytest.mark.asyncio + @ai_snapshot_test() + async def test_retry_limit_is_per_agent_loop(self) -> None: + pytest.importorskip("langchain_openai") + + class Person(BaseModel): + name: str = Field(description="The person's full name", min_length=1) + + after_first_call = False + + @model_middleware + async def _model_middleware( + _request: ModelRequest, + _handler: ModelMiddlewareHandler, + ) -> ModelResponse: + if after_first_call: + return ModelResponse( + message=AIMessage(content="", calls=[]), + structured_output=Person(name="Mike"), + ) + else: + raise StructuredOutputGenerationException( + message=AIMessage(content="", calls=[]), + error=StructuredOutputValidationError( + validation_error="Invalid output" + ), + ) + + async with Agent( + model=(await self.model()), + system_prompt="Respond with structured data", + output_schema=Person, + service=self.service, + middleware=[ + _model_middleware, + ], + ) as agent: + with pytest.raises( + StructuredOutputRetryLimitExceededException, + match="Structured output retry limit of 3 exceeded", + ): + await agent.invoke( + [HumanMessage(content="My name is Mike, what is my name?")] + ) + + after_first_call = True + + # Since structured output retry limit is per agent loop, this should not fail. + await agent.invoke( + [HumanMessage(content="My name is Mike, what is my name?")] + ) + # TODO: test what happens if model/agent middleware removes the structured_output. # do we detect that? We should and raise in invoke, that output was removed.