From cce8519221b6174c76c99c9a17240a9a41866b5b Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 4 Jan 2025 13:28:36 +0000 Subject: [PATCH 01/12] add sampling callback paramater --- src/mcp/client/session.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 27ca74d8..254b5bc1 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,4 +1,6 @@ from datetime import timedelta +from inspect import iscoroutinefunction +from typing import Awaitable, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl @@ -7,6 +9,10 @@ from mcp.shared.session import BaseSession from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +sampling_function_signature = Callable[ + [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] +] + class ClientSession( BaseSession[ @@ -17,11 +23,14 @@ class ClientSession( types.ServerNotification, ] ): + sampling_callback: sampling_function_signature | None = None + def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, + sampling_callback: sampling_function_signature | None = None, ) -> None: super().__init__( read_stream, @@ -31,7 +40,21 @@ def __init__( read_timeout_seconds=read_timeout_seconds, ) + # validate sampling_callback + # use asserts here because this should be known at compile time + if sampling_callback is not None: + assert callable(sampling_callback), "sampling_callback must be callable" + assert iscoroutinefunction( + sampling_callback + ), "sampling_callback must be an async function" + + self.sampling_callback = sampling_callback + async def initialize(self) -> types.InitializeResult: + sampling = None + if self.sampling_callback is not None: + sampling = types.SamplingCapability() + result = await self.send_request( types.ClientRequest( types.InitializeRequest( @@ -39,7 +62,7 @@ async def initialize(self) -> types.InitializeResult: params=types.InitializeRequestParams( protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( - sampling=None, + sampling=sampling, experimental=None, roots=types.RootsCapability( # TODO: Should this be based on whether we From 368782c15543a655d06e6098afe81e3926275712 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 4 Jan 2025 16:29:37 +0000 Subject: [PATCH 02/12] add request handler --- src/mcp/client/session.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 254b5bc1..87df7800 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -6,7 +6,7 @@ from pydantic import AnyUrl import mcp.types as types -from mcp.shared.session import BaseSession +from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS sampling_function_signature = Callable[ @@ -255,3 +255,17 @@ async def send_roots_list_changed(self) -> None: ) ) ) + + async def _received_request( + self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] + ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): + print("Received create message request") + if self.sampling_callback is None: + raise RuntimeError("Sampling callback is not set") + response = await self.sampling_callback(responder.request.root.params) + + client_response = types.ClientResult(**response.model_dump()) + + print(f"Response: {response.dict()}") + await responder.respond(client_response) From f9de5f096d30e379386b801e1c2865ce3ae6f940 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 16:20:20 +0000 Subject: [PATCH 03/12] cleanup print statements --- src/mcp/client/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 87df7800..71284ab0 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -259,13 +259,13 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: + if isinstance(responder.request.root, types.CreateMessageRequest): - print("Received create message request") + # handle create message request (sampling) + if self.sampling_callback is None: raise RuntimeError("Sampling callback is not set") + response = await self.sampling_callback(responder.request.root.params) - client_response = types.ClientResult(**response.model_dump()) - - print(f"Response: {response.dict()}") await responder.respond(client_response) From 9e68fa8f144d54a9771df6b8a0be43c3ab32d162 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:32:35 +0000 Subject: [PATCH 04/12] add docs to readme --- README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 335542c7..f8852d21 100644 --- a/README.md +++ b/README.md @@ -417,9 +417,21 @@ server_params = StdioServerParameters( env=None # Optional environment variables ) +# Optional: create a sampling callback +async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent( + type="text", + text="Hello, world! from model", + ), + model="gpt-3.5-turbo", + stopReason="endTurn", + ) + async def run(): async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: + async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session: # Initialize the connection await session.initialize() From f979ca9980689df7816a99bcfc391a653f73eca4 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Tue, 7 Jan 2025 18:47:15 +0000 Subject: [PATCH 05/12] ruff format --- src/mcp/client/session.py | 3 +-- tests/client/test_stdio.py | 7 ++++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 71284ab0..93206d72 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -259,13 +259,12 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: - if isinstance(responder.request.root, types.CreateMessageRequest): # handle create message request (sampling) if self.sampling_callback is None: raise RuntimeError("Sampling callback is not set") - + response = await self.sampling_callback(responder.request.root.params) client_response = types.ClientResult(**response.model_dump()) await responder.respond(client_response) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 0bdec72d..ba9461e6 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,12 +1,17 @@ +import shutil + import pytest from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +tee: str = shutil.which("tee") # type: ignore +assert tee is not None, "could not find tee command" + @pytest.mark.anyio async def test_stdio_client(): - server_parameters = StdioServerParameters(command="/usr/bin/tee") + server_parameters = StdioServerParameters(command=tee) async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages From e2e2f4335c471a69d44c4b8435c0b8d9e334e7c3 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Feb 2025 15:30:51 +0000 Subject: [PATCH 06/12] simplify the implementation --- src/mcp/client/session.py | 28 +++++++--------------------- 1 file changed, 7 insertions(+), 21 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 93206d72..cf1aee13 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,5 +1,4 @@ from datetime import timedelta -from inspect import iscoroutinefunction from typing import Awaitable, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -39,21 +38,12 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - - # validate sampling_callback - # use asserts here because this should be known at compile time - if sampling_callback is not None: - assert callable(sampling_callback), "sampling_callback must be callable" - assert iscoroutinefunction( - sampling_callback - ), "sampling_callback must be an async function" - self.sampling_callback = sampling_callback async def initialize(self) -> types.InitializeResult: - sampling = None - if self.sampling_callback is not None: - sampling = types.SamplingCapability() + sampling = ( + types.SamplingCapability() if self.sampling_callback is not None else None + ) result = await self.send_request( types.ClientRequest( @@ -260,11 +250,7 @@ async def _received_request( self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] ) -> None: if isinstance(responder.request.root, types.CreateMessageRequest): - # handle create message request (sampling) - - if self.sampling_callback is None: - raise RuntimeError("Sampling callback is not set") - - response = await self.sampling_callback(responder.request.root.params) - client_response = types.ClientResult(**response.model_dump()) - await responder.respond(client_response) + if self.sampling_callback is not None: + response = await self.sampling_callback(responder.request.root.params) + client_response = types.ClientResult(root=response) + await responder.respond(client_response) From 8f0f7c5d00ed00195c28d4f88d5e31fe6ceb33a7 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 13 Feb 2025 16:29:40 +0000 Subject: [PATCH 07/12] fix: simplify implementation and add test --- src/mcp/client/session.py | 9 +++-- src/mcp/shared/memory.py | 4 +- tests/client/test_sampling_callback.py | 53 ++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 tests/client/test_sampling_callback.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cf1aee13..caa8e0f2 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -sampling_function_signature = Callable[ +SamplingFnT = Callable[ [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] ] @@ -22,14 +22,14 @@ class ClientSession( types.ServerNotification, ] ): - sampling_callback: sampling_function_signature | None = None + sampling_callback: SamplingFnT | None = None def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, - sampling_callback: sampling_function_signature | None = None, + sampling_callback: SamplingFnT | None = None, ) -> None: super().__init__( read_stream, @@ -253,4 +253,5 @@ async def _received_request( if self.sampling_callback is not None: response = await self.sampling_callback(responder.request.root.params) client_response = types.ClientResult(root=response) - await responder.respond(client_response) + with responder: + await responder.respond(client_response) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 72549925..0900cfd8 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -54,6 +54,7 @@ async def create_client_server_memory_streams() -> ( async def create_connected_server_and_client_session( server: Server, read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -80,6 +81,7 @@ async def create_connected_server_and_client_session( read_stream=client_read, write_stream=client_write, read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py new file mode 100644 index 00000000..5f4ff30f --- /dev/null +++ b/tests/client/test_sampling_callback.py @@ -0,0 +1,53 @@ +import pytest + +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) + + +@pytest.mark.anyio +async def test_sampling_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = CreateMessageResult( + role="assistant", + content=TextContent( + type="text", text="This is a response from the sampling callback" + ), + model="test-model", + stopReason="endTurn", + ) + + async def sampling_callback( + message: CreateMessageRequestParams, + ) -> CreateMessageResult: + return callback_return + + @server.tool("test_sampling") + async def test_sampling_tool(message: str): + value = await server.get_context().session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=message) + ) + ], + max_tokens=100, + ) + assert value == callback_return + return True + + async with create_session( + server._mcp_server, sampling_callback=sampling_callback + ) as client_session: + # Make a request to trigger sampling callback + assert await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + ) From a48cb68aa7329eeff7a88aa68d2926a798e5415c Mon Sep 17 00:00:00 2001 From: Jerome Date: Wed, 19 Feb 2025 15:37:31 +0000 Subject: [PATCH 08/12] Added roots callback also --- src/mcp/client/session.py | 74 ++++++++++++++++++------ src/mcp/shared/memory.py | 4 +- tests/client/test_list_roots_callback.py | 50 ++++++++++++++++ tests/client/test_sampling_callback.py | 5 +- 4 files changed, 112 insertions(+), 21 deletions(-) create mode 100644 tests/client/test_list_roots_callback.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 974c61f7..fb3b4d7a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,16 +1,27 @@ from datetime import timedelta -from typing import Awaitable, Callable +from typing import Protocol, Any from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from mcp.shared.context import RequestContext import mcp.types as types from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -SamplingFnT = Callable[ - [types.CreateMessageRequestParams], Awaitable[types.CreateMessageResult] -] + +class SamplingFnT(Protocol): + async def __call__( + self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + ... + + +class ListRootsFnT(Protocol): + async def __call__( + self, context: RequestContext["ClientSession", Any] + ) -> types.ListRootsResult: + ... class ClientSession( @@ -22,7 +33,7 @@ class ClientSession( types.ServerNotification, ] ): - sampling_callback: SamplingFnT | None = None + _sampling_callback: SamplingFnT | None = None def __init__( self, @@ -30,6 +41,7 @@ def __init__( write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, ) -> None: super().__init__( read_stream, @@ -38,11 +50,22 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - self.sampling_callback = sampling_callback + self._sampling_callback = sampling_callback + self._list_roots_callback = list_roots_callback async def initialize(self) -> types.InitializeResult: sampling = ( - types.SamplingCapability() if self.sampling_callback is not None else None + types.SamplingCapability() if self._sampling_callback is not None else None + ) + roots = ( + types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, + ) + if self._list_roots_callback is not None + else None ) result = await self.send_request( @@ -54,12 +77,7 @@ async def initialize(self) -> types.InitializeResult: capabilities=types.ClientCapabilities( sampling=sampling, experimental=None, - roots=types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True - ), + roots=roots, ), clientInfo=types.Implementation(name="mcp", version="0.1.0"), ), @@ -258,11 +276,29 @@ async def send_roots_list_changed(self) -> None: ) async def _received_request( - self, responder: RequestResponder["types.ServerRequest", "types.ClientResult"] + self, responder: RequestResponder[types.ServerRequest, types.ClientResult] ) -> None: - if isinstance(responder.request.root, types.CreateMessageRequest): - if self.sampling_callback is not None: - response = await self.sampling_callback(responder.request.root.params) - client_response = types.ClientResult(root=response) + + ctx = RequestContext[ClientSession, Any]( + request_id=responder.request_id, + meta=responder.request_meta, + session=self, + lifespan_context=None, + ) + + match responder.request.root: + case types.CreateMessageRequest: + if self._sampling_callback is not None: + response = await self._sampling_callback(ctx, responder.request.root.params) + client_response = types.ClientResult(root=response) + with responder: + await responder.respond(client_response) + case types.ListRootsRequest: + if self._list_roots_callback is not None: + response = await self._list_roots_callback(ctx) + client_response = types.ClientResult(root=response) + with responder: + await responder.respond(client_response) + case types.PingRequest: with responder: - await responder.respond(client_response) + await responder.respond(types.ClientResult(root=types.EmptyResult())) \ No newline at end of file diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 0900cfd8..ae6b0be5 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -9,7 +9,7 @@ import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from mcp.client.session import ClientSession, SamplingFnT +from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT from mcp.server import Server from mcp.types import JSONRPCMessage @@ -55,6 +55,7 @@ async def create_connected_server_and_client_session( server: Server, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, raise_exceptions: bool = False, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" @@ -82,6 +83,7 @@ async def create_connected_server_and_client_session( write_stream=client_write, read_timeout_seconds=read_timeout_seconds, sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py new file mode 100644 index 00000000..66d65084 --- /dev/null +++ b/tests/client/test_list_roots_callback.py @@ -0,0 +1,50 @@ +from pydantic import FileUrl +import pytest + +from mcp.client.session import ClientSession +from mcp.server.fastmcp.server import Context +from mcp.shared.context import RequestContext +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) +from mcp.types import ( + ListRootsResult, + Root, +) + + +@pytest.mark.anyio +async def test_list_roots_callback(): + from mcp.server.fastmcp import FastMCP + + server = FastMCP("test") + + callback_return = ListRootsResult(roots=[ + Root( + uri=FileUrl("test://users/fake/test"), + name="Test Root 1", + ), + Root( + uri=FileUrl("test://users/fake/test/2"), + name="Test Root 2", + ) + ]) + + async def list_roots_callback( + context: RequestContext[ClientSession, None] + ) -> ListRootsResult: + return callback_return + + @server.tool("test_list_roots") + async def test_list_roots(context: Context, message: str): + roots = context.session.list_roots() + assert roots == callback_return + return True + + async with create_session( + server._mcp_server, list_roots_callback=list_roots_callback + ) as client_session: + # Make a request to trigger sampling callback + assert await client_session.call_tool( + "test_list_roots", {"message": "test message"} + ) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 5f4ff30f..d46e5cdd 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,5 +1,7 @@ import pytest +from mcp.client.session import ClientSession +from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, ) @@ -27,7 +29,8 @@ async def test_sampling_callback(): ) async def sampling_callback( - message: CreateMessageRequestParams, + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return From 9277c691335c109220d468ecaa1f0d4b90e5bdf8 Mon Sep 17 00:00:00 2001 From: Jerome Date: Wed, 19 Feb 2025 15:42:34 +0000 Subject: [PATCH 09/12] Fixed match case for request types --- src/mcp/client/session.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fb3b4d7a..3af72a7d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -287,18 +287,18 @@ async def _received_request( ) match responder.request.root: - case types.CreateMessageRequest: + case types.CreateMessageRequest(params=params): if self._sampling_callback is not None: - response = await self._sampling_callback(ctx, responder.request.root.params) + response = await self._sampling_callback(ctx, params) client_response = types.ClientResult(root=response) with responder: await responder.respond(client_response) - case types.ListRootsRequest: + case types.ListRootsRequest(): if self._list_roots_callback is not None: response = await self._list_roots_callback(ctx) client_response = types.ClientResult(root=response) with responder: await responder.respond(client_response) - case types.PingRequest: + case types.PingRequest(): with responder: await responder.respond(types.ClientResult(root=types.EmptyResult())) \ No newline at end of file From 005483aee8aa90ad6fbb37923b663ea9c008da6e Mon Sep 17 00:00:00 2001 From: Jerome Date: Wed, 19 Feb 2025 22:20:01 +0000 Subject: [PATCH 10/12] Refactored default behaviour, updated tests --- src/mcp/client/session.py | 67 ++++++++++++++++-------- tests/client/test_list_roots_callback.py | 26 +++++++-- tests/client/test_sampling_callback.py | 16 +++++- 3 files changed, 80 insertions(+), 29 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3af72a7d..37036e2b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,27 +1,49 @@ from datetime import timedelta -from typing import Protocol, Any +from typing import Any, Protocol from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl +from pydantic import AnyUrl, TypeAdapter -from mcp.shared.context import RequestContext import mcp.types as types +from mcp.shared.context import RequestContext from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS class SamplingFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any], params: types.CreateMessageRequestParams - ) -> types.CreateMessageResult: - ... + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: ... class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult: - ... + ) -> types.ListRootsResult | types.ErrorData: ... + + +async def _default_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Sampling not supported", + ) + + +async def _default_list_roots_callback( + context: RequestContext["ClientSession", Any], +) -> types.ListRootsResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="List roots not supported", + ) + + +ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData) class ClientSession( @@ -33,8 +55,6 @@ class ClientSession( types.ServerNotification, ] ): - _sampling_callback: SamplingFnT | None = None - def __init__( self, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], @@ -50,8 +70,8 @@ def __init__( types.ServerNotification, read_timeout_seconds=read_timeout_seconds, ) - self._sampling_callback = sampling_callback - self._list_roots_callback = list_roots_callback + self._sampling_callback = sampling_callback or _default_sampling_callback + self._list_roots_callback = list_roots_callback or _default_list_roots_callback async def initialize(self) -> types.InitializeResult: sampling = ( @@ -278,27 +298,28 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder[types.ServerRequest, types.ClientResult] ) -> None: - ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, lifespan_context=None, ) - + match responder.request.root: case types.CreateMessageRequest(params=params): - if self._sampling_callback is not None: + with responder: response = await self._sampling_callback(ctx, params) - client_response = types.ClientResult(root=response) - with responder: - await responder.respond(client_response) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + case types.ListRootsRequest(): - if self._list_roots_callback is not None: + with responder: response = await self._list_roots_callback(ctx) - client_response = types.ClientResult(root=response) - with responder: - await responder.respond(client_response) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + case types.PingRequest(): with responder: - await responder.respond(types.ClientResult(root=types.EmptyResult())) \ No newline at end of file + return await responder.respond( + types.ClientResult(root=types.EmptyResult()) + ) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 66d65084..989566c6 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,5 +1,5 @@ -from pydantic import FileUrl import pytest +from pydantic import FileUrl from mcp.client.session import ClientSession from mcp.server.fastmcp.server import Context @@ -10,6 +10,7 @@ from mcp.types import ( ListRootsResult, Root, + TextContent, ) @@ -21,11 +22,11 @@ async def test_list_roots_callback(): callback_return = ListRootsResult(roots=[ Root( - uri=FileUrl("test://users/fake/test"), + uri=FileUrl("file://users/fake/test"), name="Test Root 1", ), Root( - uri=FileUrl("test://users/fake/test/2"), + uri=FileUrl("file://users/fake/test/2"), name="Test Root 2", ) ]) @@ -37,14 +38,29 @@ async def list_roots_callback( @server.tool("test_list_roots") async def test_list_roots(context: Context, message: str): - roots = context.session.list_roots() + roots = await context.session.list_roots() assert roots == callback_return return True + # Test with list_roots callback async with create_session( server._mcp_server, list_roots_callback=list_roots_callback ) as client_session: # Make a request to trigger sampling callback - assert await client_session.call_tool( + result = await client_session.call_tool( "test_list_roots", {"message": "test message"} ) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'true' + + # Test without list_roots callback + async with create_session(server._mcp_server) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_list_roots", {"message": "test message"} + ) + assert result.isError is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'Error executing tool test_list_roots: List roots not supported' + diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index d46e5cdd..3ddea5b0 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -47,10 +47,24 @@ async def test_sampling_tool(message: str): assert value == callback_return return True + # Test with sampling callback async with create_session( server._mcp_server, sampling_callback=sampling_callback ) as client_session: # Make a request to trigger sampling callback - assert await client_session.call_tool( + result = await client_session.call_tool( "test_sampling", {"message": "Test message for sampling"} ) + assert result.isError is False + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'true' + + # Test without sampling callback + async with create_session(server._mcp_server) as client_session: + # Make a request to trigger sampling callback + result = await client_session.call_tool( + "test_sampling", {"message": "Test message for sampling"} + ) + assert result.isError is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'Error executing tool test_sampling: Sampling not supported' From 56185976d0cd8ce9fbbda89bbb63b6a46f3ae38b Mon Sep 17 00:00:00 2001 From: Jerome Date: Wed, 19 Feb 2025 22:38:33 +0000 Subject: [PATCH 11/12] Formatted --- tests/client/test_list_roots_callback.py | 32 +++++++++++++----------- tests/client/test_sampling_callback.py | 7 ++++-- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 989566c6..384e7676 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -20,19 +20,21 @@ async def test_list_roots_callback(): server = FastMCP("test") - callback_return = ListRootsResult(roots=[ - Root( - uri=FileUrl("file://users/fake/test"), - name="Test Root 1", - ), - Root( - uri=FileUrl("file://users/fake/test/2"), - name="Test Root 2", - ) - ]) + callback_return = ListRootsResult( + roots=[ + Root( + uri=FileUrl("file://users/fake/test"), + name="Test Root 1", + ), + Root( + uri=FileUrl("file://users/fake/test/2"), + name="Test Root 2", + ), + ] + ) async def list_roots_callback( - context: RequestContext[ClientSession, None] + context: RequestContext[ClientSession, None], ) -> ListRootsResult: return callback_return @@ -52,7 +54,7 @@ async def test_list_roots(context: Context, message: str): ) assert result.isError is False assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'true' + assert result.content[0].text == "true" # Test without list_roots callback async with create_session(server._mcp_server) as client_session: @@ -62,5 +64,7 @@ async def test_list_roots(context: Context, message: str): ) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'Error executing tool test_list_roots: List roots not supported' - + assert ( + result.content[0].text + == "Error executing tool test_list_roots: List roots not supported" + ) diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 3ddea5b0..ba586d4a 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -57,7 +57,7 @@ async def test_sampling_tool(message: str): ) assert result.isError is False assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'true' + assert result.content[0].text == "true" # Test without sampling callback async with create_session(server._mcp_server) as client_session: @@ -67,4 +67,7 @@ async def test_sampling_tool(message: str): ) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'Error executing tool test_sampling: Sampling not supported' + assert ( + result.content[0].text + == "Error executing tool test_sampling: Sampling not supported" + ) From 9e3709104821490daf6f0bb8025a7ea47f59fd32 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 20 Feb 2025 10:47:31 +0000 Subject: [PATCH 12/12] fix: skip stdio test if `tee` cannot be found --- tests/client/test_stdio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index ba9461e6..95747ffd 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -6,10 +6,10 @@ from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore -assert tee is not None, "could not find tee command" @pytest.mark.anyio +@pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_client(): server_parameters = StdioServerParameters(command=tee)