From 0cd88e83cae0c6daae1815c0dc079bdd2bec6605 Mon Sep 17 00:00:00 2001 From: aaronabbott Date: Fri, 9 Jan 2026 22:56:27 +0000 Subject: [PATCH 1/3] Test cases --- tests/test_context_propagation.py | 94 +++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 tests/test_context_propagation.py diff --git a/tests/test_context_propagation.py b/tests/test_context_propagation.py new file mode 100644 index 000000000..83704bc7d --- /dev/null +++ b/tests/test_context_propagation.py @@ -0,0 +1,94 @@ +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager + +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +import mcp.types as types +from mcp import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import MCPServer + +TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") +HOST = "testserver" + + +@contextmanager +def set_test_contextvar(value: str) -> Iterator[None]: + token = TEST_CONTEXTVAR.set(value) + try: + yield + finally: + TEST_CONTEXTVAR.reset(token) + + +@pytest.fixture +def server() -> MCPServer: + mcp = MCPServer("test_server") + + # tool that returns the value of TEST_CONTEXT_VAR. + @mcp.tool() + async def my_tool() -> str: + return TEST_CONTEXTVAR.get() + + return mcp + + +@pytest.mark.anyio +async def test_memory_transport_client_to_server(server: MCPServer): + async with Client(server) as client: + with set_test_contextvar("client_value"): + result = await client.call_tool(name="my_tool") + + assert isinstance(result, types.CallToolResult) + assert result.content == snapshot([types.TextContent(text="client_value")]) + + +@pytest.mark.anyio +async def test_streamable_http_asgi_to_mcpserver(server: MCPServer): + mcp_app = server.streamable_http_app(host=HOST) + + # Wrap it in a middleware that sets the contextvar + async def middleware_app(scope: Scope, receive: Receive, send: Send): + with set_test_contextvar("from_middleware"): + await mcp_app(scope, receive, send) + + async with ( + mcp_app.router.lifespan_context(middleware_app), + httpx.ASGITransport(app=middleware_app) as transport, + httpx.AsyncClient(transport=transport) as http_client, + Client(streamable_http_client(f"http://{HOST}/mcp", http_client=http_client)) as client, + ): + result = await client.call_tool("my_tool") + assert result.content == snapshot([types.TextContent(text="from_middleware")]) + + +@pytest.mark.anyio +async def test_streamable_http_mcpclient_to_httpx(server: MCPServer): + mcp_app = server.streamable_http_app(host=HOST) + + captured_context_var = None + + # Intercepts the httpx call and capture the contextvar's value + class ContextCapturingASGITransport(httpx.ASGITransport): + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + nonlocal captured_context_var + captured_context_var = TEST_CONTEXTVAR.get() + return await super().handle_async_request(request) + + async with ( + mcp_app.router.lifespan_context(mcp_app), + ContextCapturingASGITransport(app=mcp_app) as transport, + httpx.AsyncClient(transport=transport) as http_client, + Client(streamable_http_client(f"http://{HOST}/mcp", http_client=http_client)) as client, + ): + with set_test_contextvar("client_value_list"): + await client.list_tools() + assert captured_context_var == snapshot("client_value_list") + + with set_test_contextvar("client_value_call_tool"): + await client.call_tool("my_tool") + assert captured_context_var == snapshot("client_value_call_tool") From 73c18ac2ce57ad93560035cf0e17abf485a99843 Mon Sep 17 00:00:00 2001 From: aaronabbott Date: Fri, 6 Feb 2026 03:09:43 +0000 Subject: [PATCH 2/3] Temporary hack to fix tests until fix is merged --- tests/test_context_propagation.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_context_propagation.py b/tests/test_context_propagation.py index 83704bc7d..33b1c1005 100644 --- a/tests/test_context_propagation.py +++ b/tests/test_context_propagation.py @@ -12,6 +12,19 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server import MCPServer +# TODO: remove once https://github.com/modelcontextprotocol/python-sdk/pull/1991 is merged +pytestmark = pytest.mark.filterwarnings("ignore::ResourceWarning") + + +# TODO: remove once https://github.com/modelcontextprotocol/python-sdk/pull/1991 is merged +@pytest.fixture(autouse=True) +def force_gc_after_test_resource_leak(): + yield + import gc + + gc.collect() + + TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial") HOST = "testserver" From 8cdb6494e4004f6e24e95695cb2a5f6696e9b524 Mon Sep 17 00:00:00 2001 From: aaronabbott Date: Fri, 9 Jan 2026 22:56:27 +0000 Subject: [PATCH 3/3] Propagate contextvars through anyio streams TODO: - Update a recipe to show it working - Consider adding an integration test of some kind --- src/mcp/client/streamable_http.py | 7 ++++++- src/mcp/server/lowlevel/server.py | 9 ++++++++- src/mcp/shared/message.py | 4 +++- src/mcp/shared/session.py | 18 ++++++++++++++---- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index cbb611419..66b980989 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -430,7 +430,8 @@ async def post_writer( """Handle writing requests to the server.""" try: async with write_stream_reader: - async for session_message in write_stream_reader: + + async def handle_message(session_message: SessionMessage) -> None: message = session_message.message metadata = ( session_message.metadata @@ -467,6 +468,10 @@ async def handle_request_async(): else: await handle_request_async() + async for session_message in write_stream_reader: + async with anyio.create_task_group() as tg_local: + session_message.context.run(tg_local.start_soon, handle_message, session_message) + except Exception: logger.exception("Error in post_writer") # pragma: no cover finally: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 96dcaf1c7..2f4b7cc0b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -683,7 +683,14 @@ async def run( async for message in session.incoming_messages: logger.debug("Received message: %s", message) - tg.start_soon( + if isinstance(message, RequestResponder) and message.context is not None: + logger.debug("Got a context to propagate, %s", message.context) + context = message.context + else: + context = contextvars.copy_context() + + context.run( + tg.start_soon, self._handle_message, message, session, diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 9dedd2e5d..4ef1448af 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -4,8 +4,9 @@ to support transport-specific features like resumability. """ +import contextvars from collections.abc import Awaitable, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from mcp.types import JSONRPCMessage, RequestId @@ -46,4 +47,5 @@ class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage + context: contextvars.Context = field(default_factory=contextvars.copy_context) metadata: MessageMetadata = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 453e36274..16eee10e3 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextvars import logging from collections.abc import Callable from contextlib import AsyncExitStack @@ -77,11 +78,13 @@ def __init__( session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, + context: contextvars.Context | None = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self.message_metadata = message_metadata + self.context = context self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -330,10 +333,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: async def _receive_loop(self) -> None: async with self._read_stream, self._write_stream: try: - async for message in self._read_stream: - if isinstance(message, Exception): # pragma: no cover - await self._handle_incoming(message) - elif isinstance(message.message, JSONRPCRequest): + + async def handle_message(message: SessionMessage) -> None: + if isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_adapter.validate_python( message.message.model_dump(by_alias=True, mode="json", exclude_none=True), @@ -346,6 +348,7 @@ async def _receive_loop(self) -> None: session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, + context=message.context, ) self._in_flight[responder.request_id] = responder await self._received_request(responder) @@ -403,6 +406,13 @@ async def _receive_loop(self) -> None: else: # Response or error await self._handle_response(message) + async for message in self._read_stream: + if isinstance(message, Exception): # pragma: no cover + await self._handle_incoming(message) + else: + async with anyio.create_task_group() as tg: + message.context.run(tg.start_soon, handle_message, message) + except anyio.ClosedResourceError: # This is expected when the client disconnects abruptly. # Without this handler, the exception would propagate up and