Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client handling for sampling, list roots, ping #218

Merged
merged 17 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,21 @@ server_params = StdioServerParameters(
env=None # Optional environment variables
)

# Optional: create a sampling callback
async def handle_sampling_message(message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(
type="text",
text="Hello, world! from model",
),
model="gpt-3.5-turbo",
stopReason="endTurn",
)

async def run():
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
async with ClientSession(read, write, sampling_callback=handle_sampling_message) as session:
# Initialize the connection
await session.initialize()

Expand Down
98 changes: 89 additions & 9 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,51 @@
from datetime import timedelta
from typing import Any, Protocol

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from pydantic import AnyUrl, TypeAdapter

import mcp.types as types
from mcp.shared.session import BaseSession
from mcp.shared.context import RequestContext
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS


class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ...


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)


ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)


class ClientSession(
BaseSession[
types.ClientRequest,
Expand All @@ -22,6 +60,8 @@ def __init__(
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
) -> None:
super().__init__(
read_stream,
Expand All @@ -30,23 +70,34 @@ def __init__(
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback

async def initialize(self) -> types.InitializeResult:
sampling = (
types.SamplingCapability() if self._sampling_callback is not None else None
)
roots = (
types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
if self._list_roots_callback is not None
else None
)

result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=None,
sampling=sampling,
experimental=None,
roots=types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True
),
roots=roots,
),
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
),
Expand Down Expand Up @@ -243,3 +294,32 @@ async def send_roots_list_changed(self) -> None:
)
)
)

async def _received_request(
self, responder: RequestResponder[types.ServerRequest, types.ClientResult]
) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)

match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.PingRequest():
with responder:
return await responder.respond(
types.ClientResult(root=types.EmptyResult())
)
6 changes: 5 additions & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.session import ClientSession
from mcp.client.session import ClientSession, ListRootsFnT, SamplingFnT
from mcp.server import Server
from mcp.types import JSONRPCMessage

Expand Down Expand Up @@ -54,6 +54,8 @@ async def create_client_server_memory_streams() -> (
async def create_connected_server_and_client_session(
server: Server,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
raise_exceptions: bool = False,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
Expand All @@ -80,6 +82,8 @@ async def create_connected_server_and_client_session(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
70 changes: 70 additions & 0 deletions tests/client/test_list_roots_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from pydantic import FileUrl

from mcp.client.session import ClientSession
from mcp.server.fastmcp.server import Context
from mcp.shared.context import RequestContext
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
ListRootsResult,
Root,
TextContent,
)


@pytest.mark.anyio
async def test_list_roots_callback():
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")

callback_return = ListRootsResult(
roots=[
Root(
uri=FileUrl("file://users/fake/test"),
name="Test Root 1",
),
Root(
uri=FileUrl("file://users/fake/test/2"),
name="Test Root 2",
),
]
)

async def list_roots_callback(
context: RequestContext[ClientSession, None],
) -> ListRootsResult:
return callback_return

@server.tool("test_list_roots")
async def test_list_roots(context: Context, message: str):
roots = await context.session.list_roots()
assert roots == callback_return
return True

# Test with list_roots callback
async with create_session(
server._mcp_server, list_roots_callback=list_roots_callback
) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Test without list_roots callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_list_roots", {"message": "test message"}
)
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_list_roots: List roots not supported"
)
73 changes: 73 additions & 0 deletions tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import (
create_connected_server_and_client_session as create_session,
)
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
SamplingMessage,
TextContent,
)


@pytest.mark.anyio
async def test_sampling_callback():
from mcp.server.fastmcp import FastMCP

server = FastMCP("test")

callback_return = CreateMessageResult(
role="assistant",
content=TextContent(
type="text", text="This is a response from the sampling callback"
),
model="test-model",
stopReason="endTurn",
)

async def sampling_callback(
context: RequestContext[ClientSession, None],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return callback_return

@server.tool("test_sampling")
async def test_sampling_tool(message: str):
value = await server.get_context().session.create_message(
messages=[
SamplingMessage(
role="user", content=TextContent(type="text", text=message)
)
],
max_tokens=100,
)
assert value == callback_return
return True

# Test with sampling callback
async with create_session(
server._mcp_server, sampling_callback=sampling_callback
) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Test without sampling callback
async with create_session(server._mcp_server) as client_session:
# Make a request to trigger sampling callback
result = await client_session.call_tool(
"test_sampling", {"message": "Test message for sampling"}
)
assert result.isError is True
assert isinstance(result.content[0], TextContent)
assert (
result.content[0].text
== "Error executing tool test_sampling: Sampling not supported"
)
7 changes: 6 additions & 1 deletion tests/client/test_stdio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import shutil

import pytest

from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse

tee: str = shutil.which("tee") # type: ignore


@pytest.mark.anyio
@pytest.mark.skipif(tee is None, reason="could not find tee command")
async def test_stdio_client():
server_parameters = StdioServerParameters(command="/usr/bin/tee")
server_parameters = StdioServerParameters(command=tee)

async with stdio_client(server_parameters) as (read_stream, write_stream):
# Test sending and receiving messages
Expand Down