Skip to content

Commit b5cb58e

Browse files
committed
Propagate contextvars through anyio streams
TODO: - Update a recipe to show it working - Consider adding an integration test of some kind
1 parent 373fbc7 commit b5cb58e

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

src/mcp/client/streamable_http.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ async def post_writer(
430430
"""Handle writing requests to the server."""
431431
try:
432432
async with write_stream_reader:
433-
async for session_message in write_stream_reader:
433+
434+
async def handle_message(session_message: SessionMessage) -> None:
434435
message = session_message.message
435436
metadata = (
436437
session_message.metadata
@@ -467,6 +468,10 @@ async def handle_request_async():
467468
else:
468469
await handle_request_async()
469470

471+
async for session_message in write_stream_reader:
472+
async with anyio.create_task_group() as tg_local:
473+
session_message.context.run(tg_local.start_soon, handle_message, session_message)
474+
470475
except Exception:
471476
logger.exception("Error in post_writer") # pragma: no cover
472477
finally:

src/mcp/server/lowlevel/server.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,14 @@ async def run(
683683
async for message in session.incoming_messages:
684684
logger.debug("Received message: %s", message)
685685

686-
tg.start_soon(
686+
if isinstance(message, RequestResponder) and message.context is not None:
687+
logger.debug("Got a context to propagate, %s", message.context)
688+
context = message.context
689+
else:
690+
context = contextvars.copy_context()
691+
692+
context.run(
693+
tg.start_soon,
687694
self._handle_message,
688695
message,
689696
session,

src/mcp/shared/message.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
to support transport-specific features like resumability.
55
"""
66

7+
import contextvars
78
from collections.abc import Awaitable, Callable
8-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
910

1011
from mcp.types import JSONRPCMessage, RequestId
1112

@@ -46,4 +47,5 @@ class SessionMessage:
4647
"""A message with specific metadata for transport-specific features."""
4748

4849
message: JSONRPCMessage
50+
context: contextvars.Context = field(default_factory=contextvars.copy_context)
4951
metadata: MessageMetadata = None

src/mcp/shared/session.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextvars
34
import logging
45
from collections.abc import Callable
56
from contextlib import AsyncExitStack
@@ -77,11 +78,13 @@ def __init__(
7778
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
7879
on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any],
7980
message_metadata: MessageMetadata = None,
81+
context: contextvars.Context | None = None,
8082
) -> None:
8183
self.request_id = request_id
8284
self.request_meta = request_meta
8385
self.request = request
8486
self.message_metadata = message_metadata
87+
self.context = context
8588
self._session = session
8689
self._completed = False
8790
self._cancel_scope = anyio.CancelScope()
@@ -330,10 +333,9 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
330333
async def _receive_loop(self) -> None:
331334
async with self._read_stream, self._write_stream:
332335
try:
333-
async for message in self._read_stream:
334-
if isinstance(message, Exception): # pragma: no cover
335-
await self._handle_incoming(message)
336-
elif isinstance(message.message, JSONRPCRequest):
336+
337+
async def handle_message(message: SessionMessage) -> None:
338+
if isinstance(message.message, JSONRPCRequest):
337339
try:
338340
validated_request = self._receive_request_adapter.validate_python(
339341
message.message.model_dump(by_alias=True, mode="json", exclude_none=True),
@@ -346,6 +348,7 @@ async def _receive_loop(self) -> None:
346348
session=self,
347349
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
348350
message_metadata=message.metadata,
351+
context=message.context,
349352
)
350353
self._in_flight[responder.request_id] = responder
351354
await self._received_request(responder)
@@ -403,6 +406,13 @@ async def _receive_loop(self) -> None:
403406
else: # Response or error
404407
await self._handle_response(message)
405408

409+
async for message in self._read_stream:
410+
if isinstance(message, Exception): # pragma: no cover
411+
await self._handle_incoming(message)
412+
else:
413+
async with anyio.create_task_group() as tg:
414+
message.context.run(tg.start_soon, handle_message, message)
415+
406416
except anyio.ClosedResourceError:
407417
# This is expected when the client disconnects abruptly.
408418
# Without this handler, the exception would propagate up and

0 commit comments

Comments
 (0)