Skip to content

Commit 4cd7805

Browse files
Merge branch 'main' into feat/sampling-resources
2 parents 1819922 + 05b7156 commit 4cd7805

File tree

9 files changed

+548
-216
lines changed

9 files changed

+548
-216
lines changed

README.md

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -315,27 +315,42 @@ async def long_task(files: list[str], ctx: Context) -> str:
315315
Authentication can be used by servers that want to expose tools accessing protected resources.
316316

317317
`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by
318-
providing an implementation of the `OAuthServerProvider` protocol.
318+
providing an implementation of the `OAuthAuthorizationServerProvider` protocol.
319319

320-
```
321-
mcp = FastMCP("My App",
322-
auth_server_provider=MyOAuthServerProvider(),
323-
auth=AuthSettings(
324-
issuer_url="https://myapp.com",
325-
revocation_options=RevocationOptions(
326-
enabled=True,
327-
),
328-
client_registration_options=ClientRegistrationOptions(
329-
enabled=True,
330-
valid_scopes=["myscope", "myotherscope"],
331-
default_scopes=["myscope"],
332-
),
333-
required_scopes=["myscope"],
320+
```python
321+
from mcp import FastMCP
322+
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
323+
from mcp.server.auth.settings import (
324+
AuthSettings,
325+
ClientRegistrationOptions,
326+
RevocationOptions,
327+
)
328+
329+
330+
class MyOAuthServerProvider(OAuthAuthorizationServerProvider):
331+
# See an example on how to implement at `examples/servers/simple-auth`
332+
...
333+
334+
335+
mcp = FastMCP(
336+
"My App",
337+
auth_server_provider=MyOAuthServerProvider(),
338+
auth=AuthSettings(
339+
issuer_url="https://myapp.com",
340+
revocation_options=RevocationOptions(
341+
enabled=True,
342+
),
343+
client_registration_options=ClientRegistrationOptions(
344+
enabled=True,
345+
valid_scopes=["myscope", "myotherscope"],
346+
default_scopes=["myscope"],
334347
),
348+
required_scopes=["myscope"],
349+
),
335350
)
336351
```
337352

338-
See [OAuthServerProvider](src/mcp/server/auth/provider.py) for more details.
353+
See [OAuthAuthorizationServerProvider](src/mcp/server/auth/provider.py) for more details.
339354

340355
## Running Your Server
341356

@@ -462,15 +477,12 @@ For low level server with Streamable HTTP implementations, see:
462477
- Stateful server: [`examples/servers/simple-streamablehttp/`](examples/servers/simple-streamablehttp/)
463478
- Stateless server: [`examples/servers/simple-streamablehttp-stateless/`](examples/servers/simple-streamablehttp-stateless/)
464479

465-
466-
467480
The streamable HTTP transport supports:
468481
- Stateful and stateless operation modes
469482
- Resumability with event stores
470-
- JSON or SSE response formats
483+
- JSON or SSE response formats
471484
- Better scalability for multi-node deployments
472485

473-
474486
### Mounting to an Existing ASGI Server
475487

476488
> **Note**: SSE transport is being superseded by [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http).

src/mcp/client/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,18 @@ def __init__(
116116
self._message_handler = message_handler or _default_message_handler
117117

118118
async def initialize(self) -> types.InitializeResult:
119-
sampling = types.SamplingCapability()
120-
roots = types.RootsCapability(
119+
sampling = (
120+
types.SamplingCapability()
121+
if self._sampling_callback is not _default_sampling_callback
122+
else None
123+
)
124+
roots = (
121125
# TODO: Should this be based on whether we
122126
# _will_ send notifications, or only whether
123127
# they're supported?
124-
listChanged=True,
128+
types.RootsCapability(listChanged=True)
129+
if self._list_roots_callback is not _default_list_roots_callback
130+
else None
125131
)
126132

127133
result = await self.send_request(

src/mcp/client/stdio/__init__.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
108108
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
109109
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
110110

111-
command = _get_executable_command(server.command)
112-
113-
# Open process with stderr piped for capture
114-
process = await _create_platform_compatible_process(
115-
command=command,
116-
args=server.args,
117-
env=(
118-
{**get_default_environment(), **server.env}
119-
if server.env is not None
120-
else get_default_environment()
121-
),
122-
errlog=errlog,
123-
cwd=server.cwd,
124-
)
111+
try:
112+
command = _get_executable_command(server.command)
113+
114+
# Open process with stderr piped for capture
115+
process = await _create_platform_compatible_process(
116+
command=command,
117+
args=server.args,
118+
env=(
119+
{**get_default_environment(), **server.env}
120+
if server.env is not None
121+
else get_default_environment()
122+
),
123+
errlog=errlog,
124+
cwd=server.cwd,
125+
)
126+
except OSError:
127+
# Clean up streams if process creation fails
128+
await read_stream.aclose()
129+
await write_stream.aclose()
130+
await read_stream_writer.aclose()
131+
await write_stream_reader.aclose()
132+
raise
125133

126134
async def stdout_reader():
127135
assert process.stdout, "Opened process is missing stdout"
@@ -177,12 +185,18 @@ async def stdin_writer():
177185
yield read_stream, write_stream
178186
finally:
179187
# Clean up process to prevent any dangling orphaned processes
180-
if sys.platform == "win32":
181-
await terminate_windows_process(process)
182-
else:
183-
process.terminate()
188+
try:
189+
if sys.platform == "win32":
190+
await terminate_windows_process(process)
191+
else:
192+
process.terminate()
193+
except ProcessLookupError:
194+
# Process already exited, which is fine
195+
pass
184196
await read_stream.aclose()
185197
await write_stream.aclose()
198+
await read_stream_writer.aclose()
199+
await write_stream_reader.aclose()
186200

187201

188202
def _get_executable_command(command: str) -> str:

src/mcp/server/streamable_http.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ async def _handle_post_request(
397397
await response(scope, receive, send)
398398

399399
# Process the message after sending the response
400-
session_message = SessionMessage(message)
400+
metadata = ServerMessageMetadata(request_context=request)
401+
session_message = SessionMessage(message, metadata=metadata)
401402
await writer.send(session_message)
402403

403404
return
@@ -412,7 +413,8 @@ async def _handle_post_request(
412413

413414
if self.is_json_response_enabled:
414415
# Process the message
415-
session_message = SessionMessage(message)
416+
metadata = ServerMessageMetadata(request_context=request)
417+
session_message = SessionMessage(message, metadata=metadata)
416418
await writer.send(session_message)
417419
try:
418420
# Process messages from the request-specific stream
@@ -511,7 +513,8 @@ async def sse_writer():
511513
async with anyio.create_task_group() as tg:
512514
tg.start_soon(response, scope, receive, send)
513515
# Then send the message to be processed by the server
514-
session_message = SessionMessage(message)
516+
metadata = ServerMessageMetadata(request_context=request)
517+
session_message = SessionMessage(message, metadata=metadata)
515518
await writer.send(session_message)
516519
except Exception:
517520
logger.exception("SSE response error")

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class RootsCapability(BaseModel):
218218

219219

220220
class SamplingCapability(BaseModel):
221-
"""Capability for logging operations."""
221+
"""Capability for sampling operations."""
222222

223223
model_config = ConfigDict(extra="allow")
224224

tests/client/test_session.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Any
2+
13
import anyio
24
import pytest
35

46
import mcp.types as types
57
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
8+
from mcp.shared.context import RequestContext
69
from mcp.shared.message import SessionMessage
710
from mcp.shared.session import RequestResponder
811
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -380,3 +383,167 @@ async def mock_server():
380383
# Should raise RuntimeError for unsupported version
381384
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
382385
await session.initialize()
386+
387+
388+
@pytest.mark.anyio
389+
async def test_client_capabilities_default():
390+
"""Test that client capabilities are properly set with default callbacks"""
391+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
392+
SessionMessage
393+
](1)
394+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
395+
SessionMessage
396+
](1)
397+
398+
received_capabilities = None
399+
400+
async def mock_server():
401+
nonlocal received_capabilities
402+
403+
session_message = await client_to_server_receive.receive()
404+
jsonrpc_request = session_message.message
405+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
406+
request = ClientRequest.model_validate(
407+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
408+
)
409+
assert isinstance(request.root, InitializeRequest)
410+
received_capabilities = request.root.params.capabilities
411+
412+
result = ServerResult(
413+
InitializeResult(
414+
protocolVersion=LATEST_PROTOCOL_VERSION,
415+
capabilities=ServerCapabilities(),
416+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
417+
)
418+
)
419+
420+
async with server_to_client_send:
421+
await server_to_client_send.send(
422+
SessionMessage(
423+
JSONRPCMessage(
424+
JSONRPCResponse(
425+
jsonrpc="2.0",
426+
id=jsonrpc_request.root.id,
427+
result=result.model_dump(
428+
by_alias=True, mode="json", exclude_none=True
429+
),
430+
)
431+
)
432+
)
433+
)
434+
# Receive initialized notification
435+
await client_to_server_receive.receive()
436+
437+
async with (
438+
ClientSession(
439+
server_to_client_receive,
440+
client_to_server_send,
441+
) as session,
442+
anyio.create_task_group() as tg,
443+
client_to_server_send,
444+
client_to_server_receive,
445+
server_to_client_send,
446+
server_to_client_receive,
447+
):
448+
tg.start_soon(mock_server)
449+
await session.initialize()
450+
451+
# Assert that capabilities are properly set with defaults
452+
assert received_capabilities is not None
453+
assert received_capabilities.sampling is None # No custom sampling callback
454+
assert received_capabilities.roots is None # No custom list_roots callback
455+
456+
457+
@pytest.mark.anyio
458+
async def test_client_capabilities_with_custom_callbacks():
459+
"""Test that client capabilities are properly set with custom callbacks"""
460+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
461+
SessionMessage
462+
](1)
463+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
464+
SessionMessage
465+
](1)
466+
467+
received_capabilities = None
468+
469+
async def custom_sampling_callback(
470+
context: RequestContext["ClientSession", Any],
471+
params: types.CreateMessageRequestParams,
472+
) -> types.CreateMessageResult | types.ErrorData:
473+
return types.CreateMessageResult(
474+
role="assistant",
475+
content=types.TextContent(type="text", text="test"),
476+
model="test-model",
477+
)
478+
479+
async def custom_list_roots_callback(
480+
context: RequestContext["ClientSession", Any],
481+
) -> types.ListRootsResult | types.ErrorData:
482+
return types.ListRootsResult(roots=[])
483+
484+
async def mock_server():
485+
nonlocal received_capabilities
486+
487+
session_message = await client_to_server_receive.receive()
488+
jsonrpc_request = session_message.message
489+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
490+
request = ClientRequest.model_validate(
491+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
492+
)
493+
assert isinstance(request.root, InitializeRequest)
494+
received_capabilities = request.root.params.capabilities
495+
496+
result = ServerResult(
497+
InitializeResult(
498+
protocolVersion=LATEST_PROTOCOL_VERSION,
499+
capabilities=ServerCapabilities(),
500+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
501+
)
502+
)
503+
504+
async with server_to_client_send:
505+
await server_to_client_send.send(
506+
SessionMessage(
507+
JSONRPCMessage(
508+
JSONRPCResponse(
509+
jsonrpc="2.0",
510+
id=jsonrpc_request.root.id,
511+
result=result.model_dump(
512+
by_alias=True, mode="json", exclude_none=True
513+
),
514+
)
515+
)
516+
)
517+
)
518+
# Receive initialized notification
519+
await client_to_server_receive.receive()
520+
521+
async with (
522+
ClientSession(
523+
server_to_client_receive,
524+
client_to_server_send,
525+
sampling_callback=custom_sampling_callback,
526+
list_roots_callback=custom_list_roots_callback,
527+
) as session,
528+
anyio.create_task_group() as tg,
529+
client_to_server_send,
530+
client_to_server_receive,
531+
server_to_client_send,
532+
server_to_client_receive,
533+
):
534+
tg.start_soon(mock_server)
535+
await session.initialize()
536+
537+
# Assert that capabilities are properly set with custom callbacks
538+
assert received_capabilities is not None
539+
assert (
540+
received_capabilities.sampling is not None
541+
) # Custom sampling callback provided
542+
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
543+
assert (
544+
received_capabilities.roots is not None
545+
) # Custom list_roots callback provided
546+
assert isinstance(received_capabilities.roots, types.RootsCapability)
547+
assert (
548+
received_capabilities.roots.listChanged is True
549+
) # Should be True for custom callback

0 commit comments

Comments
 (0)