diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index abef0b73..b035cad5 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -20,7 +20,8 @@ from deprecated import deprecated from .itransport import ITransport -from .protocol import ToolSchema +from .mcp_transport import McpHttpTransport +from .protocol import Protocol, ToolSchema from .tool import ToolboxTool from .toolbox_transport import ToolboxTransport from .utils import identify_auth_requirements, resolve_value @@ -44,6 +45,7 @@ def __init__( client_headers: Optional[ Mapping[str, Union[Callable[[], str], Callable[[], Awaitable[str]], str]] ] = None, + protocol: Protocol = Protocol.TOOLBOX, ): """ Initializes the ToolboxClient. @@ -54,10 +56,15 @@ def __init__( If None (default), a new session is created internally. Note that if a session is provided, its lifecycle (including closing) should typically be managed externally. - client_headers: Headers to include in each request sent through this client. + client_headers: Headers to include in each request sent through this + client. + protocol: The communication protocol to use. """ + if protocol == Protocol.TOOLBOX: + self.__transport = ToolboxTransport(url, session) + else: + self.__transport = McpHttpTransport(url, session, protocol) - self.__transport = ToolboxTransport(url, session) self.__client_headers = client_headers if client_headers is not None else {} def __parse_tool( diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport.py b/packages/toolbox-core/src/toolbox_core/mcp_transport.py new file mode 100644 index 00000000..8e9f0035 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport.py @@ -0,0 +1,288 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import os +import uuid +from typing import Any, Mapping, Optional, Union + +from aiohttp import ClientSession + +from . import version +from .itransport import ITransport +from .protocol import ( + AdditionalPropertiesSchema, + ManifestSchema, + ParameterSchema, + Protocol, + ToolSchema, +) + + +class McpHttpTransport(ITransport): + """Transport for the MCP protocol.""" + + def __init__( + self, + base_url: str, + session: Optional[ClientSession] = None, + protocol: Protocol = Protocol.MCP, + ): + self.__mcp_base_url = base_url + "/mcp/" + # Will be updated after negotiation + self.__protocol_version = protocol.value + self.__server_version: Optional[str] = None + self.__session_id: Optional[str] = None + + self.__manage_session = session is None + self.__session = session or ClientSession() + self.__init_task = asyncio.create_task(self.__initialize_session()) + + @property + def base_url(self) -> str: + return self.__mcp_base_url + + def __convert_tool_schema(self, tool_data: dict) -> ToolSchema: + parameters = [] + input_schema = tool_data.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required = input_schema.get("required", []) + + for name, schema in properties.items(): + additional_props = schema.get("additionalProperties") + if isinstance(additional_props, dict): + additional_props = AdditionalPropertiesSchema( + type=additional_props["type"] + ) + else: + additional_props = True + parameters.append( + ParameterSchema( + name=name, + type=schema["type"], + description=schema.get("description", ""), + required=name in required, + additionalProperties=additional_props, + ) + ) + + return ToolSchema(description=tool_data["description"], parameters=parameters) + + async def __list_tools( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> Any: + """Private helper to fetch the raw tool list from the server.""" + if toolset_name: + url = self.__mcp_base_url + toolset_name + else: + url = self.__mcp_base_url + return await self.__send_request( + url=url, method="tools/list", params={}, headers=headers + ) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + await self.__init_task + + if self.__server_version is None: + raise RuntimeError("Server version not available.") + + result = await self.__list_tools(headers=headers) + tool_def = None + for tool in result.get("tools", []): + if tool.get("name") == tool_name: + tool_def = self.__convert_tool_schema(tool) + break + + if tool_def is None: + raise ValueError(f"Tool '{tool_name}' not found.") + + tool_details = ManifestSchema( + serverVersion=self.__server_version, + tools={tool_name: tool_def}, + ) + return tool_details + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self.__init_task + + if self.__server_version is None: + raise RuntimeError("Server version not available.") + + result = await self.__list_tools(toolset_name, headers) + tools = result.get("tools") + + return ManifestSchema( + serverVersion=self.__server_version, + tools={tool["name"]: self.__convert_tool_schema(tool) for tool in tools}, + ) + + async def tool_invoke( + self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]] + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self.__init_task + + url = self.__mcp_base_url + params = {"name": tool_name, "arguments": arguments} + result = await self.__send_request( + url=url, method="tools/call", params=params, headers=headers + ) + all_content = result.get("content", result) + content_str = "".join( + content.get("text", "") + for content in all_content + if isinstance(content, dict) + ) + return content_str or "null" + + async def close(self): + try: + await self.__init_task + except Exception: + # If initialization failed, we can still try to close the session. + pass + finally: + if self.__manage_session and self.__session and not self.__session.closed: + await self.__session.close() + + async def __initialize_session(self): + """Initializes the MCP session.""" + if self.__session is None and self.__manage_session: + self.__session = ClientSession() + + url = self.__mcp_base_url + + # Perform version negotitation + client_supported_versions = Protocol.get_supported_mcp_versions() + proposed_protocol_version = self.__protocol_version + params = { + "processId": os.getpid(), + "clientInfo": { + "name": "toolbox-python-sdk", + "version": version.__version__, + }, + "capabilities": {}, + "protocolVersion": proposed_protocol_version, + } + # Send initialize notification + initialize_result = await self.__send_request( + url=url, method="initialize", params=params + ) + + # Get the session id if the proposed version requires it + if proposed_protocol_version == "2025-03-26": + self.__session_id = initialize_result.get("Mcp-Session-Id") + if not self.__session_id: + if self.__manage_session: + await self.close() + raise RuntimeError( + "Server did not return a Mcp-Session-Id during initialization." + ) + server_info = initialize_result.get("serverInfo") + if not server_info: + raise RuntimeError("Server info not found in initialize response") + + self.__server_version = server_info.get("version") + if not self.__server_version: + raise RuntimeError("Server version not found in initialize response") + + # Perform version negotiation based on server response + server_protcol_version = initialize_result.get("protocolVersion") + if server_protcol_version: + if server_protcol_version not in client_supported_versions: + if self.__manage_session: + await self.close() + raise RuntimeError( + f"MCP version mismatch: client does not support server version {server_protcol_version}" + ) + # Update the protocol version to the one agreed upon by the server. + self.__protocol_version = server_protcol_version + else: + if self.__manage_session: + await self.close() + raise RuntimeError("MCP Protocol version not found in initialize response") + + server_capabilities = initialize_result.get("capabilities") + if not server_capabilities or "tools" not in server_capabilities: + if self.__manage_session: + await self.close() + raise RuntimeError("Server does not support the 'tools' capability.") + await self.__send_request( + url=url, method="notifications/initialized", params={} + ) + + async def __send_request( + self, + url: str, + method: str, + params: dict, + headers: Optional[Mapping[str, str]] = None, + ) -> Any: + """Sends a JSON-RPC request to the MCP server.""" + + request_params = params.copy() + req_headers = dict(headers or {}) + + # Check based on the NEGOTIATED version (self.__protocol_version) + if ( + self.__protocol_version == "2025-03-26" + and method != "initialize" + and self.__session_id + ): + request_params["Mcp-Session-Id"] = self.__session_id + + if self.__protocol_version == "2025-06-18": + req_headers["MCP-Protocol-Version"] = self.__protocol_version + + payload = { + "jsonrpc": "2.0", + "method": method, + "params": request_params, + } + + if not method.startswith("notifications/"): + payload["id"] = str(uuid.uuid4()) + + async with self.__session.post( + url, json=payload, headers=req_headers + ) as response: + if not response.ok: + error_text = await response.text() + raise RuntimeError( + f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}" + ) + + # Handle potential empty body (e.g. 204 No Content for notifications) + if response.status == 204 or response.content.at_eof(): + return None + + json_response = await response.json() + if "error" in json_response: + error = json_response["error"] + if error["code"] == -32000: + raise RuntimeError(f"MCP version mismatch: {error['message']}") + else: + raise RuntimeError( + f"MCP request failed with code {error['code']}: {error['message']}" + ) + return json_response.get("result") diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index 8cf563cc..5b14ce8e 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -11,12 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from enum import Enum from inspect import Parameter from typing import Any, Optional, Type, Union from pydantic import BaseModel + +class Protocol(str, Enum): + """Defines how the client should choose between communication protocols.""" + + TOOLBOX = "toolbox" + MCP_v20250618 = "2025-06-18" + MCP_v20250326 = "2025-03-26" + MCP_v20241105 = "2024-11-05" + MCP_LATEST = MCP_v20250618 + MCP = MCP_LATEST + + @classmethod + def get_supported_mcp_versions(cls): + """Returns a list of supported MCP versions, sorted from newest to oldest.""" + versions = [member for member in cls if member.name.startswith("MCP_v")] + # Sort by the version date in descending order + versions.sort(key=lambda x: x.value, reverse=True) + return [v.value for v in versions] + + __TYPE_MAP = { "string": str, "integer": int, diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py new file mode 100644 index 00000000..d4c64e21 --- /dev/null +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -0,0 +1,353 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import Parameter, signature +from typing import Any, Optional + +import pytest +import pytest_asyncio +from pydantic import ValidationError + +from toolbox_core.client import ToolboxClient +from toolbox_core.protocol import Protocol +from toolbox_core.tool import ToolboxTool + + +# --- Shared Fixtures Defined at Module Level --- +@pytest_asyncio.fixture(scope="function") +async def toolbox(): + """Creates a ToolboxClient instance shared by all tests in this module.""" + toolbox = ToolboxClient("http://localhost:5000", protocol=Protocol.MCP) + try: + yield toolbox + finally: + await toolbox.close() + + +@pytest_asyncio.fixture(scope="function") +async def get_n_rows_tool(toolbox: ToolboxClient) -> ToolboxTool: + """Load the 'get-n-rows' tool using the shared toolbox client.""" + tool = await toolbox.load_tool("get-n-rows") + assert tool.__name__ == "get-n-rows" + return tool + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBasicE2E: + @pytest.mark.parametrize( + "toolset_name, expected_length, expected_tools", + [ + ("my-toolset", 1, ["get-row-by-id"]), + ("my-toolset-2", 2, ["get-n-rows", "get-row-by-id"]), + ], + ) + async def test_load_toolset_specific( + self, + toolbox: ToolboxClient, + toolset_name: str, + expected_length: int, + expected_tools: list[str], + ): + """Load a specific toolset""" + toolset = await toolbox.load_toolset(toolset_name) + assert len(toolset) == expected_length + tool_names = {tool.__name__ for tool in toolset} + assert tool_names == set(expected_tools) + + async def test_load_toolset_default(self, toolbox: ToolboxClient): + """Load the default toolset, i.e. all tools.""" + toolset = await toolbox.load_toolset() + assert len(toolset) == 7 + tool_names = {tool.__name__ for tool in toolset} + expected_tools = [ + "get-row-by-content-auth", + "get-row-by-email-auth", + "get-row-by-id-auth", + "get-row-by-id", + "get-n-rows", + "search-rows", + "process-data", + ] + assert tool_names == set(expected_tools) + + async def test_run_tool(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool.""" + response = await get_n_rows_tool(num_rows="2") + + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" not in response + + async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with missing params.""" + with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): + await get_n_rows_tool() + + async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): + """Invoke a tool with wrong param type.""" + with pytest.raises( + ValidationError, + match=r"num_rows\s+Input should be a valid string\s+\[type=string_type,\s+input_value=2,\s+input_type=int\]", + ): + await get_n_rows_tool(num_rows=2) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestBindParams: + async def test_bind_params( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a param to an existing tool.""" + new_tool = get_n_rows_tool.bind_params({"num_rows": "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + async def test_bind_params_callable( + self, toolbox: ToolboxClient, get_n_rows_tool: ToolboxTool + ): + """Bind a callable param to an existing tool.""" + new_tool = get_n_rows_tool.bind_params({"num_rows": lambda: "3"}) + response = await new_tool() + assert isinstance(response, str) + assert "row1" in response + assert "row2" in response + assert "row3" in response + assert "row4" not in response + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestOptionalParams: + """ + End-to-end tests for tools with optional parameters. + """ + + async def test_tool_signature_is_correct(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with optional params.""" + tool = await toolbox.load_tool("search-rows") + sig = signature(tool) + + assert "email" in sig.parameters + assert "data" in sig.parameters + assert "id" in sig.parameters + + # The required parameter should have no default + assert sig.parameters["email"].default is Parameter.empty + assert sig.parameters["email"].annotation is str + + # The optional parameter should have a default of None + assert sig.parameters["data"].default is None + assert sig.parameters["data"].annotation is Optional[str] + + # The optional parameter should have a default of None + assert sig.parameters["id"].default is None + assert sig.parameters["id"].annotation is Optional[int] + + async def test_run_tool_with_optional_params_omitted(self, toolbox: ToolboxClient): + """Invoke a tool providing only the required parameter.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_data_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", data=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_optional_id_provided(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=1) + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_optional_id_null(self, toolbox: ToolboxClient): + """Invoke a tool providing both required and optional parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=None) + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_missing_required_param(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(TypeError, match="missing a required argument: 'email'"): + await tool(id=5, data="row5") + + async def test_run_tool_with_required_param_null(self, toolbox: ToolboxClient): + """Invoke a tool without its required parameter.""" + tool = await toolbox.load_tool("search-rows") + with pytest.raises(ValidationError, match="email"): + await tool(email=None, id=5, data="row5") + + async def test_run_tool_with_all_default_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=0, data="row2") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" in response + assert "row3" not in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_all_valid_params(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row3") + assert isinstance(response, str) + assert '"email":"twishabansal@google.com"' in response + assert "row1" not in response + assert "row2" not in response + assert "row3" in response + assert "row4" not in response + assert "row5" not in response + assert "row6" not in response + + async def test_run_tool_with_different_email(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different email.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="anubhavdhawan@google.com", id=3, data="row3") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_data(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=3, data="row4") + assert isinstance(response, str) + assert response == "null" + + async def test_run_tool_with_different_id(self, toolbox: ToolboxClient): + """Invoke a tool providing all parameters but with a different data.""" + tool = await toolbox.load_tool("search-rows") + + response = await tool(email="twishabansal@google.com", id=4, data="row3") + assert isinstance(response, str) + assert response == "null" + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestMapParams: + """ + End-to-end tests for tools with map parameters. + """ + + async def test_tool_signature_with_map_params(self, toolbox: ToolboxClient): + """Verify the client correctly constructs the signature for a tool with map params.""" + tool = await toolbox.load_tool("process-data") + sig = signature(tool) + + assert "execution_context" in sig.parameters + assert sig.parameters["execution_context"].annotation == dict[str, Any] + assert sig.parameters["execution_context"].default is Parameter.empty + + assert "user_scores" in sig.parameters + assert sig.parameters["user_scores"].annotation == dict[str, int] + assert sig.parameters["user_scores"].default is Parameter.empty + + assert "feature_flags" in sig.parameters + assert sig.parameters["feature_flags"].annotation == Optional[dict[str, bool]] + assert sig.parameters["feature_flags"].default is None + + async def test_run_tool_with_map_params(self, toolbox: ToolboxClient): + """Invoke a tool with valid map parameters.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "prod", "id": 1234, "user": 1234.5}, + user_scores={"user1": 100, "user2": 200}, + feature_flags={"new_feature": True}, + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"prod","id":1234,"user":1234.5}' in response + assert '"user_scores":{"user1":100,"user2":200}' in response + assert '"feature_flags":{"new_feature":true}' in response + + async def test_run_tool_with_optional_map_param_omitted( + self, toolbox: ToolboxClient + ): + """Invoke a tool without the optional map parameter.""" + tool = await toolbox.load_tool("process-data") + + response = await tool( + execution_context={"env": "dev"}, user_scores={"user3": 300} + ) + assert isinstance(response, str) + assert '"execution_context":{"env":"dev"}' in response + assert '"user_scores":{"user3":300}' in response + assert '"feature_flags":null' in response + + async def test_run_tool_with_wrong_map_value_type(self, toolbox: ToolboxClient): + """Invoke a tool with a map parameter having the wrong value type.""" + tool = await toolbox.load_tool("process-data") + + with pytest.raises(ValidationError): + await tool( + execution_context={"env": "staging"}, + user_scores={"user4": "not-an-integer"}, + ) diff --git a/packages/toolbox-core/tests/test_mcp_transport.py b/packages/toolbox-core/tests/test_mcp_transport.py new file mode 100644 index 00000000..102d6097 --- /dev/null +++ b/packages/toolbox-core/tests/test_mcp_transport.py @@ -0,0 +1,364 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import AsyncGenerator + +import pytest +import pytest_asyncio +from aiohttp import ClientSession +from aioresponses import aioresponses +from yarl import URL + +from toolbox_core.mcp_transport import McpHttpTransport +from toolbox_core.protocol import ManifestSchema, Protocol + +TEST_BASE_URL = "http://fake-mcp-server.com" +TEST_TOOL_NAME = "test_tool" + + +@pytest_asyncio.fixture +async def http_session() -> AsyncGenerator[ClientSession, None]: + """Provides a real aiohttp ClientSession that is closed after the test.""" + async with ClientSession() as session: + yield session + + +@pytest.fixture +def mock_initialize_response() -> dict: + """Provides a valid sample dictionary for an initialize response.""" + data = { + "jsonrpc": "2.0", + "id": "1", + "result": { + "protocolVersion": "2025-06-18", + "serverInfo": { + "name": "Fake MCP Server", + "version": "1.0.0", + "protocolVersion": Protocol.MCP_LATEST.value, + }, + "capabilities": {"tools": {}}, + }, + } + return copy.deepcopy(data) + + +@pytest.fixture +def mock_tools_list_response() -> dict: + """Provides a valid sample dictionary for a tools/list response.""" + data = { + "jsonrpc": "2.0", + "id": "2", + "result": { + "tools": [ + { + "name": TEST_TOOL_NAME, + "description": "A test tool", + "inputSchema": { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "A parameter", + } + }, + "required": ["param1"], + }, + } + ] + }, + } + return copy.deepcopy(data) + + +@pytest.mark.asyncio +async def test_successful_initialization( + http_session: ClientSession, + mock_initialize_response: dict, + mock_tools_list_response: dict, +): + """Tests that the transport initializes without errors.""" + url = f"{TEST_BASE_URL}/mcp/" + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) # initialized notification + m.post(url, status=200, payload=mock_tools_list_response) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + # Trigger the lazy initialization by calling a method + await transport._McpHttpTransport__list_tools() + + +@pytest.mark.asyncio +async def test_tools_list_success( + http_session: ClientSession, + mock_initialize_response: dict, + mock_tools_list_response: dict, +): + """Tests a successful tools_list call.""" + url = f"{TEST_BASE_URL}/mcp/" + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + + m.post(url, status=200, payload=mock_tools_list_response) + result = await transport.tools_list() + + assert isinstance(result, ManifestSchema) + assert result.serverVersion == "1.0.0" + assert TEST_TOOL_NAME in result.tools + + +@pytest.mark.asyncio +async def test_tool_get_success( + http_session: ClientSession, + mock_initialize_response: dict, + mock_tools_list_response: dict, +): + """Tests getting a single existing tool.""" + url = f"{TEST_BASE_URL}/mcp/" + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + + m.post(url, status=200, payload=mock_tools_list_response) + result = await transport.tool_get(TEST_TOOL_NAME) + + assert len(result.tools) == 1 + assert TEST_TOOL_NAME in result.tools + + +@pytest.mark.asyncio +async def test_tool_get_not_found_raises_error( + http_session: ClientSession, + mock_initialize_response: dict, + mock_tools_list_response: dict, +): + """Tests that getting a non-existent tool raises ValueError.""" + url = f"{TEST_BASE_URL}/mcp/" + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + + m.post(url, status=200, payload=mock_tools_list_response) + with pytest.raises(ValueError, match="Tool 'non_existent_tool' not found."): + await transport.tool_get("non_existent_tool") + + +@pytest.mark.asyncio +async def test_tool_invoke_success( + http_session: ClientSession, mock_initialize_response: dict +): + """Tests a successful tool_invoke call.""" + url = f"{TEST_BASE_URL}/mcp/" + invoke_response = { + "jsonrpc": "2.0", + "id": "4", + "result": {"content": [{"text": "success"}]}, + } + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + + m.post(url, status=200, payload=invoke_response) + result = await transport.tool_invoke(TEST_TOOL_NAME, {"arg": "val"}, {}) + assert result == "success" + + +@pytest.mark.asyncio +async def test_http_request_failure( + http_session: ClientSession, mock_initialize_response: dict +): + """Tests that a non-200 response raises a RuntimeError.""" + url = f"{TEST_BASE_URL}/mcp/" + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + m.post(url, status=500, body="Internal Server Error") + with pytest.raises(RuntimeError) as exc_info: + await transport.tools_list() + + assert "API request failed with status 500" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_json_rpc_error( + http_session: ClientSession, mock_initialize_response: dict +): + """Tests that a response with a JSON-RPC error raises a RuntimeError.""" + url = f"{TEST_BASE_URL}/mcp/" + error_response = { + "jsonrpc": "2.0", + "id": "5", + "error": {"code": -32601, "message": "Method not found"}, + } + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_LATEST, + ) + m.post(url, status=200, payload=error_response) + + with pytest.raises(RuntimeError, match="MCP request failed with code -32601"): + await transport.tools_list() + + +@pytest.mark.asyncio +async def test_v2025_06_18_adds_protocol_header( + http_session: ClientSession, + mock_tools_list_response: dict, + mock_initialize_response: dict, +): + """Tests that MCP v2025-06-18 adds the MCP-Protocol-Version header.""" + url = f"{TEST_BASE_URL}/mcp/" + protocol_version = "2025-06-18" + + mock_initialize_response["result"]["protocolVersion"] = protocol_version + + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_v20250618, + ) + + m.post(url, status=200, payload=mock_tools_list_response) + await transport.tools_list() + + calls = m.requests.get(("POST", URL(url))) + assert calls is not None + + # There will be 3 calls: initialize, initialized, and tools/list + assert len(calls) == 3 + + # Check the last call (tools/list) for the header + list_request = calls[2] + assert "MCP-Protocol-Version" in list_request.kwargs["headers"] + assert ( + list_request.kwargs["headers"]["MCP-Protocol-Version"] == protocol_version + ) + + +@pytest.mark.asyncio +async def test_v2025_03_26_session_id_handling( + http_session: ClientSession, + mock_tools_list_response: dict, + mock_initialize_response: dict, +): + """Tests that MCP v2025-03-26 correctly handles the session ID.""" + session_id = "test-session-123" + url = f"{TEST_BASE_URL}/mcp/" + protocol_version = "2025-03-26" + + # The client expects protocolVersion inside serverInfo + mock_initialize_response["result"]["protocolVersion"] = protocol_version + mock_initialize_response["result"]["Mcp-Session-Id"] = session_id + + with aioresponses() as m: + m.post(url, status=200, payload=mock_initialize_response) + m.post(url, status=204) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_v20250326, + ) + + m.post(url, status=200, payload=mock_tools_list_response) + await transport.tools_list() + + calls = m.requests.get(("POST", URL(url))) + assert calls is not None + assert len(calls) == 3 + + list_request = calls[2] + sent_payload = list_request.kwargs["json"] + assert "Mcp-Session-Id" in sent_payload["params"] + assert sent_payload["params"]["Mcp-Session-Id"] == session_id + + +@pytest.mark.asyncio +async def test_v2025_03_26_missing_session_id_raises_error( + http_session: ClientSession, +): + """Tests that initialization fails for v2025-03-26 if no session ID is returned.""" + url = f"{TEST_BASE_URL}/mcp/" + init_response_no_session = { + "jsonrpc": "2.0", + "id": "1", + "result": { + "serverInfo": { + "name": "Fake MCP Server", + "version": "1.0.0", + "protocolVersion": "2025-03-26", + }, + "capabilities": {"tools": {}}, + }, + } + + with aioresponses() as m: + m.post(url, status=200, payload=init_response_no_session) + + transport = McpHttpTransport( + base_url=TEST_BASE_URL, + session=http_session, + protocol=Protocol.MCP_v20250326, + ) + with pytest.raises( + RuntimeError, + match="Server did not return a Mcp-Session-Id during initialization.", + ): + # Trigger the lazy initialization to cause the error + await transport.tools_list() diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index dae95f61..b5f00067 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -18,7 +18,20 @@ import pytest -from toolbox_core.protocol import AdditionalPropertiesSchema, ParameterSchema +from toolbox_core.protocol import AdditionalPropertiesSchema, ParameterSchema, Protocol + + +def test_get_supported_mcp_versions(): + """ + Tests that get_supported_mcp_versions returns the correct list of versions, + sorted from newest to oldest. + """ + expected_versions = ["2025-06-18", "2025-03-26", "2024-11-05"] + supported_versions = Protocol.get_supported_mcp_versions() + + assert supported_versions == expected_versions + # Also verify that the non-MCP members are not included + assert "toolbox" not in supported_versions def test_parameter_schema_float():