diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index cbb611419..66b980989 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -430,7 +430,8 @@ async def post_writer( """Handle writing requests to the server.""" try: async with write_stream_reader: - async for session_message in write_stream_reader: + + async def handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -467,6 +468,10 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + async with anyio.create_task_group() as tg_local: + session_message.context.run(tg_local.start_soon, handle_message, session_message) + except Exception: logger.exception("Error in post_writer") # pragma: no cover finally: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 96dcaf1c7..d89eb088f 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -683,7 +683,14 @@ async def run( async for message in session.incoming_messages: logger.debug("Received message: %s", message) - tg.start_soon( + if isinstance(message, RequestResponder) and message.context is not None: + logger.debug("Got a context to propagate, %s", message.context) + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, self._handle_message, message, session, @@ -739,9 +746,7 @@ async def _handle_request( request_data = None close_sse_stream_cb = None close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance( - message.message_metadata, ServerMessageMetadata - ): # pragma: no cover + if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): request_data = message.message_metadata.request_context close_sse_stream_cb = message.message_metadata.close_sse_stream close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index e9156f7ba..dc68bc3fc 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -328,11 +328,11 @@ def _create_json_response( headers=response_headers, ) - def _get_session_id(self, request: Request) -> str | None: # pragma: no cover + def _get_session_id(self, request: Request) -> str | None: """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", @@ -340,7 +340,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # } # If an event ID was provided, include it - if event_message.event_id: + if event_message.event_id: # pragma: no cover event_data["id"] = event_message.event_id return event_data @@ -381,9 +381,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No if request.method == "POST": await self._handle_post_request(scope, request, receive, send) - elif request.method == "GET": # pragma: no cover + elif request.method == "GET": await self._handle_get_request(request, send) - elif request.method == "DELETE": # pragma: no cover + elif request.method == "DELETE": await self._handle_delete_request(request, send) else: # pragma: no cover await self._handle_unsupported_request(request, send) @@ -427,7 +427,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se return False return True - async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: # noqa: PLR0915 """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: # pragma: no cover @@ -470,14 +470,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Check if this is an initialization request is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" - if is_initialization_request: # pragma: no cover + if is_initialization_request: # Check if the server already has an established session if self.mcp_session_id: # Check if request has a session ID request_session_id = self._get_session_id(request) # If request has a session ID but doesn't match, return 404 - if request_session_id and request_session_id != self.mcp_session_id: + if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, @@ -488,7 +488,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # For notifications and responses only, return 202 Accepted - if not isinstance(message, JSONRPCRequest): # pragma: no cover + if not isinstance(message, JSONRPCRequest): # Create response object and send it response = self._create_json_response( None, @@ -561,14 +561,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) finally: await self._clean_up_memory_streams(request_id) - else: # pragma: no cover + else: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): + async def sse_writer(): # pragma: lax no cover # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -617,11 +617,12 @@ async def sse_writer(): # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: + except Exception: # pragma: no cover logger.exception("SSE response error") await sse_stream_writer.aclose() - await sse_stream_reader.aclose() await self._clean_up_memory_streams(request_id) + finally: + await sse_stream_reader.aclose() except Exception as err: # pragma: no cover logger.exception("Error handling POST request") @@ -635,7 +636,7 @@ async def sse_writer(): await writer.send(Exception(err)) return - async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. This allows the server to communicate to the client without the client @@ -643,13 +644,13 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr and notifications on this stream. """ writer = self._read_stream_writer - if writer is None: + if writer is None: # pragma: no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: + if not has_sse: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -657,11 +658,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): + if not await self._validate_request_headers(request, send): # pragma: no cover return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover await self._replay_events(last_event_id, request, send) return @@ -675,7 +676,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: + if GET_STREAM_KEY in self._request_streams: # pragma: no cover response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -695,7 +696,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: + async for event_message in standalone_stream_reader: # pragma: lax no cover # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -704,7 +705,7 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: + except Exception: # pragma: no cover logger.exception("Error in standalone SSE writer") finally: logger.debug("Closing standalone SSE writer") @@ -720,16 +721,17 @@ async def standalone_sse_writer(): try: # This will send headers immediately and establish the SSE connection await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE response") + await self._clean_up_memory_streams(GET_STREAM_KEY) + finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - await self._clean_up_memory_streams(GET_STREAM_KEY) - async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" # Validate session ID - if not self.mcp_session_id: + if not self.mcp_session_id: # pragma: no cover # If no session ID set, return Method Not Allowed response = self._create_error_response( "Method Not Allowed: Session termination not supported", @@ -738,7 +740,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: # await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): + if not await self._validate_request_headers(request, send): # pragma: no cover return await self.terminate() @@ -796,16 +798,16 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): return False return True - async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: + if not self.mcp_session_id: # pragma: no cover # If we're not using session IDs, return True return True @@ -813,7 +815,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: + if not request_session_id: # pragma: no cover response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -822,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag return False # If session ID doesn't match, return error - if request_session_id != self.mcp_session_id: + if request_session_id != self.mcp_session_id: # pragma: no cover response = self._create_error_response( "Not Found: Invalid or expired session ID", HTTPStatus.NOT_FOUND, @@ -832,17 +834,17 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag return True - async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: """Validate the protocol version header in the request.""" # Get the protocol version from the request headers protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: + if protocol_version is None: # pragma: no cover protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -1004,10 +1006,10 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except ( # pragma: no cover + except ( anyio.BrokenResourceError, anyio.ClosedResourceError, - ): + ): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) else: # pragma: no cover diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index a954b24a4..ddc6e5014 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -181,7 +181,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) @@ -261,5 +261,5 @@ class StreamableHTTPASGIApp: def __init__(self, session_manager: StreamableHTTPSessionManager): self.session_manager = session_manager - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.session_manager.handle_request(scope, receive, send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d..4ef1448af 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -4,8 +4,9 @@ to support transport-specific features like resumability. """ +import contextvars from collections.abc import Awaitable, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from mcp.types import JSONRPCMessage, RequestId @@ -46,4 +47,5 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage + context: contextvars.Context = field(default_factory=contextvars.copy_context) metadata: MessageMetadata = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274..16eee10e3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -77,11 +78,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -330,10 +333,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): # pragma: no cover - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def handle_message(message: SessionMessage) -> None: + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -346,6 +348,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=message.context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -403,6 +406,13 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): # pragma: no cover + await self._handle_incoming(message) + else: + async with anyio.create_task_group() as tg: + message.context.run(tg.start_soon, handle_message, message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index af1b23619..5f46977d5 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -5,9 +5,13 @@ from unittest.mock import AsyncMock, patch import anyio +import httpx import pytest from starlette.types import Message +from mcp import types +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client from mcp.server import streamable_http_manager from mcp.server.lowlevel import Server from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport @@ -313,3 +317,24 @@ async def mock_receive(): assert error_data["id"] == "server-error" assert error_data["error"]["code"] == INVALID_REQUEST assert error_data["error"]["message"] == "Session not found" + + +@pytest.mark.anyio +async def test_e2e_streamable_http_server_cleanup(): + host = "testserver" + app = Server("test-server") + + @app.list_tools() + async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult: + return types.ListToolsResult(tools=[]) + + mcp_app = app.streamable_http_app(host=host) + async with ( + mcp_app.router.lifespan_context(mcp_app), + httpx.ASGITransport(mcp_app) as transport, + httpx.AsyncClient(transport=transport) as client, + streamable_http_client(f"http://{host}/mcp", http_client=client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + await session.list_tools() diff --git a/tests/test_context_propagation.py b/tests/test_context_propagation.py new file mode 100644 index 000000000..1e56d72ab --- /dev/null +++ b/tests/test_context_propagation.py @@ -0,0 +1,98 @@ +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager + +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +import mcp.types as types +from mcp import Client +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client +from mcp.server import MCPServer + +TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") +HOST = "testserver" + + +@contextmanager +def set_test_contextvar(value: str) -> Iterator[None]: + token = TEST_CONTEXTVAR.set(value) + try: + yield + finally: + TEST_CONTEXTVAR.reset(token) + + +@pytest.fixture +def server() -> MCPServer: + mcp = MCPServer("test_server") + + # tool that returns the value of TEST_CONTEXT_VAR. + @mcp.tool() + async def my_tool() -> str: + return TEST_CONTEXTVAR.get() + + return mcp + + +@pytest.mark.anyio +async def test_memory_transport_client_to_server(server: MCPServer): + async with Client(server) as client: + with set_test_contextvar("client_value"): + result = await client.call_tool(name="my_tool") + + assert isinstance(result, types.CallToolResult) + assert result.content == snapshot([types.TextContent(text="client_value")]) + + +@pytest.mark.anyio +async def test_streamable_http_asgi_to_mcpserver(server: MCPServer): + mcp_app = server.streamable_http_app(host=HOST) + + # Wrap it in a middleware that sets the contextvar + async def middleware_app(scope: Scope, receive: Receive, send: Send): + with set_test_contextvar("from_middleware"): + await mcp_app(scope, receive, send) + + async with ( + mcp_app.router.lifespan_context(middleware_app), + httpx.ASGITransport(app=middleware_app) as transport, + httpx.AsyncClient(transport=transport) as client, + streamable_http_client(f"http://{HOST}/mcp", http_client=client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + await session.initialize() + result = await session.call_tool("my_tool") + assert result.content == snapshot([types.TextContent(text="from_middleware")]) + + +@pytest.mark.anyio +async def test_streamable_http_mcpclient_to_httpx(server: MCPServer): + mcp_app = server.streamable_http_app(host=HOST) + + captured_context_var = None + + # Intercepts the httpx call and capture the contextvar's value + class ContextCapturingASGITransport(httpx.ASGITransport): + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + nonlocal captured_context_var + captured_context_var = TEST_CONTEXTVAR.get() + return await super().handle_async_request(request) + + async with ( + mcp_app.router.lifespan_context(mcp_app), + ContextCapturingASGITransport(app=mcp_app) as transport, + httpx.AsyncClient(transport=transport) as client, + streamable_http_client(f"http://{HOST}/mcp", http_client=client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with set_test_contextvar("client_value_initialize"): + await session.initialize() + assert captured_context_var == snapshot("client_value_initialize") + + with set_test_contextvar("client_value_call_tool"): + await session.call_tool("my_tool") + assert captured_context_var == snapshot("client_value_call_tool")