diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py new file mode 100644 index 0000000000..ae3517919c --- /dev/null +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -0,0 +1,52 @@ +"""Tool result formatting for MCP tools/call responses. + +Each formatter is a ``ResultFormatter`` subclass that can enrich +a tool result with additional content. All formatters are accumulated. +""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent + +from dash.types import CallbackExecutionResponse +from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + +from .base import ResultFormatter +from .result_dataframe import DataFrameResult +from .result_plotly_figure import PlotlyFigureResult + +_RESULT_FORMATTERS: list[type[ResultFormatter]] = [ + PlotlyFigureResult, + DataFrameResult, +] + + +def format_callback_response( + response: CallbackExecutionResponse, + callback: CallbackAdapter, +) -> CallToolResult: + """Format a callback response as a CallToolResult. + + The response is always returned as structuredContent. Result + formatters are called per output property and may add additional + content items (images, markdown, etc.). + """ + content: list[Any] = [ + TextContent(type="text", text=json.dumps(response, default=str)), + ] + + resp = response.get("response") or {} + for callback_output in callback.outputs: + value = resp.get(callback_output["component_id"], {}).get( + callback_output["property"] + ) + for formatter in _RESULT_FORMATTERS: + content.extend(formatter.format(callback_output, value)) + + return CallToolResult( + content=content, + structuredContent=response, + ) diff --git a/dash/mcp/primitives/tools/results/base.py b/dash/mcp/primitives/tools/results/base.py new file mode 100644 index 0000000000..1f7714ff6b --- /dev/null +++ b/dash/mcp/primitives/tools/results/base.py @@ -0,0 +1,24 @@ +"""Base class for result formatters.""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + + +class ResultFormatter: + """A formatter that can enrich an MCP tool result with additional content. + + Subclasses implement ``format`` to return content items (text, images) + for a specific callback output. All formatters are accumulated — every + formatter can add content to the overall tool result. + """ + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + raise NotImplementedError diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py new file mode 100644 index 0000000000..04b1d84b3e --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -0,0 +1,68 @@ +"""Tabular data result: render as a markdown table. + +Detects tabular output by component type and prop name: +- DataTable.data +- AgGrid.rowData +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + +from .base import ResultFormatter + +MAX_ROWS = 50 + +_TABULAR_PROPS = { + ("DataTable", "data"), + ("AgGrid", "rowData"), +} + + +def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: + """Render a list of row dicts as a markdown table.""" + columns = list(rows[0].keys()) + total = len(rows) + + lines: list[str] = [] + lines.append(f"*{total} rows \u00d7 {len(columns)} columns*") + lines.append("") + lines.append("| " + " | ".join(columns) + " |") + lines.append("| " + " | ".join("---" for _ in columns) + " |") + + for row in rows[:max_rows]: + cells = [ + str(row.get(col, "")).replace("|", "\\|").replace("\n", " ") + for col in columns + ] + lines.append("| " + " | ".join(cells) + " |") + + if total > max_rows: + lines.append(f"\n(\u2026 {total - max_rows} more rows)") + + return "\n".join(lines) + + +class DataFrameResult(ResultFormatter): + """Produce a markdown table for tabular component output values.""" + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + key = (output.get("component_type"), output.get("property")) + if key not in _TABULAR_PROPS: + return [] + if ( + not isinstance(returned_output_value, list) + or not returned_output_value + or not isinstance(returned_output_value[0], dict) + ): + return [] + return [ + TextContent(type="text", text=_to_markdown_table(returned_output_value)) + ] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py new file mode 100644 index 0000000000..ad2c057f89 --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -0,0 +1,62 @@ +"""Plotly figure tool result: rendered image.""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any + +from mcp.types import ImageContent, TextContent + +from dash.mcp.types import MCPOutput + +from .base import ResultFormatter + +logger = logging.getLogger(__name__) + +IMAGE_WIDTH = 700 +IMAGE_HEIGHT = 450 + + +def _render_image(figure: Any) -> ImageContent | None: + """Render the figure as a base64 PNG ImageContent. + + Returns None if kaleido is not installed. + """ + try: + img_bytes = figure.to_image( + format="png", + width=IMAGE_WIDTH, + height=IMAGE_HEIGHT, + ) + except (ValueError, ImportError): + logger.debug("MCP: kaleido not available, skipping image render") + return None + + b64 = base64.b64encode(img_bytes).decode("ascii") + return ImageContent(type="image", data=b64, mimeType="image/png") + + +class PlotlyFigureResult(ResultFormatter): + """Produce a rendered PNG for Graph.figure output values.""" + + @classmethod + def format( + cls, output: MCPOutput, returned_output_value: Any + ) -> list[TextContent | ImageContent]: + if ( + output.get("component_type") != "Graph" + or output.get("property") != "figure" + ): + return [] + if not isinstance(returned_output_value, dict): + return [] + + try: + import plotly.graph_objects as go + except ImportError: + return [] + + fig = go.Figure(returned_output_value) + image = _render_image(fig) + return [image] if image is not None else [] diff --git a/tests/unit/mcp/tools/results/test_callback_response.py b/tests/unit/mcp/tools/results/test_callback_response.py new file mode 100644 index 0000000000..ff8cca5e20 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_callback_response.py @@ -0,0 +1,98 @@ +"""Tests for the callback response formatter.""" + +from unittest.mock import Mock + +from dash.mcp.primitives.tools.results import format_callback_response + + +def _mock_callback(outputs=None): + cb = Mock() + cb.outputs = outputs or [] + return cb + + +class TestFormatCallbackResponse: + def test_wraps_as_structured_content(self): + response = { + "multi": True, + "response": {"out": {"children": "hello"}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent == response + + def test_content_has_json_text_fallback(self): + """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert len(result.content) >= 1 + assert result.content[0].type == "text" + assert '"multi": true' in result.content[0].text + + def test_is_error_defaults_false(self): + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert result.isError is False + + def test_preserves_side_update(self): + response = { + "multi": True, + "response": {"out": {"children": "x"}}, + "sideUpdate": {"other": {"value": 42}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} + + def test_datatable_result_includes_markdown_table(self): + response = { + "multi": True, + "response": { + "my-table": {"data": [{"name": "Alice", "age": 30}]}, + }, + } + outputs = [ + { + "component_id": "my-table", + "component_type": "DataTable", + "property": "data", + "id_and_prop": "my-table.data", + "initial_value": None, + "tool_name": "update", + } + ] + result = format_callback_response(response, _mock_callback(outputs)) + texts = [c.text for c in result.content if c.type == "text"] + assert any("| name | age |" in t for t in texts) + + def test_plotly_figure_includes_image(self): + from unittest.mock import patch + + try: + import plotly.graph_objects as go + except ImportError: + return + + response = { + "multi": True, + "response": { + "my-graph": { + "figure": { + "data": [{"type": "bar", "x": ["A"], "y": [1]}], + "layout": {}, + } + } + }, + } + outputs = [ + { + "component_id": "my-graph", + "component_type": "Graph", + "property": "figure", + "id_and_prop": "my-graph.figure", + "initial_value": None, + "tool_name": "update", + } + ] + with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): + result = format_callback_response(response, _mock_callback(outputs)) + images = [c for c in result.content if c.type == "image"] + assert len(images) == 1 diff --git a/tests/unit/mcp/tools/results/test_dataframe.py b/tests/unit/mcp/tools/results/test_dataframe.py new file mode 100644 index 0000000000..65aef74d31 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_dataframe.py @@ -0,0 +1,63 @@ +"""Tests for the tabular data result formatter.""" + +from dash.mcp.primitives.tools.results.result_dataframe import ( + MAX_ROWS, + DataFrameResult, +) + +EXPECTED_TABLE = ( + "*2 rows \u00d7 2 columns*\n" + "\n" + "| name | age |\n" + "| --- | --- |\n" + "| Alice | 30 |\n" + "| Bob | 25 |" +) + +SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + +DATATABLE_OUTPUT = { + "component_type": "DataTable", + "property": "data", + "component_id": "t", + "id_and_prop": "t.data", + "initial_value": None, + "tool_name": "update", +} + +AGGRID_OUTPUT = { + "component_type": "AgGrid", + "property": "rowData", + "component_id": "g", + "id_and_prop": "g.rowData", + "initial_value": None, + "tool_name": "update", +} + + +class TestDataframeResult: + def test_datatable_data_renders_markdown(self): + result = DataFrameResult.format(DATATABLE_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_aggrid_rowdata_renders_markdown(self): + result = DataFrameResult.format(AGGRID_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_ignores_non_tabular_props(self): + non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} + assert DataFrameResult.format(non_tabular, SAMPLE_ROWS) == [] + + def test_ignores_empty_or_non_dict_rows(self): + assert DataFrameResult.format(DATATABLE_OUTPUT, []) == [] + assert DataFrameResult.format(DATATABLE_OUTPUT, ["a", "b"]) == [] + + def test_truncates_large_tables(self): + rows = [{"i": n} for n in range(MAX_ROWS + 50)] + result = DataFrameResult.format(DATATABLE_OUTPUT, rows) + text = result[0].text + assert f"| {MAX_ROWS - 1} |" in text + assert f"| {MAX_ROWS} |" not in text + assert "50 more rows" in text diff --git a/tests/unit/mcp/tools/results/test_plotly_figure.py b/tests/unit/mcp/tools/results/test_plotly_figure.py new file mode 100644 index 0000000000..e3c42af303 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_plotly_figure.py @@ -0,0 +1,55 @@ +"""Tests for the Plotly figure tool result formatter.""" + +import base64 +from unittest.mock import patch + +import pytest + +from dash.mcp.primitives.tools.results.result_plotly_figure import ( + PlotlyFigureResult, +) + +go = pytest.importorskip("plotly.graph_objects") + +FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" +FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") + +GRAPH_FIGURE_OUTPUT = { + "component_type": "Graph", + "property": "figure", + "component_id": "g", + "id_and_prop": "g.figure", + "initial_value": None, + "tool_name": "update", +} + + +class TestPlotlyFigureResult: + def test_returns_image_when_kaleido_available(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert len(result) == 1 + assert result[0].type == "image" + assert result[0].data == FAKE_B64 + + def test_returns_empty_when_kaleido_unavailable(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", side_effect=ImportError): + result = PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, fig_dict) + assert result == [] + + def test_ignores_non_graph_components(self): + output = { + **GRAPH_FIGURE_OUTPUT, + "component_type": "Div", + "property": "children", + } + assert PlotlyFigureResult.format(output, {}) == [] + + def test_ignores_non_figure_props(self): + output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} + assert PlotlyFigureResult.format(output, {}) == [] + + def test_ignores_non_dict_values(self): + assert PlotlyFigureResult.format(GRAPH_FIGURE_OUTPUT, "not a dict") == []