Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,13 @@ async def awrap_tool_call(
assert resp.artifact is None, "artifact is already populated"

if resp.name.startswith(AGENT_PREFIX):
resp.artifact = SubagentFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
resp.artifact = SubagentFailureResult(
error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType]
)
else:
resp.artifact = ToolFailureResult(str(resp.content)) # pyright: ignore[reportUnknownArgumentType]
resp.artifact = ToolFailureResult(
error_message=str(resp.content) # pyright: ignore[reportUnknownArgumentType]
)

return resp

Expand Down Expand Up @@ -862,7 +866,9 @@ async def llm_handler(req: ModelRequest) -> ModelResponse:
case LC_StructuredOutputValidationError():
raise StructuredOutputGenerationException(
message=msg,
error=StructuredOutputValidationError(str(e.source)),
error=StructuredOutputValidationError(
validation_error=str(e.source)
),
)
case LC_StructuredOutputError():
# Langchain only returns the above handled exceptions, LC_StructuredOutputError
Expand Down Expand Up @@ -1012,7 +1018,7 @@ async def _sdk_handler(request: ToolRequest) -> ToolResponse:
assert isinstance(sdk_result, ToolMessage), (
"Expected tool response from tool middleware handler"
)
return ToolResponse(sdk_result.result)
return ToolResponse(result=sdk_result.result)

return _sdk_handler

Expand All @@ -1032,7 +1038,7 @@ async def _sdk_handler(
assert isinstance(sdk_result, SubagentMessage), (
"Expected subagent response from subagent middleware handler"
)
return SubagentResponse(sdk_result.result)
return SubagentResponse(result=sdk_result.result)

return _sdk_handler

Expand Down Expand Up @@ -1274,10 +1280,10 @@ def _convert_model_result_from_lc(model_response: LC_ModelCallResult) -> ModelRe

tool_strategy_messages = [
StructuredOutputMessage(
m.tool_call_id,
m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "",
m.status,
str(m.content), # pyright: ignore[reportUnknownArgumentType]
call_id=m.tool_call_id,
name=m.name.removeprefix(TOOL_STRATEGY_TOOL_PREFIX) if m.name else "",
status=m.status,
content=str(m.content), # pyright: ignore[reportUnknownArgumentType]
)
for m in model_response.result
if isinstance(m, LC_ToolMessage)
Expand Down Expand Up @@ -1402,7 +1408,9 @@ async def _tool_call(
"ToolException from LangChain should not be raised in tool.func"
)

artifact = ToolResult(result.content, result.structured_content)
artifact = ToolResult(
content=result.content, structured_content=result.structured_content
)

if result.structured_content:
# For both local tools and remote tools (Splunk MCP Server App), the primary
Expand Down Expand Up @@ -1719,9 +1727,9 @@ def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
],
structured_output_calls=[
StructuredOutputCall(
tc["id"] or "",
tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
tc["args"],
id=tc["id"] or "",
name=tc["name"].removeprefix(TOOL_STRATEGY_TOOL_PREFIX),
args=tc["args"],
)
for tc in message.tool_calls
if tc["name"].startswith(TOOL_STRATEGY_TOOL_PREFIX)
Expand Down
36 changes: 18 additions & 18 deletions splunklib/ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from splunklib.ai.tools import ToolType


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class TextBlock:
"""Plain text content block returned by a model."""

Expand All @@ -36,7 +36,7 @@ class TextBlock:
"""


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class OpaqueBlock:
"""Content block of an unrecognized or unsupported type.

Expand All @@ -62,30 +62,30 @@ class OpaqueBlock:
ContentBlock = TextBlock | OpaqueBlock


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolCall:
id: str
name: str
type: ToolType
args: dict[str, Any]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentCall:
id: str
name: str
args: str | dict[str, Any]
thread_id: str | None


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class StructuredOutputCall:
id: str
name: str
args: dict[str, Any]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class BaseMessage:
role: str = field(init=False)

Expand All @@ -96,7 +96,7 @@ def __post_init__(self) -> None:
)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class HumanMessage(BaseMessage):
"""
Message originating from a human user.
Expand All @@ -110,7 +110,7 @@ class HumanMessage(BaseMessage):
content: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AIMessage(BaseMessage):
"""
Message produced by an LLM.
Expand Down Expand Up @@ -141,7 +141,7 @@ class AIMessage(BaseMessage):
"""


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolResult:
"""
ToolResult represents a result of a successful tool call.
Expand All @@ -151,7 +151,7 @@ class ToolResult:
structured_content: dict[str, Any] | None


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentStructuredResult:
"""
SubagentStructuredResult represents a result of a successful subagent call.
Expand All @@ -161,7 +161,7 @@ class SubagentStructuredResult:
structured_output: dict[str, Any]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentTextResult:
"""
SubagentTextResult represents a result of a successful subagent call.
Expand All @@ -171,7 +171,7 @@ class SubagentTextResult:
content: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolFailureResult:
"""
Represents the result of a failed sub-agent call.
Expand All @@ -183,7 +183,7 @@ class ToolFailureResult:
error_message: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentFailureResult:
"""
Represents the result of a failed tool call.
Expand All @@ -195,7 +195,7 @@ class SubagentFailureResult:
error_message: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolMessage(BaseMessage):
"""ToolMessage represents a response of a tool call"""

Expand All @@ -208,7 +208,7 @@ class ToolMessage(BaseMessage):


# TODO: do we have a test that uses this?
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SystemMessage(BaseMessage):
"""
A message used to prime or control agent behavior.
Expand All @@ -218,7 +218,7 @@ class SystemMessage(BaseMessage):
content: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentMessage(BaseMessage):
"""
SubagentMessage represents a response of an agent invocation
Expand All @@ -231,7 +231,7 @@ class SubagentMessage(BaseMessage):
result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class StructuredOutputMessage(BaseMessage):
"""
StructuredMessage represents a response to the StructuredOutputCall.
Expand All @@ -254,7 +254,7 @@ class StructuredOutputMessage(BaseMessage):
# where developers might want to store messages in say KV store.


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AgentResponse(Generic[OutputT]):
# in case output_schema is provided, this will hold the parsed structured output
structured_output: OutputT
Expand Down
16 changes: 8 additions & 8 deletions splunklib/ai/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AgentState:
"""AgentState is available through certain middlewares and contains information about the current state of an agent execution."""

Expand All @@ -42,27 +42,27 @@ class AgentState:
token_count: int


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolRequest:
call: ToolCall
state: AgentState


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolResponse:
result: ToolResult | ToolFailureResult


ToolMiddlewareHandler = Callable[[ToolRequest], Awaitable[ToolResponse]]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentRequest:
call: SubagentCall
state: AgentState


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class SubagentResponse:
result: SubagentStructuredResult | SubagentTextResult | SubagentFailureResult

Expand All @@ -73,13 +73,13 @@ class SubagentResponse:
]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ModelRequest:
system_message: str
state: AgentState


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ModelResponse:
message: AIMessage
structured_output: Any | None = None
Expand All @@ -94,7 +94,7 @@ def __post_init__(self) -> None:
ModelMiddlewareHandler = Callable[[ModelRequest], Awaitable[ModelResponse]]


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AgentRequest:
messages: Sequence[BaseMessage]

Expand Down
6 changes: 3 additions & 3 deletions splunklib/ai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import httpx


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class PredefinedModel:
"""Base class for models that are predefined in the SDK"""

model: str


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class OpenAIModel(PredefinedModel):
"""Predefined OpenAI Model"""

Expand Down Expand Up @@ -53,7 +53,7 @@ class OpenAIModel(PredefinedModel):
"""


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class AnthropicModel(PredefinedModel):
"""Predefined Anthropic Model"""

Expand Down
4 changes: 2 additions & 2 deletions splunklib/ai/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from splunklib.ai.messages import AIMessage


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class StructuredOutputMultipleToolCallsError:
pass


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class StructuredOutputValidationError:
validation_error: str

Expand Down
8 changes: 4 additions & 4 deletions splunklib/ai/tool_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from splunklib.ai.tools import ToolMetadata


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolAllowlist:
"""Holds tool names and tags allowed to be used by Agents.

Expand All @@ -41,17 +41,17 @@ def is_allowed(self, tool: ToolMetadata) -> bool:
return self.custom_predicate(tool) if self.custom_predicate else False


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class RemoteToolSettings:
allowlist: ToolAllowlist


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class LocalToolSettings:
allowlist: ToolAllowlist


@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True)
class ToolSettings:
local: LocalToolSettings | bool
"""Controls local tool loading (via ``bin/tools.py``).
Expand Down
Loading
Loading