Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions splunklib/ai/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
from splunklib.ai.conversation_store import ConversationStore
from splunklib.ai.hooks import (
DEFAULT_STEP_LIMIT,
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TOKEN_LIMIT,
StepLimitMiddleware,
StructuredOutputRetryLimitMiddleware,
TimeoutLimitMiddleware,
TokenLimitMiddleware,
)
Expand Down Expand Up @@ -79,16 +81,24 @@ def __init__(
self._output_schema = output_schema
user_middleware = tuple(middleware) if middleware else ()
user_middleware_types = {type(m) for m in user_middleware}

# NOTE: we're creating separate instances per agent - TimeoutLimitMiddleware is stateful
# and sharing one would cause agents to overwrite each other's deadline.
predefined: list[AgentMiddleware] = [
predefined_before: list[AgentMiddleware] = [
StructuredOutputRetryLimitMiddleware(DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT),
]
predefined_after: list[AgentMiddleware] = [
TokenLimitMiddleware(DEFAULT_TOKEN_LIMIT),
StepLimitMiddleware(DEFAULT_STEP_LIMIT),
TimeoutLimitMiddleware(DEFAULT_TIMEOUT_SECONDS),
]
# Append predefined middlewares by default if not provided already.
default_middleware = [m for m in predefined if type(m) not in user_middleware_types]
self._middleware = (*user_middleware, *default_middleware)

self._middleware = (
*{m for m in predefined_before if type(m) not in user_middleware_types},
*user_middleware,
*{m for m in predefined_after if type(m) not in user_middleware_types},
)

self._trace_id = secrets.token_hex(16) # 32 Hex characters
self._conversation_store = conversation_store
self._thread_id = thread_id
Expand Down
5 changes: 0 additions & 5 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,11 +882,6 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
except StructuredOutputGenerationException as e:
# Structured output generation failed, retry.

# TODO: we should provide a mechanism to limit the amount of retries
# thath happen sequentially (say 3), otherwise raise a different exception.
# For now this can be done with the use of model middleware that counts
# the amount of StructuredOutputGenerationException that were raised.

ai_msg = _map_message_to_langchain(e.message)
assert isinstance(ai_msg, LC_AIMessage)

Expand Down
49 changes: 49 additions & 0 deletions splunklib/ai/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
ModelRequest,
ModelResponse,
)
from splunklib.ai.structured_output import StructuredOutputGenerationException

DEFAULT_TIMEOUT_SECONDS: float = 600.0
DEFAULT_STEP_LIMIT: int = 100
DEFAULT_TOKEN_LIMIT: int = 200_000
DEFAULT_STRUCTURED_OUTPUT_RETRY_LIMIT: int = 3


class AgentStopException(Exception):
Expand Down Expand Up @@ -43,6 +45,13 @@ def __init__(self, timeout_seconds: float) -> None:
super().__init__(f"Timed out after {timeout_seconds} seconds.")


class StructuredOutputRetryLimitExceededException(AgentStopException):
"""Raised by `Agent.invoke`, when structured output retry limit exceeds"""

def __init__(self, retry_count: int) -> None:
super().__init__(f"Structured output retry limit of {retry_count} exceeded")


def before_model(
func: Callable[[ModelRequest], None | Awaitable[None]],
) -> AgentMiddleware:
Expand Down Expand Up @@ -199,3 +208,43 @@ async def model_middleware(
if self._deadline is not None and monotonic() >= self._deadline:
raise TimeoutExceededException(timeout_seconds=self._seconds)
return await handler(request)


class StructuredOutputRetryLimitMiddleware(AgentMiddleware):
"""Stops agent execution when the agent exceeds structured output
retry limit during a single agent loop invocation.
"""

_limit: int
_retries_per_thread_id: dict[str, int]

def __init__(self, limit: int) -> None:
self._limit = limit
self._retries_per_thread_id = {}

@override
async def agent_middleware(
self,
request: AgentRequest,
handler: AgentMiddlewareHandler,
) -> AgentResponse[Any | None]:
try:
# Agent loop starting.
self._retries_per_thread_id[request.thread_id] = 0
return await handler(request)
finally:
del self._retries_per_thread_id[request.thread_id] # don't leak memory

@override
async def model_middleware(
self,
request: ModelRequest,
handler: ModelMiddlewareHandler,
) -> ModelResponse:
try:
return await handler(request)
except StructuredOutputGenerationException:
self._retries_per_thread_id[request.state.thread_id] += 1
if self._retries_per_thread_id[request.state.thread_id] > self._limit:
raise StructuredOutputRetryLimitExceededException(self._limit)
raise # re-raise, to retry structured output generation
147 changes: 147 additions & 0 deletions tests/integration/ai/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from pydantic.dataclasses import dataclass

from splunklib.ai import Agent
from splunklib.ai.hooks import (
StructuredOutputRetryLimitExceededException,
StructuredOutputRetryLimitMiddleware,
)
from splunklib.ai.messages import (
AgentResponse,
AIMessage,
Expand Down Expand Up @@ -930,5 +934,148 @@ async def _model_middleware(
assert len(result.messages) == 3
assert result.structured_output.name == "MIKE"

@pytest.mark.asyncio
@ai_snapshot_test()
async def test_default_retry_limit(self) -> None:
pytest.importorskip("langchain_openai")

class Person(BaseModel):
name: str = Field(description="The person's full name", min_length=1)

model_call_count = 0

@model_middleware
async def _model_middleware(
_request: ModelRequest,
_handler: ModelMiddlewareHandler,
) -> ModelResponse:
nonlocal model_call_count
model_call_count += 1

raise StructuredOutputGenerationException(
message=AIMessage(content="", calls=[]),
error=StructuredOutputValidationError(
validation_error="Invalid output"
),
)

async with Agent(
model=(await self.model()),
system_prompt="Respond with structured data",
output_schema=Person,
service=self.service,
middleware=[_model_middleware],
) as agent:
with pytest.raises(
StructuredOutputRetryLimitExceededException,
match="Structured output retry limit of 3 exceeded",
):
await agent.invoke(
[HumanMessage(content="My name is Mike, what is my name?")]
)

assert model_call_count == 4

@pytest.mark.asyncio
@ai_snapshot_test()
async def test_custom_retry_limit_retry(self) -> None:
pytest.importorskip("langchain_openai")

class Person(BaseModel):
name: str = Field(description="The person's full name", min_length=1)

limits = [0, 1, 20]
for limit in limits:
with self.subTest(limit):
model_call_count = 0

@model_middleware
async def _model_middleware(
_request: ModelRequest,
_handler: ModelMiddlewareHandler,
) -> ModelResponse:
nonlocal model_call_count
model_call_count += 1

raise StructuredOutputGenerationException(
message=AIMessage(content="", calls=[]),
error=StructuredOutputValidationError(
validation_error="Invalid output"
),
)

async with Agent(
model=(await self.model()),
system_prompt="Respond with structured data",
output_schema=Person,
service=self.service,
middleware=[
StructuredOutputRetryLimitMiddleware(limit),
_model_middleware,
],
) as agent:
with pytest.raises(
StructuredOutputRetryLimitExceededException,
match=f"Structured output retry limit of {limit} exceeded",
):
await agent.invoke(
[HumanMessage(content="My name is Mike, what is my name?")]
)

# We expect limit + 1, since first LLM call is not a retry.
assert model_call_count == limit + 1

@pytest.mark.asyncio
@ai_snapshot_test()
async def test_retry_limit_is_per_agent_loop(self) -> None:
pytest.importorskip("langchain_openai")

class Person(BaseModel):
name: str = Field(description="The person's full name", min_length=1)

after_first_call = False

@model_middleware
async def _model_middleware(
_request: ModelRequest,
_handler: ModelMiddlewareHandler,
) -> ModelResponse:
if after_first_call:
return ModelResponse(
message=AIMessage(content="", calls=[]),
structured_output=Person(name="Mike"),
)
else:
raise StructuredOutputGenerationException(
message=AIMessage(content="", calls=[]),
error=StructuredOutputValidationError(
validation_error="Invalid output"
),
)

async with Agent(
model=(await self.model()),
system_prompt="Respond with structured data",
output_schema=Person,
service=self.service,
middleware=[
_model_middleware,
],
) as agent:
with pytest.raises(
StructuredOutputRetryLimitExceededException,
match="Structured output retry limit of 3 exceeded",
):
await agent.invoke(
[HumanMessage(content="My name is Mike, what is my name?")]
)

after_first_call = True

# Since structured output retry limit is per agent loop, this should not fail.
await agent.invoke(
[HumanMessage(content="My name is Mike, what is my name?")]
)

# TODO: test what happens if model/agent middleware removes the structured_output.
# do we detect that? We should and raise in invoke, that output was removed.
Loading