diff --git a/splunklib/ai/engines/langchain.py b/splunklib/ai/engines/langchain.py index b53d1cd7..5a090e02 100644 --- a/splunklib/ai/engines/langchain.py +++ b/splunklib/ai/engines/langchain.py @@ -286,9 +286,13 @@ async def awrap_tool_call( assert resp.artifact is None, "artifact is already populated" if resp.name.startswith(AGENT_PREFIX): - resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] + resp.artifact = SubagentFailureResult( + error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType] + ) else: - resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType] + resp.artifact = ToolFailureResult( + error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType] + ) return resp @@ -862,7 +866,9 @@ async def llm_handler(req: ModelRequest) -> ModelResponse: case LC_StructuredOutputValidationError(): raise StructuredOutputGenerationException( message=msg, - error=StructuredOutputValidationError(str(e.source)), + error=StructuredOutputValidationError( + validation_error=str(e.source) + ), ) case LC_StructuredOutputError(): # Langchain only returns the above handled exceptions, LC_StructuredOutputError @@ -1012,7 +1018,7 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse: assert isinstance(sdk_result, ToolMessage), ( "Expected tool response from tool middleware handler" ) - return ToolResponse(sdk_result.result) + return ToolResponse(result=sdk_result.result) return _sdk_handler @@ -1032,7 +1038,7 @@ async def _sdk_handler( assert isinstance(sdk_result, SubagentMessage), ( "Expected subagent response from subagent middleware handler" ) - return SubagentResponse(sdk_result.result) + return SubagentResponse(result=sdk_result.result) return _sdk_handler @@ -1274,10 +1280,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe tool_strategy_messages = [ StructuredOutputMessage( - m.tool_call_id, - m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "", - m.status, - str(m.content), # pyright: ignore[reportUnknownArgumentType] + call_id=m.tool_call_id, + name=m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "", + status=m.status, + content=str(m.content), # pyright: ignore[reportUnknownArgumentType] ) for m in model_response.result if isinstance(m, LC_ToolMessage) @@ -1402,7 +1408,9 @@ async def _tool_call( "ToolException from LangChain should not be raised in tool.func" ) - artifact = ToolResult(result.content, result.structured_content) + artifact = ToolResult( + content=result.content, structured_content=result.structured_content + ) if result.structured_content: # For both local tools and remote tools (Splunk MCP Server App), the primary @@ -1719,9 +1727,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage: ], structured_output_calls=[ StructuredOutputCall( - tc["id"] or "", - tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX), - tc["args"], + id=tc["id"] or "", + name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX), + args=tc["args"], ) for tc in message.tool_calls if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX) diff --git a/splunklib/ai/messages.py b/splunklib/ai/messages.py index 614d9d04..4f8ff937 100644 --- a/splunklib/ai/messages.py +++ b/splunklib/ai/messages.py @@ -21,7 +21,7 @@ from splunklib.ai.tools import ToolType -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class TextBlock: """Plain text content block returned by a model.""" @@ -36,7 +36,7 @@ class TextBlock: """ -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class OpaqueBlock: """Content block of an unrecognized or unsupported type. @@ -62,7 +62,7 @@ class OpaqueBlock: ContentBlock = TextBlock | OpaqueBlock -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolCall: id: str name: str @@ -70,7 +70,7 @@ class ToolCall: args: dict[str, Any] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentCall: id: str name: str @@ -78,14 +78,14 @@ class SubagentCall: thread_id: str | None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StructuredOutputCall: id: str name: str args: dict[str, Any] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class BaseMessage: role: str = field(init=False) @@ -96,7 +96,7 @@ def __post_init__(self) -> None: ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class HumanMessage(BaseMessage): """ Message originating from a human user. @@ -110,7 +110,7 @@ class HumanMessage(BaseMessage): content: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AIMessage(BaseMessage): """ Message produced by an LLM. @@ -141,7 +141,7 @@ class AIMessage(BaseMessage): """ -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolResult: """ ToolResult represents a result of a successful tool call. @@ -151,7 +151,7 @@ class ToolResult: structured_content: dict[str, Any] | None -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentStructuredResult: """ SubagentStructuredResult represents a result of a successful subagent call. @@ -161,7 +161,7 @@ class SubagentStructuredResult: structured_output: dict[str, Any] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentTextResult: """ SubagentTextResult represents a result of a successful subagent call. @@ -171,7 +171,7 @@ class SubagentTextResult: content: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolFailureResult: """ Represents the result of a failed sub-agent call. @@ -183,7 +183,7 @@ class ToolFailureResult: error_message: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentFailureResult: """ Represents the result of a failed tool call. @@ -195,7 +195,7 @@ class SubagentFailureResult: error_message: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolMessage(BaseMessage): """ToolMessage represents a response of a tool call""" @@ -208,7 +208,7 @@ class ToolMessage(BaseMessage): # TODO: do we have a test that uses this? -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SystemMessage(BaseMessage): """ A message used to prime or control agent behavior. @@ -218,7 +218,7 @@ class SystemMessage(BaseMessage): content: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentMessage(BaseMessage): """ SubagentMessage represents a response of an agent invocation @@ -231,7 +231,7 @@ class SubagentMessage(BaseMessage): result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StructuredOutputMessage(BaseMessage): """ StructuredMessage represents a response to the StructuredOutputCall. @@ -254,7 +254,7 @@ class StructuredOutputMessage(BaseMessage): # where developers might want to store messages in say KV store. -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AgentResponse(Generic[OutputT]): # in case output_schema is provided, this will hold the parsed structured output structured_output: OutputT diff --git a/splunklib/ai/middleware.py b/splunklib/ai/middleware.py index 0231dbb6..e47affbf 100644 --- a/splunklib/ai/middleware.py +++ b/splunklib/ai/middleware.py @@ -30,7 +30,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AgentState: """AgentState is available through certain middlewares and contains information about the current state of an agent execution.""" @@ -42,13 +42,13 @@ class AgentState: token_count: int -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolRequest: call: ToolCall state: AgentState -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolResponse: result: ToolResult | ToolFailureResult @@ -56,13 +56,13 @@ class ToolResponse: ToolMiddlewareHandler = Callable[[ToolRequest], Awaitable[ToolResponse]] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentRequest: call: SubagentCall state: AgentState -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class SubagentResponse: result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult @@ -73,13 +73,13 @@ class SubagentResponse: ] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ModelRequest: system_message: str state: AgentState -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ModelResponse: message: AIMessage structured_output: Any | None = None @@ -94,7 +94,7 @@ def __post_init__(self) -> None: ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[ModelResponse]] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AgentRequest: messages: Sequence[BaseMessage] diff --git a/splunklib/ai/model.py b/splunklib/ai/model.py index c701f5d0..2c1f59c7 100644 --- a/splunklib/ai/model.py +++ b/splunklib/ai/model.py @@ -18,14 +18,14 @@ import httpx -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class PredefinedModel: """Base class for models that are predefined in the SDK""" model: str -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class OpenAIModel(PredefinedModel): """Predefined OpenAI Model""" @@ -53,7 +53,7 @@ class OpenAIModel(PredefinedModel): """ -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class AnthropicModel(PredefinedModel): """Predefined Anthropic Model""" diff --git a/splunklib/ai/structured_output.py b/splunklib/ai/structured_output.py index 06fc9635..3c31fd49 100644 --- a/splunklib/ai/structured_output.py +++ b/splunklib/ai/structured_output.py @@ -17,12 +17,12 @@ from splunklib.ai.messages import AIMessage -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StructuredOutputMultipleToolCallsError: pass -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class StructuredOutputValidationError: validation_error: str diff --git a/splunklib/ai/tool_settings.py b/splunklib/ai/tool_settings.py index fe5fdfae..22ce0c9e 100644 --- a/splunklib/ai/tool_settings.py +++ b/splunklib/ai/tool_settings.py @@ -18,7 +18,7 @@ from splunklib.ai.tools import ToolMetadata -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolAllowlist: """Holds tool names and tags allowed to be used by Agents. @@ -41,17 +41,17 @@ def is_allowed(self, tool: ToolMetadata) -> bool: return self.custom_predicate(tool) if self.custom_predicate else False -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class RemoteToolSettings: allowlist: ToolAllowlist -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class LocalToolSettings: allowlist: ToolAllowlist -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolSettings: local: LocalToolSettings | bool """Controls local tool loading (via ``bin/tools.py``). diff --git a/splunklib/ai/tools.py b/splunklib/ai/tools.py index 5846f08e..5dbbbd55 100644 --- a/splunklib/ai/tools.py +++ b/splunklib/ai/tools.py @@ -38,7 +38,7 @@ class ToolException(Exception): """Custom exception to indicate tool execution errors.""" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolResult: content: str structured_content: dict[str, Any] | None @@ -49,7 +49,7 @@ class ToolType(Enum): REMOTE = "remote" -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ToolMetadata: name: str description: str @@ -58,7 +58,7 @@ class ToolMetadata: tags: list[str] -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class Tool(ToolMetadata): func: Callable[..., Awaitable[ToolResult]] diff --git a/tests/integration/ai/test_agent_mcp_tools.py b/tests/integration/ai/test_agent_mcp_tools.py index 7bd4518d..3bca0a63 100644 --- a/tests/integration/ai/test_agent_mcp_tools.py +++ b/tests/integration/ai/test_agent_mcp_tools.py @@ -771,7 +771,7 @@ class ToolResults(BaseModel): assert len(agent.tools) == 2 content = "Call tools to populate output." - response = await agent.invoke([HumanMessage(content)]) + response = await agent.invoke([HumanMessage(content=content)]) print(response.structured_output) assert response.structured_output.remote_temperature == "31.5C" assert response.structured_output.local_temperature == "22.1C" diff --git a/tests/integration/ai/test_conversation_store.py b/tests/integration/ai/test_conversation_store.py index f4616ca5..da468cb0 100644 --- a/tests/integration/ai/test_conversation_store.py +++ b/tests/integration/ai/test_conversation_store.py @@ -129,7 +129,7 @@ async def _agent_middleware( if not after_first_call: return AgentResponse( messages=[ - HumanMessage("My name is Mike"), + HumanMessage(content="My name is Mike"), AIMessage(content="Hi Mike!", calls=[]), ], structured_output=None, diff --git a/tests/integration/ai/test_middleware.py b/tests/integration/ai/test_middleware.py index c90c82ba..e9142cd1 100644 --- a/tests/integration/ai/test_middleware.py +++ b/tests/integration/ai/test_middleware.py @@ -191,7 +191,9 @@ async def test_middleware( call = request.call assert call.id, "Invalid call id received" - return ToolResponse(ToolResult(content="0.5C", structured_content=None)) + return ToolResponse( + result=ToolResult(content="0.5C", structured_content=None) + ) async with Agent( model=await self.model(), @@ -475,7 +477,9 @@ async def test_middleware( call = request.call assert call.id, "Invalid call id received" - return SubagentResponse(SubagentTextResult(content="Chris-superstar")) + return SubagentResponse( + result=SubagentTextResult(content="Chris-superstar") + ) async with ( Agent( diff --git a/tests/integration/ai/test_structured_output.py b/tests/integration/ai/test_structured_output.py index 242b5403..edac8cc8 100644 --- a/tests/integration/ai/test_structured_output.py +++ b/tests/integration/ai/test_structured_output.py @@ -699,7 +699,7 @@ async def _model_middleware( raise StructuredOutputGenerationException( message=resp.message, error=StructuredOutputValidationError( - "Validation error: name must have ALL letters capitalized" + validation_error="Validation error: name must have ALL letters capitalized" ), ) return resp @@ -736,7 +736,7 @@ async def _model_middleware( raise StructuredOutputGenerationException( message=resp.message, error=StructuredOutputValidationError( - "Validation error: name must have ALL letters capitalized" + validation_error="Validation error: name must have ALL letters capitalized" ), ) return resp diff --git a/tests/unit/ai/engine/test_langchain_backend.py b/tests/unit/ai/engine/test_langchain_backend.py index 1daa0add..ea0e199c 100644 --- a/tests/unit/ai/engine/test_langchain_backend.py +++ b/tests/unit/ai/engine/test_langchain_backend.py @@ -239,7 +239,7 @@ def test_map_message_from_langchain_tool(self) -> None: content="result", tool_call_id="call-1", status="error", - artifact=ToolFailureResult("result"), + artifact=ToolFailureResult(error_message="result"), ) mapped = lc._map_message_from_langchain(message) @@ -255,7 +255,7 @@ def test_map_message_from_langchain_subagent(self) -> None: content="subagent output", tool_call_id="call-2", status="error", - artifact=SubagentFailureResult("subagent output"), + artifact=SubagentFailureResult(error_message="subagent output"), ) mapped = lc._map_message_from_langchain(message) @@ -550,7 +550,7 @@ def test_map_message_to_langchain_tool(self) -> None: name="lookup", call_id="call-1", type=ToolType.REMOTE, - result=ToolFailureResult("result"), + result=ToolFailureResult(error_message="result"), ) mapped = lc._map_message_to_langchain(message) @@ -562,7 +562,9 @@ def test_map_message_to_langchain_tool(self) -> None: def test_map_message_to_langchain_subagent(self) -> None: message = SubagentMessage( - name="My Agent", call_id="call-2", result=SubagentFailureResult("ping") + name="My Agent", + call_id="call-2", + result=SubagentFailureResult(error_message="ping"), ) mapped = lc._map_message_to_langchain(message) diff --git a/tests/unit/ai/test_tool_settings.py b/tests/unit/ai/test_tool_settings.py index e6d5ac7f..0f592448 100644 --- a/tests/unit/ai/test_tool_settings.py +++ b/tests/unit/ai/test_tool_settings.py @@ -67,7 +67,7 @@ def test_filtering( initial_tools: Sequence[Tool], expected_tools: Sequence[Tool], ) -> None: - filters = ToolAllowlist(allowed_names, allowed_tags) + filters = ToolAllowlist(names=allowed_names, tags=allowed_tags) filtered_tools = [t for t in initial_tools if filters.is_allowed(t)] assert filtered_tools == expected_tools