Skip to content

Commit 903bd3e

Browse files
committed
feat: expose sampling_capabilities parameter for SEP-1577 tool support
1 parent 8e02fc1 commit 903bd3e

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

src/mcp/client/session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
write_stream: MemoryObjectSendStream[SessionMessage],
116116
read_timeout_seconds: timedelta | None = None,
117117
sampling_callback: SamplingFnT | None = None,
118+
sampling_capabilities: types.SamplingCapability | None = None,
118119
elicitation_callback: ElicitationFnT | None = None,
119120
list_roots_callback: ListRootsFnT | None = None,
120121
logging_callback: LoggingFnT | None = None,
@@ -132,6 +133,7 @@ def __init__(
132133
)
133134
self._client_info = client_info or DEFAULT_CLIENT_INFO
134135
self._sampling_callback = sampling_callback or _default_sampling_callback
136+
self._sampling_capabilities = sampling_capabilities
135137
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
136138
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
137139
self._logging_callback = logging_callback or _default_logging_callback
@@ -144,7 +146,11 @@ def __init__(
144146
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
145147

146148
async def initialize(self) -> types.InitializeResult:
147-
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
149+
sampling = (
150+
(self._sampling_capabilities or types.SamplingCapability())
151+
if self._sampling_callback is not _default_sampling_callback
152+
else None
153+
)
148154
elicitation = (
149155
types.ElicitationCapability(
150156
form=types.FormElicitationCapability(),

tests/client/test_session.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,93 @@ async def mock_server():
497497
# Custom sampling callback provided
498498
assert received_capabilities.sampling is not None
499499
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
500+
# Default sampling capabilities (no tools)
501+
assert received_capabilities.sampling.tools is None
500502
# Custom list_roots callback provided
501503
assert received_capabilities.roots is not None
502504
assert isinstance(received_capabilities.roots, types.RootsCapability)
503505
# Should be True for custom callback
504506
assert received_capabilities.roots.listChanged is True
505507

506508

509+
@pytest.mark.anyio
510+
async def test_client_capabilities_with_sampling_tools():
511+
"""Test that sampling capabilities with tools are properly advertised"""
512+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
513+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
514+
515+
received_capabilities = None
516+
517+
async def custom_sampling_callback( # pragma: no cover
518+
context: RequestContext["ClientSession", Any],
519+
params: types.CreateMessageRequestParams,
520+
) -> types.CreateMessageResult | types.ErrorData:
521+
return types.CreateMessageResult(
522+
role="assistant",
523+
content=types.TextContent(type="text", text="test"),
524+
model="test-model",
525+
)
526+
527+
async def mock_server():
528+
nonlocal received_capabilities
529+
530+
session_message = await client_to_server_receive.receive()
531+
jsonrpc_request = session_message.message
532+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
533+
request = ClientRequest.model_validate(
534+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
535+
)
536+
assert isinstance(request.root, InitializeRequest)
537+
received_capabilities = request.root.params.capabilities
538+
539+
result = ServerResult(
540+
InitializeResult(
541+
protocolVersion=LATEST_PROTOCOL_VERSION,
542+
capabilities=ServerCapabilities(),
543+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
544+
)
545+
)
546+
547+
async with server_to_client_send:
548+
await server_to_client_send.send(
549+
SessionMessage(
550+
JSONRPCMessage(
551+
JSONRPCResponse(
552+
jsonrpc="2.0",
553+
id=jsonrpc_request.root.id,
554+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
555+
)
556+
)
557+
)
558+
)
559+
# Receive initialized notification
560+
await client_to_server_receive.receive()
561+
562+
async with (
563+
ClientSession(
564+
server_to_client_receive,
565+
client_to_server_send,
566+
sampling_callback=custom_sampling_callback,
567+
sampling_capabilities=types.SamplingCapability(tools=types.SamplingToolsCapability()),
568+
) as session,
569+
anyio.create_task_group() as tg,
570+
client_to_server_send,
571+
client_to_server_receive,
572+
server_to_client_send,
573+
server_to_client_receive,
574+
):
575+
tg.start_soon(mock_server)
576+
await session.initialize()
577+
578+
# Assert that sampling capabilities with tools are properly advertised
579+
assert received_capabilities is not None
580+
assert received_capabilities.sampling is not None
581+
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
582+
# Tools capability should be present
583+
assert received_capabilities.sampling.tools is not None
584+
assert isinstance(received_capabilities.sampling.tools, types.SamplingToolsCapability)
585+
586+
507587
@pytest.mark.anyio
508588
async def test_get_server_capabilities():
509589
"""Test that get_server_capabilities returns None before init and capabilities after"""

0 commit comments

Comments
 (0)