From 1c31114e77aea6fafbbe71b1c974c9e3b74a35d9 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:35:13 -0600 Subject: [PATCH 01/10] Wire MCP server, SSE transport, and Dash app integration --- dash/_configs.py | 2 + dash/dash.py | 31 + dash/mcp/__init__.py | 7 + dash/mcp/_server.py | 277 +++++ dash/mcp/_sse.py | 67 ++ dash/mcp/notifications/__init__.py | 7 + .../notification_tools_changed.py | 30 + dash/mcp/primitives/__init__.py | 17 + .../tools/callback_adapter_collection.py | 2 - tests/integration/mcp/conftest.py | 53 + .../primitives/resources/test_resources.py | 51 + .../tools/test_callback_signatures.py | 958 ++++++++++++++++++ .../tools/test_duplicate_outputs.py | 128 +++ .../primitives/tools/test_input_schemas.py | 66 ++ .../tools/test_tool_get_dash_component.py | 54 + .../mcp/primitives/tools/test_tools_list.py | 118 +++ tests/integration/mcp/test_server.py | 304 ++++++ tests/unit/mcp/test_server.py | 92 ++ tests/unit/mcp/tools/test_run_callback.py | 246 +++++ 19 files changed, 2508 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/_server.py create mode 100644 dash/mcp/_sse.py create mode 100644 dash/mcp/notifications/__init__.py create mode 100644 dash/mcp/notifications/notification_tools_changed.py create mode 100644 tests/integration/mcp/conftest.py create mode 100644 tests/integration/mcp/primitives/resources/test_resources.py create mode 100644 tests/integration/mcp/primitives/tools/test_callback_signatures.py create mode 100644 tests/integration/mcp/primitives/tools/test_duplicate_outputs.py create mode 100644 tests/integration/mcp/primitives/tools/test_input_schemas.py create mode 100644 tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py create mode 100644 tests/integration/mcp/primitives/tools/test_tools_list.py create mode 100644 tests/integration/mcp/test_server.py create mode 100644 tests/unit/mcp/test_server.py create mode 100644 tests/unit/mcp/tools/test_run_callback.py diff --git a/dash/_configs.py b/dash/_configs.py index edbf7b50d1..f6df4001f1 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -32,6 +32,8 @@ def load_dash_env_vars(): "DASH_DISABLE_VERSION_CHECK", "DASH_PRUNE_ERRORS", "DASH_COMPRESS", + "DASH_MCP_ENABLED", + "DASH_MCP_PATH", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index a9f9b6c757..2cc1f37c61 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -483,6 +483,8 @@ def __init__( # pylint: disable=too-many-statements health_endpoint: Optional[str] = None, csrf_token_name: str = "_csrf_token", csrf_header_name: str = "X-CSRFToken", + enable_mcp: Optional[bool] = None, + mcp_path: Optional[str] = None, **obsolete, ): @@ -593,6 +595,13 @@ def __init__( # pylint: disable=too-many-statements # keep title as a class property for backwards compatibility self.title = title + # MCP (Model Context Protocol) configuration + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True) + _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") + self._mcp_path = ( + _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path + ) + # list of dependencies - this one is used by the back end for dispatching self.callback_map: dict = {} # same deps as a list to catch duplicate outputs, and to send to the front end @@ -813,6 +822,21 @@ def _setup_routes(self): hook.data["methods"], ) + if self._enable_mcp: + from .mcp import ( # pylint: disable=import-outside-toplevel + enable_mcp_server, + ) + + try: + enable_mcp_server(self, self._mcp_path) + except Exception as e: # pylint: disable=broad-exception-caught + self._enable_mcp = False + self.logger.warning( + "MCP server could not be started at '%s': %s", + self._mcp_path, + e, + ) + # catch-all for front-end routes, used by dcc.Location self._add_url("", self.index) @@ -2548,6 +2572,13 @@ def verify_url_part(served_part, url_part, part_name): if not jupyter_dash or not jupyter_dash.in_ipython: self.logger.info("Dash is running on %s://%s%s%s\n", *display_url) + if self._enable_mcp: + self.logger.info( + " * MCP available at %s://%s%s%s%s\n", + *display_url[:3], + self.config.routes_pathname_prefix, + self._mcp_path, + ) if self.config.extra_hot_reload_paths: extra_files = flask_run_options["extra_files"] = [] diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py index e69de29bb2..2677ea141b 100644 --- a/dash/mcp/__init__.py +++ b/dash/mcp/__init__.py @@ -0,0 +1,7 @@ +"""Dash MCP (Model Context Protocol) server integration.""" + +from dash.mcp._server import enable_mcp_server + +__all__ = [ + enable_mcp_server, +] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py new file mode 100644 index 0000000000..1c6279290b --- /dev/null +++ b/dash/mcp/_server.py @@ -0,0 +1,277 @@ +"""Flask route setup, Streamable HTTP transport, and MCP message handling.""" + +from __future__ import annotations + +import atexit +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from flask import Response, request + +from dash.mcp.types import MCPError + +if TYPE_CHECKING: + from dash import Dash + +from dash import get_app + +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ErrorData, + Implementation, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ResourcesCapability, + ServerCapabilities, + ToolsCapability, +) + +from dash.version import __version__ +from dash.mcp._sse import ( + close_sse_stream, + create_sse_stream, + shutdown_all_streams, +) +from dash.mcp.primitives import ( + call_tool, + list_resource_templates, + list_resources, + list_tools, + read_resource, +) +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +logger = logging.getLogger(__name__) + + +def enable_mcp_server(app: Dash, mcp_path: str) -> None: + """ + Add MCP routes to a Dash/Flask app. + + Registers a single Streamable HTTP endpoint for the MCP protocol. + Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied + automatically. + + Args: + app: The Dash application instance. + mcp_path: Route prefix for MCP endpoints. + """ + # Session storage: session_id -> metadata + sessions: dict[str, dict[str, Any]] = {} + + def _create_session() -> str: + sid = str(uuid.uuid4()) + sessions[sid] = {} + return sid + + # -- Streamable HTTP endpoint -------------------------------------------- + + def mcp_handler() -> Response: + if request.method == "POST": + return _handle_post() + if request.method == "GET": + return _handle_get() + if request.method == "DELETE": + return _handle_delete() + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) + + def _handle_get() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + return create_sse_stream(sessions, session_id) + + def _handle_post() -> Response: + content_type = request.content_type or "" + if "application/json" not in content_type: + return Response( + json.dumps({"error": "Content-Type must be application/json"}), + content_type="application/json", + status=415, + ) + + try: + data = request.get_json() + except Exception: + return Response( + json.dumps({"error": "Invalid JSON"}), + content_type="application/json", + status=400, + ) + + method = data.get("method", "") + request_id = data.get("id") + session_id = request.headers.get("mcp-session-id") + + stale_session = False + if method == "initialize": + session_id = _create_session() + elif session_id and session_id not in sessions: + stale_session = True + sessions[session_id] = {} + elif not session_id: + session_id = _create_session() + + response_data = _process_mcp_message(data) + + if response_data is None: + return Response("", status=202) + + if stale_session: + _inject_warning(response_data, _STALE_SESSION_WARNING) + + return Response( + json.dumps(response_data), + content_type="application/json", + status=200, + headers={"mcp-session-id": session_id}, + ) + + def _handle_delete() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + close_sse_stream(sessions[session_id]) + del sessions[session_id] + logger.info("MCP session terminated: %s", session_id) + return Response("", status=204) + + # -- Register routes ----------------------------------------------------- + + from dash._get_app import with_app_context_factory + + app._add_url( + mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] + ) + + # Close all SSE streams on server shutdown so MCP clients see a + # clean stream end and can reconnect promptly. + atexit.register(shutdown_all_streams, sessions) + + logger.info( + "MCP routes registered at %s%s", + app.config.routes_pathname_prefix, + mcp_path, + ) + + +_STALE_SESSION_WARNING = ( + "[Warning: your session was not recognised" + " — the app may have restarted." + " Please call tools/list to refresh your tool list." + " Please ask the user to reconnect to the MCP server.]" +) + + +def _inject_warning(response_data: dict[str, Any], warning: str) -> None: + """Append a warning to a JSON-RPC response dict. + + For successful ``tools/call`` responses the warning is added as an + extra text content block so the agent sees it alongside the result. + For error responses the warning is appended to the error message. + Other responses (tools/list, resources/*) are left unchanged — the + JSON-RPC spec forbids extra top-level keys. + """ + # tools/call success: result has a "content" list + result = response_data.get("result") + if isinstance(result, dict) and isinstance(result.get("content"), list): + result["content"].append({"type": "text", "text": warning}) + return + + # Error response + error = response_data.get("error") + if isinstance(error, dict) and "message" in error: + error["message"] += " " + warning + + +def _handle_initialize() -> InitializeResult: + return InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + tools=ToolsCapability(listChanged=True), + resources=ResourcesCapability(), + ), + serverInfo=Implementation(name="Plotly Dash", version=__version__), + instructions=( + "This is a Dash web application. " + "Dash apps are stateless: calling a tool executes " + "a callback and returns its result to you, but does " + "NOT update the user's browser. " + "Use tool results to answer questions about what " + "the app would produce for given inputs." + ), + ) + + +def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: + """ + Process an MCP JSON-RPC message and return the response dict. + + Returns ``None`` for notifications (no ``id`` field). + """ + method = data.get("method", "") + params = data.get("params", {}) or {} + request_id = data.get("id") + + app = get_app() + if not hasattr(app, "mcp_callback_map"): + app.mcp_callback_map = CallbackAdapterCollection(app) + + mcp_methods = { + "initialize": _handle_initialize, + "tools/list": lambda: list_tools(), + "tools/call": lambda: call_tool( + params.get("name", ""), params.get("arguments", {}) + ), + "resources/list": lambda: list_resources(), + "resources/templates/list": lambda: list_resource_templates(), + "resources/read": lambda: read_resource(params.get("uri", "")), + } + + try: + handler = mcp_methods.get(method) + if handler is None: + if method.startswith("notifications/"): + return None + raise ValueError(f"Unknown method: {method}") + + result = handler() + + response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=result.model_dump(exclude_none=True, mode="json"), + ) + return response.model_dump(exclude_none=True, mode="json") + + except MCPError as e: + logger.error("MCP error: %s", e) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=e.code, message=str(e)), + ).model_dump(exclude_none=True) + except Exception as e: + logger.error("MCP error: %s", e, exc_info=True) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=-32603, message=f"{type(e).__name__}: {e}"), + ).model_dump(exclude_none=True) diff --git a/dash/mcp/_sse.py b/dash/mcp/_sse.py new file mode 100644 index 0000000000..4928dc68b2 --- /dev/null +++ b/dash/mcp/_sse.py @@ -0,0 +1,67 @@ +"""SSE stream generation and queue management.""" + +from __future__ import annotations + +import queue +from typing import Any + +from flask import Response + + +def create_sse_stream(sessions: dict[str, dict[str, Any]], session_id: str) -> Response: + """Create a Server-Sent Events stream for the given session. + + Stores a :class:`queue.Queue` in ``sessions[session_id]["sse_queue"]`` + and returns a Flask streaming ``Response``. The generator yields + events pushed to the queue, with keepalive comments every 30 seconds. + """ + event_queue: queue.Queue[str | None] = queue.Queue() + # Replace any prior SSE queue for this session (client reconnect). + sessions[session_id]["sse_queue"] = event_queue + + def _generate(): + try: + while True: + try: + event = event_queue.get(timeout=30) + if event is None: + return # Sentinel: server closing stream + yield f"event: message\ndata: {event}\n\n" + except queue.Empty: + yield ": keepalive\n\n" + except GeneratorExit: + pass + finally: + # Clean up queue reference if it's still ours. + if sessions.get(session_id, {}).get("sse_queue") is event_queue: + sessions[session_id].pop("sse_queue", None) + + return Response( + _generate(), + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": session_id, + }, + ) + + +def close_sse_stream(session_data: dict[str, Any]) -> None: + """Send a sentinel to shut down the session's SSE stream cleanly.""" + sse_queue = session_data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(None) + except queue.Full: + pass + + +def shutdown_all_streams(sessions: dict[str, dict[str, Any]]) -> None: + """Close all active SSE streams. + + Called during server shutdown (via ``atexit``) so that connected + MCP clients see a clean stream end and can reconnect promptly. + """ + for session_data in list(sessions.values()): + close_sse_stream(session_data) diff --git a/dash/mcp/notifications/__init__.py b/dash/mcp/notifications/__init__.py new file mode 100644 index 0000000000..b1fe9e8665 --- /dev/null +++ b/dash/mcp/notifications/__init__.py @@ -0,0 +1,7 @@ +"""Server-initiated MCP notifications.""" + +from .notification_tools_changed import broadcast_tools_changed + +__all__ = [ + "broadcast_tools_changed", +] diff --git a/dash/mcp/notifications/notification_tools_changed.py b/dash/mcp/notifications/notification_tools_changed.py new file mode 100644 index 0000000000..1970667d1a --- /dev/null +++ b/dash/mcp/notifications/notification_tools_changed.py @@ -0,0 +1,30 @@ +"""Tool list change notifications.""" + +from __future__ import annotations + +import json +import queue +from typing import Any + + +def broadcast_tools_changed( + sessions: dict[str, dict[str, Any]], +) -> None: + """Push a tools/list_changed notification to all active SSE streams. + + Not called automatically yet — available for future hot-reload + or dynamic callback registration. + """ + notification = json.dumps( + { + "jsonrpc": "2.0", + "method": "notifications/tools/list_changed", + } + ) + for data in sessions.values(): + sse_queue = data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(notification) + except queue.Full: + pass diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py index e69de29bb2..b14839f1e1 100644 --- a/dash/mcp/primitives/__init__.py +++ b/dash/mcp/primitives/__init__.py @@ -0,0 +1,17 @@ +from .resources import ( + list_resources, + list_resource_templates, + read_resource, +) +from .tools import ( + call_tool, + list_tools, +) + +__all__ = [ + call_tool, + list_resources, + list_resource_templates, + list_tools, + read_resource, +] diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 0304394f63..68a1813da1 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -35,8 +35,6 @@ def __init__(self, app): CallbackAdapter(callback_output_id=output_id) for output_id in self._tool_names_map ] - # TODO: enable_mcp_server() will replace this with a direct assignment on app - app.mcp_callback_map = self @staticmethod def _sanitize_name(name: str) -> str: diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py new file mode 100644 index 0000000000..0f212d1763 --- /dev/null +++ b/tests/integration/mcp/conftest.py @@ -0,0 +1,53 @@ +"""Shared helpers for MCP integration tests.""" + +import requests + + +def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): + headers = {"Content-Type": "application/json"} + if session_id: + headers["mcp-session-id"] = session_id + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers=headers, + timeout=5, + ) + + +def _mcp_session(server_url): + resp = _mcp_post(server_url, "initialize") + resp.raise_for_status() + return resp.headers["mcp-session-id"] + + +def _mcp_tools(server_url): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, "tools/list", session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json()["result"]["tools"] + + +def _mcp_call_tool(server_url, tool_name, arguments=None): + sid = _mcp_session(server_url) + resp = _mcp_post( + server_url, + "tools/call", + {"name": tool_name, "arguments": arguments or {}}, + session_id=sid, + request_id=2, + ) + resp.raise_for_status() + return resp.json() + + +def _mcp_method(server_url, method, params=None): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, method, params, session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json() diff --git a/tests/integration/mcp/primitives/resources/test_resources.py b/tests/integration/mcp/primitives/resources/test_resources.py new file mode 100644 index 0000000000..dfc1e09f9b --- /dev/null +++ b/tests/integration/mcp/primitives/resources/test_resources.py @@ -0,0 +1,51 @@ +"""Integration tests for MCP resources.""" + +import json + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_method + + +def test_resources_list_includes_layout(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method(dash_duo.server.url, "resources/list") + + assert "result" in result + uris = [r["uri"] for r in result["result"]["resources"]] + assert "dash://layout" in uris + + +def test_read_layout_resource(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="res-dd", options=["x", "y"], value="x"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method( + dash_duo.server.url, + "resources/read", + {"uri": "dash://layout"}, + ) + + assert "result" in result + layout = json.loads(result["result"]["contents"][0]["text"]) + assert layout["type"] == "Div" + children = layout["props"]["children"] + dd = next( + c for c in children if isinstance(c, dict) and c.get("type") == "Dropdown" + ) + assert dd["props"]["id"] == "res-dd" + assert dd["props"]["options"] == ["x", "y"] diff --git a/tests/integration/mcp/primitives/tools/test_callback_signatures.py b/tests/integration/mcp/primitives/tools/test_callback_signatures.py new file mode 100644 index 0000000000..db325f2046 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_callback_signatures.py @@ -0,0 +1,958 @@ +""" +Integration tests for all Dash callback signature types. + +Each test verifies that: +1. The MCP tool schema accurately reflects the callback's parameters +2. Calling the tool with those parameters produces the expected result + +Assertions are derived from the callback definition, not the implementation. + +See: https://dash.plotly.com/flexible-callback-signatures +""" + +from dash import Dash, Input, Output, State, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next(t for t in tools if t["name"] == name) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_positional_callback(dash_duo): + """Standard positional: Input("fruit", "value") → param named 'value'.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="fruit", options=["apple", "banana"], value="apple"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input → 1 param named "value" (from function signature) + # Returns string → Output("out", "children") + @app.callback(Output("out", "children"), Input("fruit", "value")) + def show_fruit(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Selected: apple") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_fruit") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"value"} + assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) + + # Tool description reflects initial state + value_desc = props["value"].get("description", "") + assert "value: 'apple'" in value_desc + assert "options: ['apple', 'banana']" in value_desc + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: apple" + + # MCP tool with different inputs + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: banana" + + +def test_positional_with_state(dash_duo): + """Positional with State: Input + State both appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input + 1 State → 2 params named "n_clicks" and "value" + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(n_clicks, value): + return f"Clicked {n_clicks} with {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Clicked None with hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"n_clicks", "value"} + assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) + + # Tool description reflects initial state + assert "value: 'hello'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked None with hello" + + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": 3, "value": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked 3 with world" + + +def test_multi_output_positional(dash_duo): + """Multi-output: returns tuple → both outputs updated in response.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="test"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + # Callback: 1 Input → 2 Outputs via tuple return + @app.callback( + Output("out1", "children"), + Output("out2", "children"), + Input("inp", "value"), + ) + def split_case(value): + return value.upper(), value.lower() + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out1", "TEST") + dash_duo.wait_for_text_to_equal("#out2", "test") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "split_case") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"value"} + + # Tool description reflects initial state + assert "value: 'test'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) + response = _get_response(result) + assert response["out1"]["children"] == "TEST" + assert response["out2"]["children"] == "test" + + +def test_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are param names.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="name-input", value="world"), + html.Div(id="out"), + ] + ) + + # Callback: dict keys "trigger" and "name" become param names + @app.callback( + Output("out", "children"), + inputs=dict(trigger=Input("btn", "n_clicks")), + state=dict(name=State("name-input", "value")), + ) + def greet(trigger, name): + return f"Hello, {name}!" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Hello, world!") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"trigger", "name"} + assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": None, "name": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, world!" + + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": 1, "name": "Dash"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, Dash!" + + +def test_dict_based_outputs(dash_duo): + """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="upper-out"), + html.Div(id="lower-out"), + ] + ) + + # Callback: dict output keys "upper" and "lower" map to components + @app.callback( + output=dict( + upper=Output("upper-out", "children"), + lower=Output("lower-out", "children"), + ), + inputs=dict(val=Input("inp", "value")), + ) + def transform(val): + return dict(upper=val.upper(), lower=val.lower()) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#upper-out", "HELLO") + dash_duo.wait_for_text_to_equal("#lower-out", "hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "transform") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"val"} + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) + response = _get_response(result) + assert response["upper-out"]["children"] == "HELLO" + assert response["lower-out"]["children"] == "hello" + + +def test_mixed_input_state_in_inputs(dash_duo): + """Mixed: State inside inputs=dict alongside Input → all appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="first", value="Jane"), + dcc.Input(id="last", value="Doe"), + html.Div(id="out"), + ] + ) + + # Callback: Input and State mixed in same dict → all keys are params + @app.callback( + Output("out", "children"), + inputs=dict( + clicks=Input("btn", "n_clicks"), + first=State("first", "value"), + last=State("last", "value"), + ), + ) + def full_name(clicks, first, last): + return f"{first} {last}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Jane Doe") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "full_name") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"clicks", "first", "last"} + assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": None, "first": "Jane", "last": "Doe"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "Jane Doe" + + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": 1, "first": "John", "last": "Smith"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "John Smith" + + +def test_tuple_grouped_inputs(dash_duo): + """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a", value="1"), + dcc.Input(id="b", value="2"), + html.Div(id="out"), + ] + ) + + # Callback: tuple group "pair" maps to 2 deps → 2 params named pair___ + @app.callback( + Output("out", "children"), + inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), + ) + def combine(pair): + return f"{pair[0]}+{pair[1]}" + + dash_duo.start_server(app) + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + props = tool["inputSchema"]["properties"] + + # Tuple expands: one param per dep, named with group prefix + component info + assert set(props.keys()) == {"pair_a__value", "pair_b__value"} + for schema in props.values(): + assert any(s.get("type") == "string" for s in schema["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"pair_a__value": "x", "pair_b__value": "y"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "x+y" + + +def test_initial_values_from_chained_callbacks(dash_duo): + """Querying components reflects post-initial-callback values. + + 3-link chain: country (default "France") → update_states → + state (should become "Ile-de-France") → update_cities → + city (should become "Paris"). + """ + DATA = { + "France": { + "Ile-de-France": ["Paris", "Versailles"], + "Provence": ["Marseille", "Nice"], + }, + "Germany": { + "Bavaria": ["Munich", "Nuremberg"], + "Berlin": ["Berlin"], + }, + } + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=list(DATA.keys()), value="France"), + dcc.Dropdown(id="state"), + dcc.Dropdown(id="city"), + ] + ) + + @app.callback( + Output("state", "options"), + Output("state", "value"), + Input("country", "value"), + ) + def update_states(country): + if not country: + return [], None + states = list(DATA[country].keys()) + return [{"label": s, "value": s} for s in states], states[0] + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("state", "value"), + Input("country", "value"), + ) + def update_cities(state, country): + if not state or not country: + return [], None + cities = DATA[country][state] + return [{"label": c, "value": c} for c in cities], cities[0] + + dash_duo.start_server(app) + + # Tool descriptions should reflect post-initial-callback state + tools = _mcp_tools(dash_duo.server.url) + update_cities_tool = _find_tool(tools, "update_cities") + state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( + "description", "" + ) + # state.value was set to "Ile-de-France" by update_states initial callback + assert "Ile-de-France" in state_desc + + # state.value should be "Ile-de-France" (first state for France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "state", "property": "value"}, + ) + state_props = result["result"]["structuredContent"]["properties"] + assert state_props["value"]["initial_value"] == "Ile-de-France" + + # city.value should be "Paris" (first city for Ile-de-France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "city", "property": "value"}, + ) + city_props = result["result"]["structuredContent"]["properties"] + assert city_props["value"]["initial_value"] == "Paris" + + +def test_dict_based_reordered_state_input(dash_duo): + """Dict-based callback with State before Input: call works, schema types correct. + + State is listed before Input in the dict. The callback should still + work correctly via MCP, and the schema types should match the + function annotations (name: str, trigger: int), not be swapped. + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="World"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(name=State("inp", "value"), trigger=Input("btn", "n_clicks")), + ) + def greet(name: str, trigger: int): + return f"Hello {name}" + + dash_duo.start_server(app) + + # First: verify the callback actually works with these args + result = _mcp_call_tool( + dash_duo.server.url, + "greet", + {"name": "Dash", "trigger": 1}, + ) + assert _get_response(result)["out"]["children"] == "Hello Dash" + + # Second: verify schema types match annotations + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + assert props["trigger"]["type"] == "integer" + assert props["name"]["type"] == "string" + + # Third: verify each param describes the correct component + trigger_desc = props["trigger"].get("description", "") + assert "number of times that this element has been clicked on" in trigger_desc + name_desc = props["name"].get("description", "") + assert "The value of the input" in name_desc + + +def test_pattern_matching_callback(dash_duo): + """Pattern-matching dict IDs: tool works with correct params and results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="hello"), + dcc.Input(id={"type": "field", "index": 1}, value="world"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + Input({"type": "field", "index": 1}, "value"), + ) + def combine(first, second): + return f"{first} {second}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + assert tool is not None + props = tool["inputSchema"]["properties"] + assert "first" in props + assert "second" in props + + # Verify initial output matches what the browser shows + dash_duo.wait_for_text_to_equal("#out", "hello world") + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "hello", "second": "world"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "hello world" + + # Verify with different values + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "foo", "second": "bar"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "foo bar" + + +def test_pattern_matching_with_all_wildcard(dash_duo): + """ALL wildcard: one callback receives values from all matching components.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "input", "index": 0}, value="alpha"), + dcc.Input(id={"type": "input", "index": 1}, value="beta"), + html.Div(id="summary"), + ] + ) + + @app.callback( + Output("summary", "children"), + Input({"type": "input", "index": ALL}, "value"), + ) + def summarize(values): + return ", ".join(v for v in values if v) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#summary", "alpha, beta") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") + assert tool is not None + + # Schema must describe values as an array of {id, property, value} objects + values_schema = tool["inputSchema"]["properties"]["values"] + assert ( + values_schema["type"] == "array" + ), f"ALL wildcard param should be typed as array, got: {values_schema}" + assert "items" in values_schema, "Array schema should include items definition" + items = values_schema["items"] + assert items["type"] == "object" + assert "id" in items["properties"] + assert "value" in items["properties"] + assert "Pattern-matching input (ALL)" in values_schema.get( + "description", "" + ), "ALL wildcard param description should explain the pattern-matching behavior" + + # MCP tool call with browser-like format: concrete IDs + values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "alpha", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "beta", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "alpha, beta" + + # Different values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "one", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "two", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "one, two" + + +def test_pattern_matching_mixed_outputs(dash_duo): + """Mixed outputs: one regular + one ALL wildcard in the same callback.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + dcc.Input(id={"type": "field", "index": 1}, value="b"), + html.Div(id={"type": "echo", "index": 0}), + html.Div(id={"type": "echo", "index": 1}), + html.Div(id="total"), + ] + ) + + @app.callback( + Output({"type": "echo", "index": ALL}, "children"), + Output("total", "children"), + Input({"type": "field", "index": ALL}, "value"), + ) + def echo_and_total(values): + echoes = [f"Echo: {v}" for v in values] + total = f"Total: {len(values)} items" + return echoes, total + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#total", "Total: 2 items") + + result = _mcp_call_tool( + dash_duo.server.url, + "echo_and_total", + { + "values": [ + { + "id": {"type": "field", "index": 0}, + "property": "value", + "value": "x", + }, + { + "id": {"type": "field", "index": 1}, + "property": "value", + "value": "y", + }, + ] + }, + ) + response = _get_response(result) + assert response["total"]["children"] == "Total: 2 items" + + +def test_pattern_matching_with_match_wildcard(dash_duo): + """MATCH wildcard: callback fires per-component with matching index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "NYC", + id={"type": "city-dd", "index": 0}, + ), + html.Div(id={"type": "city-out", "index": 0}), + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "LA", + id={"type": "city-dd", "index": 1}, + ), + html.Div(id={"type": "city-out", "index": 1}), + ] + ) + + @app.callback( + Output({"type": "city-out", "index": MATCH}, "children"), + Input({"type": "city-dd", "index": MATCH}, "value"), + ) + def show_city(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + assert tool is not None + + # Schema describes MATCH input + value_schema = tool["inputSchema"]["properties"]["value"] + assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") + + # Call with concrete ID for index 0 (MATCH takes a single entry, not an array) + result = _mcp_call_tool( + dash_duo.server.url, + "show_city", + { + "value": { + "id": {"type": "city-dd", "index": 0}, + "property": "value", + "value": "MTL", + } + }, + ) + response = _get_response(result) + # Find the output key containing "city-out" (Dash may serialize dict IDs differently) + out_key = next(k for k in response if "city-out" in k) + assert response[out_key]["children"] == "Selected: MTL" + + +def test_pattern_matching_with_allsmaller_wildcard(dash_duo): + """ALLSMALLER wildcard: receives values from components with smaller index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH, ALLSMALLER + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["France", "Germany", "Japan"], + "France", + id={"type": "country-dd", "index": 0}, + ), + html.Div(id={"type": "country-out", "index": 0}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Germany", + id={"type": "country-dd", "index": 1}, + ), + html.Div(id={"type": "country-out", "index": 1}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Japan", + id={"type": "country-dd", "index": 2}, + ), + html.Div(id={"type": "country-out", "index": 2}), + ] + ) + + @app.callback( + Output({"type": "country-out", "index": MATCH}, "children"), + Input({"type": "country-dd", "index": MATCH}, "value"), + Input({"type": "country-dd", "index": ALLSMALLER}, "value"), + ) + def show_countries(current, previous): + all_selected = [current] + list(reversed(previous)) + return f"All: {', '.join(all_selected)}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") + assert tool is not None + + # Schema describes both MATCH and ALLSMALLER inputs + props = tool["inputSchema"]["properties"] + assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") + assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( + "description", "" + ) + + # Call for index 2: MATCH is a single dict, ALLSMALLER is a list + result = _mcp_call_tool( + dash_duo.server.url, + "show_countries", + { + "current": { + "id": {"type": "country-dd", "index": 2}, + "property": "value", + "value": "Japan", + }, + "previous": [ + { + "id": {"type": "country-dd", "index": 0}, + "property": "value", + "value": "France", + }, + { + "id": {"type": "country-dd", "index": 1}, + "property": "value", + "value": "Germany", + }, + ], + }, + ) + response = _get_response(result) + out_key = next(k for k in response if "country-out" in k) + assert response[out_key]["children"] == "All: Japan, Germany, France" + + +def test_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default. + + The dropdown has value="original" in the layout. The callback has + prevent_initial_call=True so it doesn't run on page load. The MCP + tool description should show value: 'a' (layout default). + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b", "c"], value="a"), + html.Div(id="out", children="not yet"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + prevent_initial_call=True, + ) + def update(val): + return f"Changed to: {val}" + + dash_duo.start_server(app) + # Browser shows layout default — callback hasn't fired + dash_duo.wait_for_text_to_equal("#out", "not yet") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") + + # Tool description reflects layout default, not callback output + assert "value: 'a'" in val_desc + + +def test_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description. + + The city dropdown has value="default-city" in the layout. + update_city runs on page load (no prevent_initial_call) and + sets city.value to "Paris". The MCP tool should show "Paris" + as the default, not "default-city". + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city", options=[], value="default-city"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_city(country): + if country == "France": + return [{"label": "Paris", "value": "Paris"}], "Paris" + return [{"label": "Berlin", "value": "Berlin"}], "Berlin" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"City: {city}" + + dash_duo.start_server(app) + # Browser shows "Paris" — the initial callback overrode "default-city" + dash_duo.wait_for_text_to_equal("#out", "City: Paris") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") + + # Tool description should show the post-initial-callback value + assert "value: 'Paris'" in city_desc + assert "default-city" not in city_desc + + +def test_callback_context_triggered_id(dash_duo): + """Callbacks using dash.ctx.triggered_id work via MCP. + + Based on https://dash.plotly.com/determining-which-callback-input-changed + """ + from dash import ctx + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Button 1", id="btn-1"), + html.Button("Button 2", id="btn-2"), + html.Button("Button 3", id="btn-3"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn-1", "n_clicks"), + Input("btn-2", "n_clicks"), + Input("btn-3", "n_clicks"), + ) + def display(btn1, btn2, btn3): + if not ctx.triggered_id: + return "No button clicked yet" + return f"Last clicked: {ctx.triggered_id}" + + dash_duo.start_server(app) + + # Browser initial state: no button clicked + dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") + + # Tool should have all three button params + tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") + props = tool["inputSchema"]["properties"] + assert "btn1" in props + assert "btn2" in props + assert "btn3" in props + + # Click btn-2 via MCP — ctx.triggered_id should be "btn-2" + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": 1, "btn3": None}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-2" + + # Click btn-3 via MCP + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": None, "btn3": 5}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-3" + + +def test_no_output_callback_does_not_crash_tools_list(dash_duo): + """A callback with no Output should not crash tools/list. + + No-output callbacks use set_props for side effects. They produce + a hash-only output_id with no dot separator. + """ + from dash import set_props + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Log", id="log-btn"), + dcc.Dropdown(id="picker", options=["a", "b"], value="a"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("log-btn", "n_clicks"), prevent_initial_call=True) + def log_click(n): + set_props("display", {"children": f"Logged {n} clicks"}) + + @app.callback(Output("display", "children"), Input("picker", "value")) + def show_selection(val): + return f"Selected: {val}" + + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + tool_names = [t["name"] for t in tools] + + # show_selection should appear as a tool + assert "show_selection" in tool_names + + # log_click has no declared output but uses set_props — still a valid tool + assert "log_click" in tool_names + + # Call log_click — sideUpdate should show the set_props effect + result = _mcp_call_tool( + dash_duo.server.url, + "log_click", + {"n": 3}, + ) + structured = result["result"]["structuredContent"] + assert "sideUpdate" in structured + assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" + + # get_dash_component shows show_selection as modifier (declared output). + # log_click uses set_props which bypasses the declarative graph — + # its effect is only visible via sideUpdate in tool call results. + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "display", "property": "children"}, + ) + prop_info = result["result"]["structuredContent"]["properties"]["children"] + assert "show_selection" in prop_info["modified_by_tool"] diff --git a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py new file mode 100644 index 0000000000..4ad00641f8 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py @@ -0,0 +1,128 @@ +"""Integration test for duplicate callback outputs. + +Multiple callbacks can output to the same component.property +when using ``allow_duplicate=True``. The MCP server must handle +this correctly — both callbacks should appear as tools, and +calling either should work. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next((t for t in tools if t["name"] == name), None) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + # Query the component — should reflect initial callback (greet_by_first) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" diff --git a/tests/integration/mcp/primitives/tools/test_input_schemas.py b/tests/integration/mcp/primitives/tools/test_input_schemas.py new file mode 100644 index 0000000000..6ee3510ddd --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_input_schemas.py @@ -0,0 +1,66 @@ +""" +Integration tests for MCP tool schema generation. + +Starts a real Dash server via ``dash_duo`` and verifies that tools +are generated with correct inputSchema, descriptions, and labels. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + + # -- Test data: change these to update the test -- + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + # Find the callback tool + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + # -- Tool-level fields -- + assert func_name in tool["name"] + + # -- inputSchema structure -- + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + # -- Property schema: type + format + description -- + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + # description includes all source values (label, constraints, default) + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" diff --git a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py new file mode 100644 index 0000000000..97472a16d7 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py @@ -0,0 +1,54 @@ +"""Integration tests for the get_dash_component tool.""" + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/primitives/tools/test_tools_list.py b/tests/integration/mcp/primitives/tools/test_tools_list.py new file mode 100644 index 0000000000..dc3d977146 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tools_list.py @@ -0,0 +1,118 @@ +"""Integration tests for tools/list — naming, dedup, and spec compliance.""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py new file mode 100644 index 0000000000..7af88bfbff --- /dev/null +++ b/tests/integration/mcp/test_server.py @@ -0,0 +1,304 @@ +"""Integration tests for the MCP Streamable HTTP endpoint. + +These tests use Flask's test_client to exercise the HTTP transport layer +(POST/GET/DELETE at /_mcp), session management, content-type handling, +and route registration/configuration. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMCPEndpoint: + """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" + + def test_post_initialize_creates_session(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + def test_post_without_session_auto_assigns(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert "tools" in data["result"] + + def test_stale_session_error_includes_hint(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": "tools/call", + "id": 1, + "params": {"name": "no_such_tool", "arguments": {}}, + } + ), + content_type="application/json", + headers={"mcp-session-id": "old-session-from-before-restart"}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "session was not recognised" in data["error"]["message"] + assert "tools/list" in data["error"]["message"] + + def test_post_with_valid_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Use session for tools/list + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 200 + data = json.loads(r2.data) + assert "result" in data + assert "tools" in data["result"] + + def test_notification_returns_202(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Send notification (no id field) + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 202 + + def test_delete_terminates_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Delete + r2 = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 204 + # Post-delete requests still succeed + r3 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r3.status_code == 200 + + def test_delete_nonexistent_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_without_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 404 + + def test_get_with_stale_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_returns_sse_stream(self): + app = _make_app() + client = app.server.test_client() + # First create a session via POST initialize + init = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = init.headers["mcp-session-id"] + # GET with valid session returns SSE stream + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r.status_code == 200 + assert r.content_type == "text/event-stream" + assert r.headers.get("Cache-Control") == "no-cache" + + def test_post_rejects_wrong_content_type(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + def test_routes_not_registered_when_disabled(self): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + def test_routes_respect_pathname_prefix(self): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + def test_enable_mcp_env_var_false(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + def test_constructor_overrides_env_var(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py new file mode 100644 index 0000000000..93238faf19 --- /dev/null +++ b/tests/unit/mcp/test_server.py @@ -0,0 +1,92 @@ +"""Tests for MCP server (_server.py) — JSON-RPC message processing.""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestProcessMCPMessage: + def test_initialize(self): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + def test_initialize_advertises_list_changed(self): + app = _make_app() + result = _mcp(app, "initialize") + caps = result["result"]["capabilities"] + assert caps["tools"]["listChanged"] is True + + def test_tools_call(self): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + def test_tools_call_unknown_tool_returns_error(self): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + def test_unknown_method_returns_error(self): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + def test_notification_returns_none(self): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/tools/test_run_callback.py b/tests/unit/mcp/tools/test_run_callback.py new file mode 100644 index 0000000000..00f4e5b7b1 --- /dev/null +++ b/tests/unit/mcp/tools/test_run_callback.py @@ -0,0 +1,246 @@ +"""Tests for callback dispatch execution via MCP tools.""" + +from dash import Dash, Input, Output, State, dcc, html +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestRunCallback: + def test_multi_output(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [ + {"label": "b", "value": "b"} + ] + assert structured["response"]["out"]["children"] == "selected: b" + + def test_omitted_kwargs_default_to_none(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + def test_no_output_callback(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + from dash import set_props + + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + def test_prevent_update(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + def test_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "click", + "store": "data", + }, + "result", + ) + == "click-data" + ) + + def test_dict_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "x_val": "foo", + "y_val": "bar", + }, + "dict-out", + ) + == "foo-bar" + ) + + def test_positional_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert ( + _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + ) + + def test_dict_inputs_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "hey", + "kept": "saved", + }, + "ds-out", + ) + == "hey+saved" + ) From 395e75d1269f15003e87766ad0ad802975a095e6 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 9 Apr 2026 09:40:45 -0600 Subject: [PATCH 02/10] Enforce session management per MCP spec (404 for unknown sessions, 400 for missing session) --- dash/mcp/_server.py | 17 ++++++----- tests/integration/mcp/test_server.py | 43 ++++++---------------------- 2 files changed, 18 insertions(+), 42 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 1c6279290b..95d00578a2 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -116,23 +116,26 @@ def _handle_post() -> Response: request_id = data.get("id") session_id = request.headers.get("mcp-session-id") - stale_session = False if method == "initialize": session_id = _create_session() elif session_id and session_id not in sessions: - stale_session = True - sessions[session_id] = {} + return Response( + json.dumps({"error": "Session not found. Please reinitialize."}), + content_type="application/json", + status=404, + ) elif not session_id: - session_id = _create_session() + return Response( + json.dumps({"error": "Missing session ID. Send an initialize request first."}), + content_type="application/json", + status=400, + ) response_data = _process_mcp_message(data) if response_data is None: return Response("", status=202) - if stale_session: - _inject_warning(response_data, _STALE_SESSION_WARNING) - return Response( json.dumps(response_data), content_type="application/json", diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py index 7af88bfbff..8917d0f5ab 100644 --- a/tests/integration/mcp/test_server.py +++ b/tests/integration/mcp/test_server.py @@ -60,7 +60,7 @@ def test_post_initialize_creates_session(self): data = json.loads(r.data) assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - def test_post_without_session_auto_assigns(self): + def test_post_without_session_returns_400(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -70,12 +70,9 @@ def test_post_without_session_auto_assigns(self): ), content_type="application/json", ) - assert r.status_code == 200 - assert "mcp-session-id" in r.headers - data = json.loads(r.data) - assert "tools" in data["result"] + assert r.status_code == 400 - def test_stale_session_error_includes_hint(self): + def test_stale_session_returns_404(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -83,18 +80,15 @@ def test_stale_session_error_includes_hint(self): data=json.dumps( { "jsonrpc": "2.0", - "method": "tools/call", + "method": "tools/list", "id": 1, - "params": {"name": "no_such_tool", "arguments": {}}, + "params": {}, } ), content_type="application/json", headers={"mcp-session-id": "old-session-from-before-restart"}, ) - assert r.status_code == 200 - data = json.loads(r.data) - assert "session was not recognised" in data["error"]["message"] - assert "tools/list" in data["error"]["message"] + assert r.status_code == 404 def test_post_with_valid_session(self): app = _make_app() @@ -161,7 +155,7 @@ def test_delete_terminates_session(self): headers={"mcp-session-id": session_id}, ) assert r2.status_code == 204 - # Post-delete requests still succeed + # Post-delete requests return 404 r3 = client.post( f"/{MCP_PATH}", data=json.dumps( @@ -170,7 +164,7 @@ def test_delete_terminates_session(self): content_type="application/json", headers={"mcp-session-id": session_id}, ) - assert r3.status_code == 200 + assert r3.status_code == 404 def test_delete_nonexistent_session_returns_404(self): app = _make_app() @@ -196,27 +190,6 @@ def test_get_with_stale_session_returns_404(self): ) assert r.status_code == 404 - def test_get_returns_sse_stream(self): - app = _make_app() - client = app.server.test_client() - # First create a session via POST initialize - init = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = init.headers["mcp-session-id"] - # GET with valid session returns SSE stream - r = client.get( - f"/{MCP_PATH}", - headers={"mcp-session-id": session_id}, - ) - assert r.status_code == 200 - assert r.content_type == "text/event-stream" - assert r.headers.get("Cache-Control") == "no-cache" - def test_post_rejects_wrong_content_type(self): app = _make_app() client = app.server.test_client() From d11c57ec0a6c43685c4eb1cb98852940070633b4 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 15 Apr 2026 16:55:23 -0600 Subject: [PATCH 03/10] remove unused code --- dash/mcp/_server.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 95d00578a2..1060d3b27f 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -126,7 +126,9 @@ def _handle_post() -> Response: ) elif not session_id: return Response( - json.dumps({"error": "Missing session ID. Send an initialize request first."}), + json.dumps( + {"error": "Missing session ID. Send an initialize request first."} + ), content_type="application/json", status=400, ) @@ -175,35 +177,6 @@ def _handle_delete() -> Response: ) -_STALE_SESSION_WARNING = ( - "[Warning: your session was not recognised" - " — the app may have restarted." - " Please call tools/list to refresh your tool list." - " Please ask the user to reconnect to the MCP server.]" -) - - -def _inject_warning(response_data: dict[str, Any], warning: str) -> None: - """Append a warning to a JSON-RPC response dict. - - For successful ``tools/call`` responses the warning is added as an - extra text content block so the agent sees it alongside the result. - For error responses the warning is appended to the error message. - Other responses (tools/list, resources/*) are left unchanged — the - JSON-RPC spec forbids extra top-level keys. - """ - # tools/call success: result has a "content" list - result = response_data.get("result") - if isinstance(result, dict) and isinstance(result.get("content"), list): - result["content"].append({"type": "text", "text": warning}) - return - - # Error response - error = response_data.get("error") - if isinstance(error, dict) and "message" in error: - error["message"] += " " + warning - - def _handle_initialize() -> InitializeResult: return InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, From 35b1c188ae1cb53d6b8c9b8dce3a89e871147728 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 16 Apr 2026 18:31:33 -0600 Subject: [PATCH 04/10] Fix types for mypy --- dash/mcp/__init__.py | 2 +- dash/mcp/_server.py | 3 +- dash/mcp/primitives/__init__.py | 10 ++-- .../resource_clientside_callbacks.py | 5 +- .../resources/resource_components.py | 7 +-- .../primitives/resources/resource_layout.py | 5 +- .../resources/resource_page_layout.py | 3 +- .../primitives/resources/resource_pages.py | 5 +- dash/mcp/primitives/tools/callback_adapter.py | 47 +++++++++++-------- .../description_pattern_matching.py | 2 +- .../schema_component_proptypes_overrides.py | 9 ++-- .../input_schemas/schema_pattern_matching.py | 5 +- dash/mcp/primitives/tools/results/__init__.py | 2 +- .../tools/results/result_plotly_figure.py | 2 +- .../tools/tool_get_dash_component.py | 2 +- dash/types.py | 9 ++-- 16 files changed, 70 insertions(+), 48 deletions(-) diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py index 2677ea141b..2bc4757f13 100644 --- a/dash/mcp/__init__.py +++ b/dash/mcp/__init__.py @@ -3,5 +3,5 @@ from dash.mcp._server import enable_mcp_server __all__ = [ - enable_mcp_server, + "enable_mcp_server", ] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 1060d3b27f..24bbef4aeb 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -204,7 +204,8 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: """ method = data.get("method", "") params = data.get("params", {}) or {} - request_id = data.get("id") + _id = data.get("id") + request_id: str | int = _id if isinstance(_id, (str, int)) else "" app = get_app() if not hasattr(app, "mcp_callback_map"): diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py index b14839f1e1..e6b46a9af3 100644 --- a/dash/mcp/primitives/__init__.py +++ b/dash/mcp/primitives/__init__.py @@ -9,9 +9,9 @@ ) __all__ = [ - call_tool, - list_resources, - list_resource_templates, - list_tools, - read_resource, + "call_tool", + "list_resources", + "list_resource_templates", + "list_tools", + "read_resource", ] diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py index 127c0f9adc..a8c0a0076a 100644 --- a/dash/mcp/primitives/resources/resource_clientside_callbacks.py +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -10,6 +10,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._utils import clean_property_name, split_callback_id @@ -25,7 +26,7 @@ def get_resource(cls) -> Resource | None: if not _get_clientside_callbacks(): return None return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_clientside_callbacks", description=( "Actions the user can take manually in the browser " @@ -52,7 +53,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(data, default=str), ) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index 9d035a855f..1f80c8bda2 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -9,6 +9,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._layout_utils import traverse @@ -22,7 +23,7 @@ class ComponentsResource(MCPResourceProvider): @classmethod def get_resource(cls) -> Resource | None: return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_components", description=( "All components with IDs in the app layout. " @@ -41,7 +42,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: components = sorted( [ { - "id": str(comp.id), + "id": str(getattr(comp, "id", None)), "type": getattr(comp, "_type", type(comp).__name__), } for comp, _ in traverse(layout) @@ -53,7 +54,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(components), ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py index 753e2b9229..7659d1fd8f 100644 --- a/dash/mcp/primitives/resources/resource_layout.py +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -7,6 +7,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash import get_app from dash._utils import to_json @@ -20,7 +21,7 @@ class LayoutResource(MCPResourceProvider): @classmethod def get_resource(cls) -> Resource | None: return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_app_layout", description=( "Full component tree of the Dash app. " @@ -35,7 +36,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=to_json(app.get_layout()), ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py index 613f0b41b9..bbfca411bc 100644 --- a/dash/mcp/primitives/resources/resource_page_layout.py +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -7,6 +7,7 @@ ResourceTemplate, TextResourceContents, ) +from pydantic import AnyUrl from dash import html from dash._pages import PAGE_REGISTRY @@ -55,7 +56,7 @@ def read_resource(cls, uri: str) -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=uri, + uri=AnyUrl(uri), mimeType="application/json", text=to_json(page_layout), ) diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py index 27c39013f3..21fa27679f 100644 --- a/dash/mcp/primitives/resources/resource_pages.py +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -9,6 +9,7 @@ Resource, TextResourceContents, ) +from pydantic import AnyUrl from dash._pages import PAGE_REGISTRY @@ -23,7 +24,7 @@ def get_resource(cls) -> Resource | None: if not PAGE_REGISTRY: return None return Resource( - uri=cls.uri, + uri=AnyUrl(cls.uri), name="dash_app_pages", description=( "List of all pages in this multi-page Dash app " @@ -51,7 +52,7 @@ def read_resource(cls, uri: str = "") -> ReadResourceResult: return ReadResourceResult( contents=[ TextResourceContents( - uri=cls.uri, + uri=AnyUrl(cls.uri), mimeType="application/json", text=json.dumps(pages, default=str), ) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index c94ba32f38..3000541a06 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -10,7 +10,7 @@ import json import typing from functools import cached_property -from typing import Any +from typing import Any, cast from mcp.types import Tool @@ -27,6 +27,7 @@ CallbackDependency, CallbackExecutionBody, CallbackInput, + CallbackInputs, CallbackOutput, CallbackOutputTarget, WildcardId, @@ -361,9 +362,7 @@ def _param_annotations(self) -> list[Any | None]: return [hints.get(func_name) for func_name, _ in self._dep_param_map] -def _expand_dep( - dep: CallbackDependency, value: Any -) -> CallbackInput | list[CallbackInput]: +def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: """Attach a concrete value to a callback dependency to produce a valid callback input. For regular deps, returns ``{id, property, value}``. @@ -372,20 +371,20 @@ def _expand_dep( """ pattern = parse_wildcard_id(dep.get("id", "")) if pattern is None: - return {**dep, "value": value} + return CallbackInput(id=dep["id"], property=dep["property"], value=value) # LLM provides browser-like format if isinstance(value, list): - return value + return cast(list[CallbackInput], value) if isinstance(value, dict) and "id" in value: - return value - return {**dep, "value": value} + return cast(CallbackInput, value) + return CallbackInput(id=dep["id"], property=dep["property"], value=value) def _expand_output_spec( output_id: str, cb_info: dict, - resolved_inputs: list[CallbackInput], + resolved_inputs: list[CallbackInputs], ) -> CallbackOutputTarget | list[CallbackOutputTarget]: """Build the outputs spec, expanding wildcards to concrete IDs. @@ -408,15 +407,19 @@ def _expand_output_spec( if pattern is not None: concrete_ids = _derive_output_ids(pattern, resolved_inputs) if not concrete_ids: - concrete_ids = [comp.id for comp in find_matching_components(pattern)] - expanded = [{"id": cid, "property": prop} for cid in concrete_ids] + concrete_ids = [ + getattr(comp, "id") for comp in find_matching_components(pattern) + ] + expanded: list[CallbackDependency] = [ + CallbackDependency(id=cid, property=prop) for cid in concrete_ids + ] # ALL/ALLSMALLER → nested list; MATCH → single dict if len(expanded) == 1: results.append(expanded[0]) else: results.append(expanded) else: - results.append({"id": pid, "property": prop}) + results.append(CallbackDependency(id=pid, property=prop)) # Mirror the Dash renderer: single-output callbacks send a bare dict, # multi-output callbacks send a list. The framework's output value @@ -428,7 +431,7 @@ def _expand_output_spec( def _derive_output_ids( output_pattern: WildcardId, - resolved_inputs: list[CallbackInput], + resolved_inputs: list[CallbackInputs], ) -> list[WildcardId] | None: """Derive concrete output IDs from the resolved input entries. @@ -457,15 +460,19 @@ def _substitute(item_id: WildcardId) -> WildcardId | None: if isinstance(entry, list) and entry: concrete_ids = [] for item in entry: - out = _substitute(item.get("id")) - if out: - concrete_ids.append(out) + item_id = item.get("id") + if isinstance(item_id, dict): + out = _substitute(item_id) + if out: + concrete_ids.append(out) if concrete_ids: return concrete_ids # MATCH: single {id, property, value} dict - elif isinstance(entry, dict) and isinstance(entry.get("id"), dict): - out = _substitute(entry["id"]) - if out: - return [out] + elif isinstance(entry, dict): + entry_id = entry.get("id") + if isinstance(entry_id, dict): + out = _substitute(entry_id) + if out: + return [out] return None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py index 221423aa50..d9a1a5a26a 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -23,7 +23,7 @@ def describe(cls, param: MCPInput) -> list[str]: return [] wildcard_key, wildcard_type = _find_wildcard(dep_id) - if wildcard_key is None: + if wildcard_key is None or wildcard_type is None: return [] non_wildcard = {k: v for k, v in dep_id.items() if k != wildcard_key} diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py index 984d493d69..ee42042415 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Callable, Union from dash.mcp.types import MCPInput @@ -44,7 +44,10 @@ def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: return refined -_OVERRIDES: dict[tuple[str, str], dict[str, Any] | callable] = { +_OVERRIDES: dict[ + tuple[str, str], + Union[dict[str, Any], Callable[[MCPInput], dict[str, Any] | None]], +] = { ("DatePickerSingle", "date"): _DATE_SCHEMA, ("DatePickerRange", "start_date"): _DATE_SCHEMA, ("DatePickerRange", "end_date"): _DATE_SCHEMA, @@ -65,7 +68,7 @@ class OverrideSchema(InputSchemaSource): @classmethod def get_schema(cls, param: MCPInput) -> dict[str, Any] | None: - key = (param.get("component_type"), param["property"]) + key = (param.get("component_type") or "", param["property"]) override = _OVERRIDES.get(key) if override is None: return None diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py index 52e16cf58b..093dc197b8 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -66,7 +66,10 @@ def _get_wildcard_type(dep_id: dict) -> str | None: def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: """Infer the JSON Schema for the ``value`` field from a matching component.""" - matches = find_matching_components(parse_wildcard_id(param["component_id"])) + pattern = parse_wildcard_id(param["component_id"]) + if pattern is None: + return None + matches = find_matching_components(pattern) if not matches: return None diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index ae3517919c..09e86410a7 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -48,5 +48,5 @@ def format_callback_response( return CallToolResult( content=content, - structuredContent=response, + structuredContent=dict(response), ) diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py index d3b98376f9..382acde422 100644 --- a/dash/mcp/primitives/tools/results/result_plotly_figure.py +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -6,7 +6,7 @@ import logging from typing import Any -import plotly.graph_objects as go +import plotly.graph_objects as go # type: ignore[import-untyped] from mcp.types import ImageContent, TextContent from dash.mcp.types import MCPOutput diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 69b6276d5a..f03b93293f 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -127,5 +127,5 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: content=[ TextContent(type="text", text=json.dumps(structured, default=str)) ], - structuredContent=structured, + structuredContent=dict(structured), ) diff --git a/dash/types.py b/dash/types.py index cbc94b8151..9da246b16c 100644 --- a/dash/types.py +++ b/dash/types.py @@ -71,11 +71,14 @@ class CallbackInput(TypedDict): value: Any +CallbackInputs = Union[CallbackInput, List[CallbackInput]] + + class CallbackExecutionBody(TypedDict): output: str - outputs: List[CallbackOutputTarget] - inputs: List[CallbackInput] - state: List[CallbackInput] + outputs: Union[CallbackOutputTarget, List[CallbackOutputTarget]] + inputs: List[CallbackInputs] + state: List[CallbackInputs] changedPropIds: List[str] From 9863f08fa548819dfd365352fbf98bbe222936ae Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 20 Apr 2026 15:30:16 -0600 Subject: [PATCH 05/10] Remove unused SSE code for initial MCP implementation --- dash/mcp/_server.py | 87 +++---------- dash/mcp/_sse.py | 67 ---------- dash/mcp/notifications/__init__.py | 7 -- .../notification_tools_changed.py | 30 ----- tests/integration/mcp/conftest.py | 22 +--- tests/integration/mcp/test_server.py | 116 ++---------------- tests/unit/mcp/test_server.py | 6 - 7 files changed, 30 insertions(+), 305 deletions(-) delete mode 100644 dash/mcp/_sse.py delete mode 100644 dash/mcp/notifications/__init__.py delete mode 100644 dash/mcp/notifications/notification_tools_changed.py diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 24bbef4aeb..c00ec6c398 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -2,10 +2,8 @@ from __future__ import annotations -import atexit import json import logging -import uuid from typing import TYPE_CHECKING, Any from flask import Response, request @@ -30,11 +28,6 @@ ) from dash.version import __version__ -from dash.mcp._sse import ( - close_sse_stream, - create_sse_stream, - shutdown_all_streams, -) from dash.mcp.primitives import ( call_tool, list_resource_templates, @@ -50,25 +43,7 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None: - """ - Add MCP routes to a Dash/Flask app. - - Registers a single Streamable HTTP endpoint for the MCP protocol. - Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied - automatically. - - Args: - app: The Dash application instance. - mcp_path: Route prefix for MCP endpoints. - """ - # Session storage: session_id -> metadata - sessions: dict[str, dict[str, Any]] = {} - - def _create_session() -> str: - sid = str(uuid.uuid4()) - sessions[sid] = {} - return sid - + """Add MCP routes to a Dash/Flask app.""" # -- Streamable HTTP endpoint -------------------------------------------- def mcp_handler() -> Response: @@ -85,14 +60,13 @@ def mcp_handler() -> Response: ) def _handle_get() -> Response: - session_id = request.headers.get("mcp-session-id") - if not session_id or session_id not in sessions: - return Response( - json.dumps({"error": "Session not found"}), - content_type="application/json", - status=404, - ) - return create_sse_stream(sessions, session_id) + # MCP spec allows servers to opt out of GET-initiated SSE streams + # by returning 405. We don't push server-initiated events. + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) def _handle_post() -> Response: content_type = request.content_type or "" @@ -112,27 +86,6 @@ def _handle_post() -> Response: status=400, ) - method = data.get("method", "") - request_id = data.get("id") - session_id = request.headers.get("mcp-session-id") - - if method == "initialize": - session_id = _create_session() - elif session_id and session_id not in sessions: - return Response( - json.dumps({"error": "Session not found. Please reinitialize."}), - content_type="application/json", - status=404, - ) - elif not session_id: - return Response( - json.dumps( - {"error": "Missing session ID. Send an initialize request first."} - ), - content_type="application/json", - status=400, - ) - response_data = _process_mcp_message(data) if response_data is None: @@ -142,21 +95,15 @@ def _handle_post() -> Response: json.dumps(response_data), content_type="application/json", status=200, - headers={"mcp-session-id": session_id}, ) def _handle_delete() -> Response: - session_id = request.headers.get("mcp-session-id") - if not session_id or session_id not in sessions: - return Response( - json.dumps({"error": "Session not found"}), - content_type="application/json", - status=404, - ) - close_sse_stream(sessions[session_id]) - del sessions[session_id] - logger.info("MCP session terminated: %s", session_id) - return Response("", status=204) + # No sessions to terminate — server is stateless. + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) # -- Register routes ----------------------------------------------------- @@ -166,10 +113,6 @@ def _handle_delete() -> Response: mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] ) - # Close all SSE streams on server shutdown so MCP clients see a - # clean stream end and can reconnect promptly. - atexit.register(shutdown_all_streams, sessions) - logger.info( "MCP routes registered at %s%s", app.config.routes_pathname_prefix, @@ -181,7 +124,7 @@ def _handle_initialize() -> InitializeResult: return InitializeResult( protocolVersion=LATEST_PROTOCOL_VERSION, capabilities=ServerCapabilities( - tools=ToolsCapability(listChanged=True), + tools=ToolsCapability(listChanged=False), resources=ResourcesCapability(), ), serverInfo=Implementation(name="Plotly Dash", version=__version__), diff --git a/dash/mcp/_sse.py b/dash/mcp/_sse.py deleted file mode 100644 index 4928dc68b2..0000000000 --- a/dash/mcp/_sse.py +++ /dev/null @@ -1,67 +0,0 @@ -"""SSE stream generation and queue management.""" - -from __future__ import annotations - -import queue -from typing import Any - -from flask import Response - - -def create_sse_stream(sessions: dict[str, dict[str, Any]], session_id: str) -> Response: - """Create a Server-Sent Events stream for the given session. - - Stores a :class:`queue.Queue` in ``sessions[session_id]["sse_queue"]`` - and returns a Flask streaming ``Response``. The generator yields - events pushed to the queue, with keepalive comments every 30 seconds. - """ - event_queue: queue.Queue[str | None] = queue.Queue() - # Replace any prior SSE queue for this session (client reconnect). - sessions[session_id]["sse_queue"] = event_queue - - def _generate(): - try: - while True: - try: - event = event_queue.get(timeout=30) - if event is None: - return # Sentinel: server closing stream - yield f"event: message\ndata: {event}\n\n" - except queue.Empty: - yield ": keepalive\n\n" - except GeneratorExit: - pass - finally: - # Clean up queue reference if it's still ours. - if sessions.get(session_id, {}).get("sse_queue") is event_queue: - sessions[session_id].pop("sse_queue", None) - - return Response( - _generate(), - content_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "mcp-session-id": session_id, - }, - ) - - -def close_sse_stream(session_data: dict[str, Any]) -> None: - """Send a sentinel to shut down the session's SSE stream cleanly.""" - sse_queue = session_data.get("sse_queue") - if sse_queue is not None: - try: - sse_queue.put_nowait(None) - except queue.Full: - pass - - -def shutdown_all_streams(sessions: dict[str, dict[str, Any]]) -> None: - """Close all active SSE streams. - - Called during server shutdown (via ``atexit``) so that connected - MCP clients see a clean stream end and can reconnect promptly. - """ - for session_data in list(sessions.values()): - close_sse_stream(session_data) diff --git a/dash/mcp/notifications/__init__.py b/dash/mcp/notifications/__init__.py deleted file mode 100644 index b1fe9e8665..0000000000 --- a/dash/mcp/notifications/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Server-initiated MCP notifications.""" - -from .notification_tools_changed import broadcast_tools_changed - -__all__ = [ - "broadcast_tools_changed", -] diff --git a/dash/mcp/notifications/notification_tools_changed.py b/dash/mcp/notifications/notification_tools_changed.py deleted file mode 100644 index 1970667d1a..0000000000 --- a/dash/mcp/notifications/notification_tools_changed.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Tool list change notifications.""" - -from __future__ import annotations - -import json -import queue -from typing import Any - - -def broadcast_tools_changed( - sessions: dict[str, dict[str, Any]], -) -> None: - """Push a tools/list_changed notification to all active SSE streams. - - Not called automatically yet — available for future hot-reload - or dynamic callback registration. - """ - notification = json.dumps( - { - "jsonrpc": "2.0", - "method": "notifications/tools/list_changed", - } - ) - for data in sessions.values(): - sse_queue = data.get("sse_queue") - if sse_queue is not None: - try: - sse_queue.put_nowait(notification) - except queue.Full: - pass diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 0f212d1763..b833bc1dea 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -3,10 +3,7 @@ import requests -def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): - headers = {"Content-Type": "application/json"} - if session_id: - headers["mcp-session-id"] = session_id +def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", json={ @@ -15,39 +12,28 @@ def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): "id": request_id, "params": params or {}, }, - headers=headers, + headers={"Content-Type": "application/json"}, timeout=5, ) -def _mcp_session(server_url): - resp = _mcp_post(server_url, "initialize") - resp.raise_for_status() - return resp.headers["mcp-session-id"] - - def _mcp_tools(server_url): - sid = _mcp_session(server_url) - resp = _mcp_post(server_url, "tools/list", session_id=sid, request_id=2) + resp = _mcp_post(server_url, "tools/list") resp.raise_for_status() return resp.json()["result"]["tools"] def _mcp_call_tool(server_url, tool_name, arguments=None): - sid = _mcp_session(server_url) resp = _mcp_post( server_url, "tools/call", {"name": tool_name, "arguments": arguments or {}}, - session_id=sid, - request_id=2, ) resp.raise_for_status() return resp.json() def _mcp_method(server_url, method, params=None): - sid = _mcp_session(server_url) - resp = _mcp_post(server_url, method, params, session_id=sid, request_id=2) + resp = _mcp_post(server_url, method, params) resp.raise_for_status() return resp.json() diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py index 8917d0f5ab..4f0d0fca00 100644 --- a/tests/integration/mcp/test_server.py +++ b/tests/integration/mcp/test_server.py @@ -45,7 +45,7 @@ def update_output(value): class TestMCPEndpoint: """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" - def test_post_initialize_creates_session(self): + def test_post_initialize_returns_protocol_version(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -56,11 +56,10 @@ def test_post_initialize_creates_session(self): content_type="application/json", ) assert r.status_code == 200 - assert "mcp-session-id" in r.headers data = json.loads(r.data) assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - def test_post_without_session_returns_400(self): + def test_post_tools_list(self): app = _make_app() client = app.server.test_client() r = client.post( @@ -70,125 +69,32 @@ def test_post_without_session_returns_400(self): ), content_type="application/json", ) - assert r.status_code == 400 - - def test_stale_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - { - "jsonrpc": "2.0", - "method": "tools/list", - "id": 1, - "params": {}, - } - ), - content_type="application/json", - headers={"mcp-session-id": "old-session-from-before-restart"}, - ) - assert r.status_code == 404 - - def test_post_with_valid_session(self): - app = _make_app() - client = app.server.test_client() - # Initialize to get session - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Use session for tools/list - r2 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} - ), - content_type="application/json", - headers={"mcp-session-id": session_id}, - ) - assert r2.status_code == 200 - data = json.loads(r2.data) + assert r.status_code == 200 + data = json.loads(r.data) assert "result" in data assert "tools" in data["result"] def test_notification_returns_202(self): app = _make_app() client = app.server.test_client() - # Initialize to get session - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Send notification (no id field) - r2 = client.post( + r = client.post( f"/{MCP_PATH}", data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), content_type="application/json", - headers={"mcp-session-id": session_id}, ) - assert r2.status_code == 202 + assert r.status_code == 202 - def test_delete_terminates_session(self): + def test_delete_returns_405(self): app = _make_app() client = app.server.test_client() - # Initialize - r1 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - session_id = r1.headers["mcp-session-id"] - # Delete - r2 = client.delete( - f"/{MCP_PATH}", - headers={"mcp-session-id": session_id}, - ) - assert r2.status_code == 204 - # Post-delete requests return 404 - r3 = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} - ), - content_type="application/json", - headers={"mcp-session-id": session_id}, - ) - assert r3.status_code == 404 + r = client.delete(f"/{MCP_PATH}") + assert r.status_code == 405 - def test_delete_nonexistent_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.delete( - f"/{MCP_PATH}", - headers={"mcp-session-id": "nonexistent"}, - ) - assert r.status_code == 404 - - def test_get_without_session_returns_404(self): + def test_get_returns_405(self): app = _make_app() client = app.server.test_client() r = client.get(f"/{MCP_PATH}") - assert r.status_code == 404 - - def test_get_with_stale_session_returns_404(self): - app = _make_app() - client = app.server.test_client() - r = client.get( - f"/{MCP_PATH}", - headers={"mcp-session-id": "nonexistent"}, - ) - assert r.status_code == 404 + assert r.status_code == 405 def test_post_rejects_wrong_content_type(self): app = _make_app() diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py index 93238faf19..23c99c50ad 100644 --- a/tests/unit/mcp/test_server.py +++ b/tests/unit/mcp/test_server.py @@ -51,12 +51,6 @@ def test_initialize(self): assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION assert "serverInfo" in result["result"] - def test_initialize_advertises_list_changed(self): - app = _make_app() - result = _mcp(app, "initialize") - caps = result["result"]["capabilities"] - assert caps["tools"]["listChanged"] is True - def test_tools_call(self): app = _make_app() tools = _tools_list(app) From 486a3dc86296a0a1620792ed1a07fa775d369c52 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 21 Apr 2026 11:41:01 -0600 Subject: [PATCH 06/10] add app-level config for exposing callback docstrings in MCP tools --- dash/_configs.py | 1 + dash/dash.py | 4 +++ tests/unit/mcp/tools/test_mcp_tools.py | 45 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/dash/_configs.py b/dash/_configs.py index f6df4001f1..cfc5552986 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -34,6 +34,7 @@ def load_dash_env_vars(): "DASH_COMPRESS", "DASH_MCP_ENABLED", "DASH_MCP_PATH", + "DASH_MCP_EXPOSE_DOCSTRINGS", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index 2cc1f37c61..dbdafd0f2c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -485,6 +485,7 @@ def __init__( # pylint: disable=too-many-statements csrf_header_name: str = "X-CSRFToken", enable_mcp: Optional[bool] = None, mcp_path: Optional[str] = None, + mcp_expose_docstrings: Optional[bool] = None, **obsolete, ): @@ -565,6 +566,9 @@ def __init__( # pylint: disable=too-many-statements hide_all_callbacks=False, csrf_token_name=csrf_token_name, csrf_header_name=csrf_header_name, + mcp_expose_docstrings=get_combined_config( + "mcp_expose_docstrings", mcp_expose_docstrings, False + ), ) self.config.set_read_only( [ diff --git a/tests/unit/mcp/tools/test_mcp_tools.py b/tests/unit/mcp/tools/test_mcp_tools.py index cacaf13b14..3255809982 100644 --- a/tests/unit/mcp/tools/test_mcp_tools.py +++ b/tests/unit/mcp/tools/test_mcp_tools.py @@ -309,6 +309,51 @@ def test_mcpt014_typed_annotation_narrows_schema(typed_app): assert tool.inputSchema["properties"]["val"]["type"] == "string" +def test_mcpt016_app_level_opt_in_exposes_docstrings(): + """Dash(mcp_expose_docstrings=True) exposes docstrings for all callbacks.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """intentionally-exposed callback docstring text for the LLM""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "intentionally-exposed callback docstring text for the LLM" in tool.description + ) + + +def test_mcpt017_per_callback_false_overrides_app_level_opt_in(): + """Per-callback mcp_expose_docstring=False wins over app-level opt-in.""" + app = Dash(__name__, mcp_expose_docstrings=True) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + mcp_expose_docstring=False, + ) + def update(val): + """sensitive callback docstring text that must not leak to LLMs""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert ( + "sensitive callback docstring text that must not leak to LLMs" + not in tool.description + ) + + # --------------------------------------------------------------------------- # Tests — end-to-end Tool shape # --------------------------------------------------------------------------- From 1414cd4087cad0add38fc02d2e1f21b98c8c9cab Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 21 Apr 2026 11:48:43 -0600 Subject: [PATCH 07/10] Disable MCP server by default --- dash/dash.py | 2 +- tests/integration/mcp/conftest.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index dbdafd0f2c..fd84e6ef89 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -600,7 +600,7 @@ def __init__( # pylint: disable=too-many-statements self.title = title # MCP (Model Context Protocol) configuration - self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True) + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, False) _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") self._mcp_path = ( _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index b833bc1dea..c81ffceb48 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -1,8 +1,15 @@ """Shared helpers for MCP integration tests.""" +import pytest import requests +@pytest.fixture(autouse=True) +def _enable_mcp_for_integration_tests(monkeypatch): + """MCP is off by default; integration tests need it on.""" + monkeypatch.setenv("DASH_MCP_ENABLED", "true") + + def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From 14ffb6e794bbced7c08b7fb8d3015061a29992b3 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 22 Apr 2026 16:04:45 -0600 Subject: [PATCH 08/10] lint --- dash/_layout_utils.py | 2 +- dash/mcp/_server.py | 36 +++++++++---------- dash/mcp/primitives/tools/callback_adapter.py | 2 +- tests/integration/mcp/conftest.py | 6 ++++ 4 files changed, 26 insertions(+), 20 deletions(-) diff --git a/dash/_layout_utils.py b/dash/_layout_utils.py index fdca86edca..d421771afd 100644 --- a/dash/_layout_utils.py +++ b/dash/_layout_utils.py @@ -117,7 +117,7 @@ def _collect_components(value: Any) -> list[Component]: if isinstance(value, Component): return [value] if isinstance(value, (list, tuple)): - return [item for item in value if isinstance(item, (Component, list, tuple))] + return [item for item in value if isinstance(item, Component)] return [] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index c00ec6c398..07b35cd373 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -1,5 +1,9 @@ """Flask route setup, Streamable HTTP transport, and MCP message handling.""" +# pylint: disable=cyclic-import +# The MCP server imports dash primitives to dispatch callbacks, and dash +# lazy-imports this module to wire the MCP endpoint. Cycle is managed here. + from __future__ import annotations import json @@ -7,14 +11,6 @@ from typing import TYPE_CHECKING, Any from flask import Response, request - -from dash.mcp.types import MCPError - -if TYPE_CHECKING: - from dash import Dash - -from dash import get_app - from mcp.types import ( LATEST_PROTOCOL_VERSION, ErrorData, @@ -27,7 +23,8 @@ ToolsCapability, ) -from dash.version import __version__ +from dash import get_app +from dash._get_app import with_app_context_factory from dash.mcp.primitives import ( call_tool, list_resource_templates, @@ -38,6 +35,11 @@ from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, ) +from dash.mcp.types import MCPError +from dash.version import __version__ + +if TYPE_CHECKING: + from dash import Dash logger = logging.getLogger(__name__) @@ -77,9 +79,8 @@ def _handle_post() -> Response: status=415, ) - try: - data = request.get_json() - except Exception: + data = request.get_json(silent=True) + if data is None: return Response( json.dumps({"error": "Invalid JSON"}), content_type="application/json", @@ -107,8 +108,7 @@ def _handle_delete() -> Response: # -- Register routes ----------------------------------------------------- - from dash._get_app import with_app_context_factory - + # pylint: disable-next=protected-access app._add_url( mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] ) @@ -156,12 +156,12 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: mcp_methods = { "initialize": _handle_initialize, - "tools/list": lambda: list_tools(), + "tools/list": list_tools, "tools/call": lambda: call_tool( params.get("name", ""), params.get("arguments", {}) ), - "resources/list": lambda: list_resources(), - "resources/templates/list": lambda: list_resource_templates(), + "resources/list": list_resources, + "resources/templates/list": list_resource_templates, "resources/read": lambda: read_resource(params.get("uri", "")), } @@ -188,7 +188,7 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: id=request_id, error=ErrorData(code=e.code, message=str(e)), ).model_dump(exclude_none=True) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.error("MCP error: %s", e, exc_info=True) return JSONRPCError( jsonrpc="2.0", diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 3000541a06..8130c6a8a1 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -375,7 +375,7 @@ def _expand_dep(dep: CallbackDependency, value: Any) -> CallbackInputs: # LLM provides browser-like format if isinstance(value, list): - return cast(list[CallbackInput], value) + return cast("list[CallbackInput]", value) if isinstance(value, dict) and "id" in value: return cast(CallbackInput, value) return CallbackInput(id=dep["id"], property=dep["property"], value=value) diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index c81ffceb48..5030211ed8 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -1,8 +1,14 @@ """Shared helpers for MCP integration tests.""" +import sys + import pytest import requests +collect_ignore_glob = [] +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") + @pytest.fixture(autouse=True) def _enable_mcp_for_integration_tests(monkeypatch): From d961671588e9191cd05c6731b27789882b31ef8a Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 10:51:46 -0600 Subject: [PATCH 09/10] fix leaky state in tests --- tests/integration/mcp/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py index 5030211ed8..0db1b775e8 100644 --- a/tests/integration/mcp/conftest.py +++ b/tests/integration/mcp/conftest.py @@ -5,6 +5,8 @@ import pytest import requests +from dash import _get_app + collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") @@ -16,6 +18,17 @@ def _enable_mcp_for_integration_tests(monkeypatch): monkeypatch.setenv("DASH_MCP_ENABLED", "true") +@pytest.fixture(autouse=True) +def _reset_dash_app_state(): + """Reset Dash module-level state after each MCP test. + + TODO: this can be removed when 4.2 backend work lands + """ + yield + _get_app.APP = None + _get_app.app_context.set(None) + + def _mcp_post(server_url, method, params=None, request_id=1): return requests.post( f"{server_url}/_mcp", From bb2599108640d115b052a1ec9e54115fed14886a Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Thu, 23 Apr 2026 14:49:31 -0600 Subject: [PATCH 10/10] Refactor unit tests to conform to existing test patterns --- .../tools/test_duplicate_outputs.py | 128 ----- .../primitives/tools/test_input_schemas.py | 66 --- .../tools/test_tool_get_dash_component.py | 54 -- .../mcp/primitives/tools/test_tools_list.py | 118 ---- ...tures.py => test_mcp_callback_behavior.py} | 508 ++++++++++++++---- tests/integration/mcp/test_mcp_endpoint.py | 189 +++++++ ...est_resources.py => test_mcp_resources.py} | 6 +- tests/integration/mcp/test_server.py | 183 ------- tests/unit/mcp/test_mcp_server.py | 99 ++++ tests/unit/mcp/test_server.py | 86 --- tests/unit/mcp/tools/test_mcp_run_callback.py | 253 +++++++++ tests/unit/mcp/tools/test_run_callback.py | 246 --------- 12 files changed, 947 insertions(+), 989 deletions(-) delete mode 100644 tests/integration/mcp/primitives/tools/test_duplicate_outputs.py delete mode 100644 tests/integration/mcp/primitives/tools/test_input_schemas.py delete mode 100644 tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py delete mode 100644 tests/integration/mcp/primitives/tools/test_tools_list.py rename tests/integration/mcp/{primitives/tools/test_callback_signatures.py => test_mcp_callback_behavior.py} (66%) create mode 100644 tests/integration/mcp/test_mcp_endpoint.py rename tests/integration/mcp/{primitives/resources/test_resources.py => test_mcp_resources.py} (86%) delete mode 100644 tests/integration/mcp/test_server.py create mode 100644 tests/unit/mcp/test_mcp_server.py delete mode 100644 tests/unit/mcp/test_server.py create mode 100644 tests/unit/mcp/tools/test_mcp_run_callback.py delete mode 100644 tests/unit/mcp/tools/test_run_callback.py diff --git a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py deleted file mode 100644 index 4ad00641f8..0000000000 --- a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Integration test for duplicate callback outputs. - -Multiple callbacks can output to the same component.property -when using ``allow_duplicate=True``. The MCP server must handle -this correctly — both callbacks should appear as tools, and -calling either should work. -""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools - - -def _find_tool(tools, name): - return next((t for t in tools if t["name"] == name), None) - - -def _get_response(result): - return result["result"]["structuredContent"]["response"] - - -def test_duplicate_outputs_both_tools_listed(dash_duo): - """Both callbacks outputting to the same component appear as tools.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - - first_tool = _find_tool(tools, "greet_by_first") - last_tool = _find_tool(tools, "greet_by_last") - - assert first_tool is not None, "greet_by_first should be listed" - assert last_tool is not None, "greet_by_last should be listed" - - -def test_duplicate_outputs_both_callable(dash_duo): - """Both callbacks can be called and produce correct results.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - - result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) - assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" - - result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) - assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" - - -def test_duplicate_outputs_find_by_output_returns_primary(dash_duo): - """find_by_output returns the primary (non-duplicate) callback.""" - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="first-name", value="Jane"), - dcc.Input(id="last-name", value="Doe"), - html.Div(id="greeting"), - ] - ) - - @app.callback( - Output("greeting", "children"), - Input("first-name", "value"), - ) - def greet_by_first(first): - return f"Hello, {first}!" - - @app.callback( - Output("greeting", "children", allow_duplicate=True), - Input("last-name", "value"), - prevent_initial_call=True, - ) - def greet_by_last(last): - return f"Hi, {last}!" - - dash_duo.start_server(app) - - # Query the component — should reflect initial callback (greet_by_first) - result = _mcp_call_tool( - dash_duo.server.url, - "get_dash_component", - {"component_id": "greeting", "property": "children"}, - ) - structured = result["result"]["structuredContent"] - assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" diff --git a/tests/integration/mcp/primitives/tools/test_input_schemas.py b/tests/integration/mcp/primitives/tools/test_input_schemas.py deleted file mode 100644 index 6ee3510ddd..0000000000 --- a/tests/integration/mcp/primitives/tools/test_input_schemas.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Integration tests for MCP tool schema generation. - -Starts a real Dash server via ``dash_duo`` and verifies that tools -are generated with correct inputSchema, descriptions, and labels. -""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_tools - - -def test_mcp_tool_with_label_and_date_picker_schema(dash_duo): - """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" - - # -- Test data: change these to update the test -- - label_text = "Departure Date" - component_id = "dp" - min_date = "2020-01-01" - max_date = "2025-12-31" - default_date = "2024-06-15" - func_name = "select_date" - param_name = "date" # function parameter name - - app = Dash(__name__) - app.layout = html.Div( - [ - html.Label(label_text, htmlFor=component_id), - dcc.DatePickerSingle( - id=component_id, - min_date_allowed=min_date, - max_date_allowed=max_date, - date=default_date, - ), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input(component_id, "date")) - def select_date(date): - return f"Selected: {date}" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - - # Find the callback tool - tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) - - # -- Tool-level fields -- - assert func_name in tool["name"] - - # -- inputSchema structure -- - schema = tool["inputSchema"] - assert schema["type"] == "object" - assert param_name in schema["required"] - assert param_name in schema["properties"] - - # -- Property schema: type + format + description -- - prop = schema["properties"][param_name] - assert prop["type"] == "string" - assert prop["format"] == "date" - - # description includes all source values (label, constraints, default) - desc = prop["description"] - for expected in (label_text, min_date, max_date, default_date): - assert expected in desc, f"Expected {expected!r} in description: {desc!r}" diff --git a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py deleted file mode 100644 index 97472a16d7..0000000000 --- a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Integration tests for the get_dash_component tool.""" - -from dash import Dash, dcc, html - -from tests.integration.mcp.conftest import _mcp_call_tool - -EXPECTED_DROPDOWN_OPTIONS = { - "component_id": "my-dropdown", - "component_type": "Dropdown", - "label": None, - "properties": { - "options": { - "initial_value": [ - {"label": "New York", "value": "NYC"}, - {"label": "Montreal", "value": "MTL"}, - ], - "modified_by_tool": [], - "input_to_tool": [], - }, - }, -} - - -def test_query_component_returns_structured_output(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown( - id="my-dropdown", - options=[ - {"label": "New York", "value": "NYC"}, - {"label": "Montreal", "value": "MTL"}, - ], - value="NYC", - ), - ] - ) - - dash_duo.start_server(app) - - result = _mcp_call_tool( - dash_duo.server.url, - "get_dash_component", - {"component_id": "my-dropdown", "property": "options"}, - ) - - assert "result" in result, f"Expected result in response: {result}" - structured = result["result"]["structuredContent"] - assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] - assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] - assert ( - structured["properties"]["options"] - == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] - ) diff --git a/tests/integration/mcp/primitives/tools/test_tools_list.py b/tests/integration/mcp/primitives/tools/test_tools_list.py deleted file mode 100644 index dc3d977146..0000000000 --- a/tests/integration/mcp/primitives/tools/test_tools_list.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Integration tests for tools/list — naming, dedup, and spec compliance.""" - -from dash import Dash, Input, Output, dcc, html - -from tests.integration.mcp.conftest import _mcp_tools - - -def test_tool_names_within_64_chars(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a"], value="a"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("dd", "value")) - def update(val): - return val - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - for param_name in tool.get("inputSchema", {}).get("properties", {}): - assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" - - -def test_long_callback_ids_within_64_chars(dash_duo): - app = Dash(__name__) - long_id = "a" * 120 - app.layout = html.Div( - [ - dcc.Input(id=long_id, value="test"), - html.Div(id=f"{long_id}-output"), - ] - ) - - @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) - def process(val): - return val - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - - -def test_pattern_matching_ids_within_64_chars(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div( - [ - dcc.Input( - id={"type": "filter-input", "index": i, "category": "primary"}, - value=f"val-{i}", - ) - for i in range(3) - ] - ), - html.Div(id="pm-output"), - ] - ) - - @app.callback( - Output("pm-output", "children"), - Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), - ) - def filter_update(v0): - return str(v0) - - dash_duo.start_server(app) - for tool in _mcp_tools(dash_duo.server.url): - assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" - - -def test_duplicate_func_names_produce_unique_tools(dash_duo): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd1", options=["a"], value="a"), - html.Div(id="dd1-output"), - dcc.Dropdown(id="dd2", options=["b"], value="b"), - html.Div(id="dd2-output"), - dcc.Dropdown(id="dd3", options=["c"], value="c"), - html.Div(id="dd3-output"), - ] - ) - - @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) - def cb(value): - return f"first: {value}" - - @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) - def cb(value): # noqa: F811 - return f"second: {value}" - - @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) - def cb(value): # noqa: F811 - return f"third: {value}" - - dash_duo.start_server(app) - tools = _mcp_tools(dash_duo.server.url) - cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] - tool_names = [t["name"] for t in cb_tools] - - assert ( - len(tool_names) == 3 - ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" - assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" - - -def test_builtin_tools_always_present(dash_duo): - app = Dash(__name__) - app.layout = html.Div(id="root") - - dash_duo.start_server(app) - tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] - assert "get_dash_component" in tool_names diff --git a/tests/integration/mcp/primitives/tools/test_callback_signatures.py b/tests/integration/mcp/test_mcp_callback_behavior.py similarity index 66% rename from tests/integration/mcp/primitives/tools/test_callback_signatures.py rename to tests/integration/mcp/test_mcp_callback_behavior.py index db325f2046..7778111386 100644 --- a/tests/integration/mcp/primitives/tools/test_callback_signatures.py +++ b/tests/integration/mcp/test_mcp_callback_behavior.py @@ -1,29 +1,56 @@ +"""Callback behaviors surfaced through MCP tools (end-to-end). + +Covers the full pipeline — a real Dash server via ``dash_duo`` + the MCP +HTTP endpoint — for every callback signature variant and the surrounding +tool-list conventions: + +- Positional, dict-based, and tuple-grouped ``inputs`` / ``state`` / + ``output`` forms. +- ``State``, multi-output, ``PreventUpdate``-style no-output callbacks, + ``ctx.triggered_id``, pattern-matching (``ALL``/``MATCH``/``ALLSMALLER``). +- Initial values: ``prevent_initial_call`` vs. initial-callback overrides. +- Duplicate outputs (``allow_duplicate=True``) appearing as separate tools. +- ``tools/list`` naming rules (64-char limit, uniqueness, built-ins). +- A representative input-schema smoke test (label + DatePicker). +- ``get_dash_component`` structured output via HTTP. """ -Integration tests for all Dash callback signature types. -Each test verifies that: -1. The MCP tool schema accurately reflects the callback's parameters -2. Calling the tool with those parameters produces the expected result +from dash import ( + ALL, + ALLSMALLER, + MATCH, + Dash, + Input, + Output, + State, + ctx, + dcc, + html, + set_props, +) -Assertions are derived from the callback definition, not the implementation. - -See: https://dash.plotly.com/flexible-callback-signatures -""" +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools -from dash import Dash, Input, Output, State, dcc, html -from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- def _find_tool(tools, name): - return next(t for t in tools if t["name"] == name) + return next((t for t in tools if t["name"] == name), None) def _get_response(result): return result["result"]["structuredContent"]["response"] -def test_positional_callback(dash_duo): +# --------------------------------------------------------------------------- +# Callback signatures — positional, multi-output, State, dict-based, tuples +# --------------------------------------------------------------------------- + + +def test_mcpb001_positional_callback(dash_duo): """Standard positional: Input("fruit", "value") → param named 'value'.""" app = Dash(__name__) app.layout = html.Div( @@ -33,8 +60,6 @@ def test_positional_callback(dash_duo): ] ) - # Callback: 1 Input → 1 param named "value" (from function signature) - # Returns string → Output("out", "children") @app.callback(Output("out", "children"), Input("fruit", "value")) def show_fruit(value): return f"Selected: {value}" @@ -48,23 +73,20 @@ def show_fruit(value): assert set(props.keys()) == {"value"} assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) - # Tool description reflects initial state value_desc = props["value"].get("description", "") assert "value: 'apple'" in value_desc assert "options: ['apple', 'banana']" in value_desc - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) response = _get_response(result) assert response["out"]["children"] == "Selected: apple" - # MCP tool with different inputs result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) response = _get_response(result) assert response["out"]["children"] == "Selected: banana" -def test_positional_with_state(dash_duo): +def test_mcpb002_positional_with_state(dash_duo): """Positional with State: Input + State both appear as params.""" app = Dash(__name__) app.layout = html.Div( @@ -75,7 +97,6 @@ def test_positional_with_state(dash_duo): ] ) - # Callback: 1 Input + 1 State → 2 params named "n_clicks" and "value" @app.callback( Output("out", "children"), Input("btn", "n_clicks"), @@ -93,10 +114,8 @@ def update(n_clicks, value): assert set(props.keys()) == {"n_clicks", "value"} assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) - # Tool description reflects initial state assert "value: 'hello'" in props["value"].get("description", "") - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} ) @@ -110,7 +129,7 @@ def update(n_clicks, value): assert response["out"]["children"] == "Clicked 3 with world" -def test_multi_output_positional(dash_duo): +def test_mcpb003_multi_output_positional(dash_duo): """Multi-output: returns tuple → both outputs updated in response.""" app = Dash(__name__) app.layout = html.Div( @@ -121,7 +140,6 @@ def test_multi_output_positional(dash_duo): ] ) - # Callback: 1 Input → 2 Outputs via tuple return @app.callback( Output("out1", "children"), Output("out2", "children"), @@ -138,18 +156,16 @@ def split_case(value): props = tool["inputSchema"]["properties"] assert set(props.keys()) == {"value"} - # Tool description reflects initial state assert "value: 'test'" in props["value"].get("description", "") - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) response = _get_response(result) assert response["out1"]["children"] == "TEST" assert response["out2"]["children"] == "test" -def test_dict_based_inputs_and_state(dash_duo): - """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are param names.""" +def test_mcpb004_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are params.""" app = Dash(__name__) app.layout = html.Div( [ @@ -159,7 +175,6 @@ def test_dict_based_inputs_and_state(dash_duo): ] ) - # Callback: dict keys "trigger" and "name" become param names @app.callback( Output("out", "children"), inputs=dict(trigger=Input("btn", "n_clicks")), @@ -177,7 +192,6 @@ def greet(trigger, name): assert set(props.keys()) == {"trigger", "name"} assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "greet", {"trigger": None, "name": "world"} ) @@ -191,7 +205,7 @@ def greet(trigger, name): assert response["out"]["children"] == "Hello, Dash!" -def test_dict_based_outputs(dash_duo): +def test_mcpb005_dict_based_outputs(dash_duo): """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" app = Dash(__name__) app.layout = html.Div( @@ -202,7 +216,6 @@ def test_dict_based_outputs(dash_duo): ] ) - # Callback: dict output keys "upper" and "lower" map to components @app.callback( output=dict( upper=Output("upper-out", "children"), @@ -221,14 +234,13 @@ def transform(val): props = tool["inputSchema"]["properties"] assert set(props.keys()) == {"val"} - # MCP tool with initial inputs matches browser result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) response = _get_response(result) assert response["upper-out"]["children"] == "HELLO" assert response["lower-out"]["children"] == "hello" -def test_mixed_input_state_in_inputs(dash_duo): +def test_mcpb006_mixed_input_state_in_inputs(dash_duo): """Mixed: State inside inputs=dict alongside Input → all appear as params.""" app = Dash(__name__) app.layout = html.Div( @@ -240,7 +252,6 @@ def test_mixed_input_state_in_inputs(dash_duo): ] ) - # Callback: Input and State mixed in same dict → all keys are params @app.callback( Output("out", "children"), inputs=dict( @@ -261,7 +272,6 @@ def full_name(clicks, first, last): assert set(props.keys()) == {"clicks", "first", "last"} assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) - # MCP tool with initial inputs matches browser result = _mcp_call_tool( dash_duo.server.url, "full_name", @@ -279,7 +289,7 @@ def full_name(clicks, first, last): assert response["out"]["children"] == "John Smith" -def test_tuple_grouped_inputs(dash_duo): +def test_mcpb007_tuple_grouped_inputs(dash_duo): """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" app = Dash(__name__) app.layout = html.Div( @@ -290,7 +300,6 @@ def test_tuple_grouped_inputs(dash_duo): ] ) - # Callback: tuple group "pair" maps to 2 deps → 2 params named pair___ @app.callback( Output("out", "children"), inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), @@ -302,7 +311,6 @@ def combine(pair): tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") props = tool["inputSchema"]["properties"] - # Tuple expands: one param per dep, named with group prefix + component info assert set(props.keys()) == {"pair_a__value", "pair_b__value"} for schema in props.values(): assert any(s.get("type") == "string" for s in schema["anyOf"]) @@ -316,7 +324,7 @@ def combine(pair): assert response["out"]["children"] == "x+y" -def test_initial_values_from_chained_callbacks(dash_duo): +def test_mcpb008_initial_values_from_chained_callbacks(dash_duo): """Querying components reflects post-initial-callback values. 3-link chain: country (default "France") → update_states → @@ -368,16 +376,13 @@ def update_cities(state, country): dash_duo.start_server(app) - # Tool descriptions should reflect post-initial-callback state tools = _mcp_tools(dash_duo.server.url) update_cities_tool = _find_tool(tools, "update_cities") state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( "description", "" ) - # state.value was set to "Ile-de-France" by update_states initial callback assert "Ile-de-France" in state_desc - # state.value should be "Ile-de-France" (first state for France) result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -386,7 +391,6 @@ def update_cities(state, country): state_props = result["result"]["structuredContent"]["properties"] assert state_props["value"]["initial_value"] == "Ile-de-France" - # city.value should be "Paris" (first city for Ile-de-France) result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -396,7 +400,7 @@ def update_cities(state, country): assert city_props["value"]["initial_value"] == "Paris" -def test_dict_based_reordered_state_input(dash_duo): +def test_mcpb009_dict_based_reordered_state_input(dash_duo): """Dict-based callback with State before Input: call works, schema types correct. State is listed before Input in the dict. The callback should still @@ -421,7 +425,6 @@ def greet(name: str, trigger: int): dash_duo.start_server(app) - # First: verify the callback actually works with these args result = _mcp_call_tool( dash_duo.server.url, "greet", @@ -429,20 +432,23 @@ def greet(name: str, trigger: int): ) assert _get_response(result)["out"]["children"] == "Hello Dash" - # Second: verify schema types match annotations tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") props = tool["inputSchema"]["properties"] assert props["trigger"]["type"] == "integer" assert props["name"]["type"] == "string" - # Third: verify each param describes the correct component trigger_desc = props["trigger"].get("description", "") assert "number of times that this element has been clicked on" in trigger_desc name_desc = props["name"].get("description", "") assert "The value of the input" in name_desc -def test_pattern_matching_callback(dash_duo): +# --------------------------------------------------------------------------- +# Pattern-matching callbacks (ALL / MATCH / ALLSMALLER) +# --------------------------------------------------------------------------- + + +def test_mcpb010_pattern_matching_callback(dash_duo): """Pattern-matching dict IDs: tool works with correct params and results.""" app = Dash(__name__) app.layout = html.Div( @@ -469,7 +475,6 @@ def combine(first, second): assert "first" in props assert "second" in props - # Verify initial output matches what the browser shows dash_duo.wait_for_text_to_equal("#out", "hello world") result = _mcp_call_tool( dash_duo.server.url, @@ -479,7 +484,6 @@ def combine(first, second): response = _get_response(result) assert response["out"]["children"] == "hello world" - # Verify with different values result = _mcp_call_tool( dash_duo.server.url, "combine", @@ -489,10 +493,8 @@ def combine(first, second): assert response["out"]["children"] == "foo bar" -def test_pattern_matching_with_all_wildcard(dash_duo): +def test_mcpb011_pattern_matching_with_all_wildcard(dash_duo): """ALL wildcard: one callback receives values from all matching components.""" - from dash import ALL - app = Dash(__name__) app.layout = html.Div( [ @@ -515,7 +517,6 @@ def summarize(values): tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") assert tool is not None - # Schema must describe values as an array of {id, property, value} objects values_schema = tool["inputSchema"]["properties"]["values"] assert ( values_schema["type"] == "array" @@ -529,7 +530,6 @@ def summarize(values): "description", "" ), "ALL wildcard param description should explain the pattern-matching behavior" - # MCP tool call with browser-like format: concrete IDs + values result = _mcp_call_tool( dash_duo.server.url, "summarize", @@ -551,7 +551,6 @@ def summarize(values): response = _get_response(result) assert response["summary"]["children"] == "alpha, beta" - # Different values result = _mcp_call_tool( dash_duo.server.url, "summarize", @@ -574,10 +573,8 @@ def summarize(values): assert response["summary"]["children"] == "one, two" -def test_pattern_matching_mixed_outputs(dash_duo): +def test_mcpb012_pattern_matching_mixed_outputs(dash_duo): """Mixed outputs: one regular + one ALL wildcard in the same callback.""" - from dash import ALL - app = Dash(__name__) app.layout = html.Div( [ @@ -624,13 +621,11 @@ def echo_and_total(values): assert response["total"]["children"] == "Total: 2 items" -def test_pattern_matching_with_match_wildcard(dash_duo): +def test_mcpb013_pattern_matching_with_match_wildcard(dash_duo): """MATCH wildcard: callback fires per-component with matching index. Based on https://dash.plotly.com/pattern-matching-callbacks """ - from dash import MATCH - app = Dash(__name__) app.layout = html.Div( [ @@ -661,11 +656,9 @@ def show_city(value): tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") assert tool is not None - # Schema describes MATCH input value_schema = tool["inputSchema"]["properties"]["value"] assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") - # Call with concrete ID for index 0 (MATCH takes a single entry, not an array) result = _mcp_call_tool( dash_duo.server.url, "show_city", @@ -678,18 +671,15 @@ def show_city(value): }, ) response = _get_response(result) - # Find the output key containing "city-out" (Dash may serialize dict IDs differently) out_key = next(k for k in response if "city-out" in k) assert response[out_key]["children"] == "Selected: MTL" -def test_pattern_matching_with_allsmaller_wildcard(dash_duo): +def test_mcpb014_pattern_matching_with_allsmaller_wildcard(dash_duo): """ALLSMALLER wildcard: receives values from components with smaller index. Based on https://dash.plotly.com/pattern-matching-callbacks """ - from dash import MATCH, ALLSMALLER - app = Dash(__name__) app.layout = html.Div( [ @@ -728,14 +718,12 @@ def show_countries(current, previous): tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") assert tool is not None - # Schema describes both MATCH and ALLSMALLER inputs props = tool["inputSchema"]["properties"] assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( "description", "" ) - # Call for index 2: MATCH is a single dict, ALLSMALLER is a list result = _mcp_call_tool( dash_duo.server.url, "show_countries", @@ -764,13 +752,13 @@ def show_countries(current, previous): assert response[out_key]["children"] == "All: Japan, Germany, France" -def test_prevent_initial_call_uses_layout_default(dash_duo): - """prevent_initial_call=True: initial value stays as the layout default. +# --------------------------------------------------------------------------- +# Initial values: prevent_initial_call vs. initial-callback overrides +# --------------------------------------------------------------------------- - The dropdown has value="original" in the layout. The callback has - prevent_initial_call=True so it doesn't run on page load. The MCP - tool description should show value: 'a' (layout default). - """ + +def test_mcpb015_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default.""" app = Dash(__name__) app.layout = html.Div( [ @@ -788,24 +776,16 @@ def update(val): return f"Changed to: {val}" dash_duo.start_server(app) - # Browser shows layout default — callback hasn't fired dash_duo.wait_for_text_to_equal("#out", "not yet") tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") - # Tool description reflects layout default, not callback output assert "value: 'a'" in val_desc -def test_initial_callback_overrides_layout_value(dash_duo): - """Initial callback overrides layout value in tool description. - - The city dropdown has value="default-city" in the layout. - update_city runs on page load (no prevent_initial_call) and - sets city.value to "Paris". The MCP tool should show "Paris" - as the default, not "default-city". - """ +def test_mcpb016_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description.""" app = Dash(__name__) app.layout = html.Div( [ @@ -830,24 +810,20 @@ def show_city(city): return f"City: {city}" dash_duo.start_server(app) - # Browser shows "Paris" — the initial callback overrode "default-city" dash_duo.wait_for_text_to_equal("#out", "City: Paris") tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") - # Tool description should show the post-initial-callback value assert "value: 'Paris'" in city_desc assert "default-city" not in city_desc -def test_callback_context_triggered_id(dash_duo): +def test_mcpb017_callback_context_triggered_id(dash_duo): """Callbacks using dash.ctx.triggered_id work via MCP. Based on https://dash.plotly.com/determining-which-callback-input-changed """ - from dash import ctx - app = Dash(__name__) app.layout = html.Div( [ @@ -871,17 +847,14 @@ def display(btn1, btn2, btn3): dash_duo.start_server(app) - # Browser initial state: no button clicked dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") - # Tool should have all three button params tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") props = tool["inputSchema"]["properties"] assert "btn1" in props assert "btn2" in props assert "btn3" in props - # Click btn-2 via MCP — ctx.triggered_id should be "btn-2" result = _mcp_call_tool( dash_duo.server.url, "display", @@ -890,7 +863,6 @@ def display(btn1, btn2, btn3): response = _get_response(result) assert response["output"]["children"] == "Last clicked: btn-2" - # Click btn-3 via MCP result = _mcp_call_tool( dash_duo.server.url, "display", @@ -900,14 +872,12 @@ def display(btn1, btn2, btn3): assert response["output"]["children"] == "Last clicked: btn-3" -def test_no_output_callback_does_not_crash_tools_list(dash_duo): +def test_mcpb018_no_output_callback_does_not_crash_tools_list(dash_duo): """A callback with no Output should not crash tools/list. No-output callbacks use set_props for side effects. They produce a hash-only output_id with no dot separator. """ - from dash import set_props - app = Dash(__name__) app.layout = html.Div( [ @@ -930,13 +900,9 @@ def show_selection(val): tools = _mcp_tools(dash_duo.server.url) tool_names = [t["name"] for t in tools] - # show_selection should appear as a tool assert "show_selection" in tool_names - - # log_click has no declared output but uses set_props — still a valid tool assert "log_click" in tool_names - # Call log_click — sideUpdate should show the set_props effect result = _mcp_call_tool( dash_duo.server.url, "log_click", @@ -946,9 +912,6 @@ def show_selection(val): assert "sideUpdate" in structured assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" - # get_dash_component shows show_selection as modifier (declared output). - # log_click uses set_props which bypasses the declarative graph — - # its effect is only visible via sideUpdate in tool call results. result = _mcp_call_tool( dash_duo.server.url, "get_dash_component", @@ -956,3 +919,338 @@ def show_selection(val): ) prop_info = result["result"]["structuredContent"]["properties"]["children"] assert "show_selection" in prop_info["modified_by_tool"] + + +# --------------------------------------------------------------------------- +# Duplicate outputs (allow_duplicate=True) +# --------------------------------------------------------------------------- + + +def test_mcpb019_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_mcpb020_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_mcpb021_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" + + +# --------------------------------------------------------------------------- +# tools/list — naming rules (64-char limit, uniqueness, built-ins) +# --------------------------------------------------------------------------- + + +def test_mcpb022_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_mcpb023_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb024_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_mcpb025_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_mcpb026_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names + + +# --------------------------------------------------------------------------- +# Input schema smoke test + get_dash_component HTTP structured output +# --------------------------------------------------------------------------- + + +def test_mcpb027_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + assert func_name in tool["name"] + + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" + + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_mcpb028_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/test_mcp_endpoint.py b/tests/integration/mcp/test_mcp_endpoint.py new file mode 100644 index 0000000000..44b358c25d --- /dev/null +++ b/tests/integration/mcp/test_mcp_endpoint.py @@ -0,0 +1,189 @@ +"""MCP Streamable HTTP endpoint — transport-layer behavior. + +Uses Flask's test_client to exercise POST/GET/DELETE at /_mcp, +session management, content-type handling, and route registration +driven by ``enable_mcp`` / ``DASH_MCP_ENABLED`` / ``routes_pathname_prefix``. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpe001_post_initialize_returns_protocol_version(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +def test_mcpe002_post_tools_list(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "result" in data + assert "tools" in data["result"] + + +def test_mcpe003_notification_returns_202(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + ) + assert r.status_code == 202 + + +def test_mcpe004_delete_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.delete(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe005_get_returns_405(): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 405 + + +def test_mcpe006_post_rejects_wrong_content_type(): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + +def test_mcpe007_routes_not_registered_when_disabled(): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + +def test_mcpe008_routes_respect_pathname_prefix(): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + +def test_mcpe009_enable_mcp_env_var_false(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + +def test_mcpe010_constructor_overrides_env_var(): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/integration/mcp/primitives/resources/test_resources.py b/tests/integration/mcp/test_mcp_resources.py similarity index 86% rename from tests/integration/mcp/primitives/resources/test_resources.py rename to tests/integration/mcp/test_mcp_resources.py index dfc1e09f9b..41519578d1 100644 --- a/tests/integration/mcp/primitives/resources/test_resources.py +++ b/tests/integration/mcp/test_mcp_resources.py @@ -1,4 +1,4 @@ -"""Integration tests for MCP resources.""" +"""MCP resources — ``resources/list`` and ``resources/read`` via HTTP.""" import json @@ -7,7 +7,7 @@ from tests.integration.mcp.conftest import _mcp_method -def test_resources_list_includes_layout(dash_duo): +def test_mcpz001_resources_list_includes_layout(dash_duo): app = Dash(__name__) app.layout = html.Div( [ @@ -24,7 +24,7 @@ def test_resources_list_includes_layout(dash_duo): assert "dash://layout" in uris -def test_read_layout_resource(dash_duo): +def test_mcpz002_read_layout_resource(dash_duo): app = Dash(__name__) app.layout = html.Div( [ diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py deleted file mode 100644 index 4f0d0fca00..0000000000 --- a/tests/integration/mcp/test_server.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Integration tests for the MCP Streamable HTTP endpoint. - -These tests use Flask's test_client to exercise the HTTP transport layer -(POST/GET/DELETE at /_mcp), session management, content-type handling, -and route registration/configuration. -""" - -import json -import os - -from dash import Dash, Input, Output, html -from mcp.types import LATEST_PROTOCOL_VERSION - -MCP_PATH = "_mcp" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_app(**kwargs): - """Create a minimal Dash app with a layout and one callback.""" - app = Dash(__name__, **kwargs) - app.layout = html.Div( - [ - html.Div(id="my-input"), - html.Div(id="my-output"), - ] - ) - - @app.callback(Output("my-output", "children"), Input("my-input", "children")) - def update_output(value): - """Test callback docstring.""" - return f"echo: {value}" - - return app - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -class TestMCPEndpoint: - """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" - - def test_post_initialize_returns_protocol_version(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - data = json.loads(r.data) - assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - - def test_post_tools_list(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - data = json.loads(r.data) - assert "result" in data - assert "tools" in data["result"] - - def test_notification_returns_202(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), - content_type="application/json", - ) - assert r.status_code == 202 - - def test_delete_returns_405(self): - app = _make_app() - client = app.server.test_client() - r = client.delete(f"/{MCP_PATH}") - assert r.status_code == 405 - - def test_get_returns_405(self): - app = _make_app() - client = app.server.test_client() - r = client.get(f"/{MCP_PATH}") - assert r.status_code == 405 - - def test_post_rejects_wrong_content_type(self): - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data="not json", - content_type="text/plain", - ) - assert r.status_code == 415 - - def test_routes_not_registered_when_disabled(self): - app = _make_app(enable_mcp=False) - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - # With MCP disabled, the route doesn't exist — response is HTML, not JSON - assert r.content_type != "application/json" - - def test_routes_respect_pathname_prefix(self): - app = _make_app(routes_pathname_prefix="/app/") - client = app.server.test_client() - - ok = client.post( - f"/app/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert ok.status_code == 200 - - miss = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert miss.status_code == 404 - - def test_enable_mcp_env_var_false(self): - old = os.environ.get("DASH_MCP_ENABLED") - try: - os.environ["DASH_MCP_ENABLED"] = "false" - app = _make_app() - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.content_type != "application/json" - finally: - if old is None: - os.environ.pop("DASH_MCP_ENABLED", None) - else: - os.environ["DASH_MCP_ENABLED"] = old - - def test_constructor_overrides_env_var(self): - old = os.environ.get("DASH_MCP_ENABLED") - try: - os.environ["DASH_MCP_ENABLED"] = "false" - app = _make_app(enable_mcp=True) - client = app.server.test_client() - r = client.post( - f"/{MCP_PATH}", - data=json.dumps( - {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} - ), - content_type="application/json", - ) - assert r.status_code == 200 - assert b"protocolVersion" in r.data - finally: - if old is None: - os.environ.pop("DASH_MCP_ENABLED", None) - else: - os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/unit/mcp/test_mcp_server.py b/tests/unit/mcp/test_mcp_server.py new file mode 100644 index 0000000000..f4bb595dce --- /dev/null +++ b/tests/unit/mcp/test_mcp_server.py @@ -0,0 +1,99 @@ +"""MCP server JSON-RPC message processing (``_process_mcp_message``).""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcps001_initialize(): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + +def test_mcps002_tools_call(): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + +def test_mcps003_tools_call_unknown_tool_returns_error(): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + +def test_mcps004_unknown_method_returns_error(): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + +def test_mcps005_notification_returns_none(): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py deleted file mode 100644 index 23c99c50ad..0000000000 --- a/tests/unit/mcp/test_server.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Tests for MCP server (_server.py) — JSON-RPC message processing.""" - -from dash._get_app import app_context -from dash.mcp._server import _process_mcp_message -from mcp.types import LATEST_PROTOCOL_VERSION - -from tests.unit.mcp.conftest import _make_app, _setup_mcp - - -def _msg(method, params=None, request_id=1): - d = {"jsonrpc": "2.0", "method": method, "id": request_id} - d["params"] = params if params is not None else {} - return d - - -def _mcp(app, method, params=None, request_id=1): - with app.server.test_request_context(): - _setup_mcp(app) - return _process_mcp_message(_msg(method, params, request_id)) - - -def _tools_list(app): - return _mcp(app, "tools/list")["result"]["tools"] - - -def _call_tool(app, tool_name, arguments=None, request_id=1): - return _mcp( - app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id - ) - - -def _call_tool_output( - app, tool_name, arguments=None, component_id=None, prop="children" -): - result = _call_tool(app, tool_name, arguments) - structured = result["result"]["structuredContent"] - response = structured["response"] - if component_id is None: - component_id = next(iter(response)) - return response[component_id][prop] - - -class TestProcessMCPMessage: - def test_initialize(self): - app = _make_app() - result = _mcp(app, "initialize") - - assert result is not None - assert result["id"] == 1 - assert result["jsonrpc"] == "2.0" - assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION - assert "serverInfo" in result["result"] - - def test_tools_call(self): - app = _make_app() - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) - - result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) - - assert result is not None - assert result["id"] == 2 - assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" - - def test_tools_call_unknown_tool_returns_error(self): - app = _make_app() - result = _call_tool(app, "nonexistent_tool") - - assert result is not None - assert "error" in result - assert result["error"]["code"] == -32601 - - def test_unknown_method_returns_error(self): - app = _make_app() - result = _mcp(app, "unknown/method") - - assert result is not None - assert "error" in result - - def test_notification_returns_none(self): - app = _make_app() - data = {"jsonrpc": "2.0", "method": "notifications/initialized"} - with app.server.test_request_context(): - app_context.set(app) - result = _process_mcp_message(data) - assert result is None diff --git a/tests/unit/mcp/tools/test_mcp_run_callback.py b/tests/unit/mcp/tools/test_mcp_run_callback.py new file mode 100644 index 0000000000..e345b9682e --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_run_callback.py @@ -0,0 +1,253 @@ +"""Callback dispatch execution via MCP tools (``run_callback``). + +Exercises how the MCP tool pipeline runs a Dash callback through +``_process_mcp_message`` with various signatures: multi-output, State, +positional vs. dict-based ``inputs``, ``PreventUpdate``, and no-output +set_props-style callbacks. +""" + +from dash import Dash, Input, Output, State, dcc, html, set_props +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpx001_multi_output(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [{"label": "b", "value": "b"}] + assert structured["response"]["out"]["children"] == "selected: b" + + +def test_mcpx002_omitted_kwargs_default_to_none(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + +def test_mcpx003_no_output_callback(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + +def test_mcpx004_prevent_update(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + +def test_mcpx005_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "click", "store": "data"}, + "result", + ) + == "click-data" + ) + + +def test_mcpx006_dict_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"x_val": "foo", "y_val": "bar"}, + "dict-out", + ) + == "foo-bar" + ) + + +def test_mcpx007_positional_inputs(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + + +def test_mcpx008_dict_inputs_with_state(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + {"trigger": "hey", "kept": "saved"}, + "ds-out", + ) + == "hey+saved" + ) diff --git a/tests/unit/mcp/tools/test_run_callback.py b/tests/unit/mcp/tools/test_run_callback.py deleted file mode 100644 index 00f4e5b7b1..0000000000 --- a/tests/unit/mcp/tools/test_run_callback.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Tests for callback dispatch execution via MCP tools.""" - -from dash import Dash, Input, Output, State, dcc, html -from dash.exceptions import PreventUpdate -from dash.mcp._server import _process_mcp_message - -from tests.unit.mcp.conftest import _setup_mcp - - -def _msg(method, params=None, request_id=1): - d = {"jsonrpc": "2.0", "method": method, "id": request_id} - d["params"] = params if params is not None else {} - return d - - -def _mcp(app, method, params=None, request_id=1): - with app.server.test_request_context(): - _setup_mcp(app) - return _process_mcp_message(_msg(method, params, request_id)) - - -def _tools_list(app): - return _mcp(app, "tools/list")["result"]["tools"] - - -def _call_tool_structured(app, tool_name, arguments=None): - result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) - return result["result"]["structuredContent"] - - -def _call_tool_output( - app, tool_name, arguments=None, component_id=None, prop="children" -): - structured = _call_tool_structured(app, tool_name, arguments) - response = structured["response"] - if component_id is None: - component_id = next(iter(response)) - return response[component_id][prop] - - -class TestRunCallback: - def test_multi_output(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a", "b"], value="a"), - dcc.Dropdown(id="dd2"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("dd2", "options"), - Output("out", "children"), - Input("dd", "value"), - ) - def update(val): - return [{"label": val, "value": val}], f"selected: {val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - structured = _call_tool_structured(app, tool_name, {"val": "b"}) - assert structured["response"]["dd2"]["options"] == [ - {"label": "b", "value": "b"} - ] - assert structured["response"]["out"]["children"] == "selected: b" - - def test_omitted_kwargs_default_to_none(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Dropdown(id="dd", options=["a"]), - dcc.Input(id="inp"), - html.Div(id="out"), - ] - ) - - @app.callback( - Output("out", "children"), - Input("dd", "value"), - State("inp", "value"), - ) - def update(selected, text): - return f"{selected}-{text}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" - - def test_no_output_callback(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Button(id="btn"), - html.Div(id="display"), - ] - ) - - @app.callback(Input("btn", "n_clicks")) - def server_cb(n): - from dash import set_props - - set_props("display", {"children": f"Clicked {n} times"}) - - tools = _tools_list(app) - tool_names = [t["name"] for t in tools] - assert "server_cb" in tool_names - - def test_prevent_update(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="inp", value="hello"), - html.Div(id="out"), - ] - ) - - @app.callback(Output("out", "children"), Input("inp", "value")) - def update(val): - if val == "block": - raise PreventUpdate - return f"got: {val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "update" in t["name"]) - assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" - - def test_with_state(self): - app = Dash(__name__) - app.layout = html.Div( - [ - html.Div(id="trigger"), - html.Div(id="store"), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), - Input("trigger", "children"), - State("store", "children"), - ) - def with_state(trigger, store): - return f"{trigger}-{store}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "trigger": "click", - "store": "data", - }, - "result", - ) - == "click-data" - ) - - def test_dict_inputs(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="x-input", value="hello"), - dcc.Input(id="y-input", value="world"), - html.Div(id="dict-out"), - ] - ) - - @app.callback( - Output("dict-out", "children"), - inputs={ - "x_val": Input("x-input", "value"), - "y_val": Input("y-input", "value"), - }, - ) - def combine(**kwargs): - return f"{kwargs['x_val']}-{kwargs['y_val']}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "combine" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "x_val": "foo", - "y_val": "bar", - }, - "dict-out", - ) - == "foo-bar" - ) - - def test_positional_inputs(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="a-input", value="A"), - html.Div(id="pos-out"), - ] - ) - - @app.callback(Output("pos-out", "children"), Input("a-input", "value")) - def echo(val): - return f"got:{val}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "echo" in t["name"]) - assert ( - _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" - ) - - def test_dict_inputs_with_state(self): - app = Dash(__name__) - app.layout = html.Div( - [ - dcc.Input(id="inp", value="hi"), - html.Div(id="st", children="state-val"), - html.Div(id="ds-out"), - ] - ) - - @app.callback( - Output("ds-out", "children"), - inputs={"trigger": Input("inp", "value")}, - state={"kept": State("st", "children")}, - ) - def with_dict_state(**kwargs): - return f"{kwargs['trigger']}+{kwargs['kept']}" - - tools = _tools_list(app) - tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) - assert ( - _call_tool_output( - app, - tool_name, - { - "trigger": "hey", - "kept": "saved", - }, - "ds-out", - ) - == "hey+saved" - )