Skip to content

Commit 9830794

Browse files
feat(client): Add capability_extensions parameter to ClientSession
Add a new `capability_extensions` parameter to `ClientSession.__init__()` that allows clients to include additional capability fields in the initialize request. This enables clients to advertise protocol extensions (like `io.modelcontextprotocol/ui`) without having to override the `initialize()` method. Example usage: ```python session = ClientSession( read_stream, write_stream, capability_extensions={ "extensions": { "io.modelcontextprotocol/ui": { "mimeTypes": ["text/html;profile=mcp-app"] } } } ) ``` The extensions are merged into `ClientCapabilities` using Pydantic's extra fields feature (`model_config = {'extra': 'allow'}`).
1 parent dcc9b4f commit 9830794

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

src/mcp/client/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def __init__(
121121
*,
122122
sampling_capabilities: types.SamplingCapability | None = None,
123123
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
124+
capability_extensions: dict[str, Any] | None = None,
124125
) -> None:
125126
super().__init__(
126127
read_stream,
@@ -143,6 +144,10 @@ def __init__(
143144
# Experimental: Task handlers (use defaults if not provided)
144145
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
145146

147+
# Capability extensions to include in initialize request
148+
# These are merged into ClientCapabilities using Pydantic's extra fields
149+
self._capability_extensions = capability_extensions or {}
150+
146151
async def initialize(self) -> types.InitializeResult:
147152
sampling = (
148153
(self._sampling_capabilities or types.SamplingCapability())
@@ -177,6 +182,7 @@ async def initialize(self) -> types.InitializeResult:
177182
experimental=None,
178183
roots=roots,
179184
tasks=self._task_handlers.build_capability(),
185+
**self._capability_extensions,
180186
),
181187
client_info=self._client_info,
182188
),

tests/client/test_session.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,73 @@ async def mock_server():
768768
await session.initialize()
769769

770770
await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta)
771+
772+
773+
@pytest.mark.anyio
774+
async def test_client_session_capability_extensions():
775+
"""Test that capability_extensions are included in the initialize request."""
776+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
777+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
778+
779+
received_capabilities = None
780+
781+
# Define capability extensions (e.g., UI extension)
782+
capability_extensions = {"extensions": {"io.modelcontextprotocol/ui": {"mimeTypes": ["text/html;profile=mcp-app"]}}}
783+
784+
async def mock_server():
785+
nonlocal received_capabilities
786+
787+
session_message = await client_to_server_receive.receive()
788+
jsonrpc_request = session_message.message
789+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
790+
request = ClientRequest.model_validate(
791+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
792+
)
793+
assert isinstance(request.root, InitializeRequest)
794+
received_capabilities = request.root.params.capabilities
795+
796+
result = ServerResult(
797+
InitializeResult(
798+
protocol_version=LATEST_PROTOCOL_VERSION,
799+
capabilities=ServerCapabilities(),
800+
server_info=Implementation(name="mock-server", version="0.1.0"),
801+
)
802+
)
803+
804+
async with server_to_client_send:
805+
await server_to_client_send.send(
806+
SessionMessage(
807+
JSONRPCMessage(
808+
JSONRPCResponse(
809+
jsonrpc="2.0",
810+
id=jsonrpc_request.root.id,
811+
result=result.model_dump(by_alias=True, mode="json", exclude_none=True),
812+
)
813+
)
814+
)
815+
)
816+
# Receive initialized notification
817+
await client_to_server_receive.receive()
818+
819+
async with (
820+
ClientSession(
821+
server_to_client_receive,
822+
client_to_server_send,
823+
capability_extensions=capability_extensions,
824+
) as session,
825+
anyio.create_task_group() as tg,
826+
client_to_server_send,
827+
client_to_server_receive,
828+
server_to_client_send,
829+
server_to_client_receive,
830+
):
831+
tg.start_soon(mock_server)
832+
await session.initialize()
833+
834+
# Assert that the capability extensions were included in the request
835+
assert received_capabilities is not None
836+
# The extensions should be present via Pydantic's extra fields
837+
caps_dict = received_capabilities.model_dump()
838+
assert "extensions" in caps_dict
839+
assert "io.modelcontextprotocol/ui" in caps_dict["extensions"]
840+
assert caps_dict["extensions"]["io.modelcontextprotocol/ui"]["mimeTypes"] == ["text/html;profile=mcp-app"]

0 commit comments

Comments
 (0)