Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ async def post_writer(
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:

async def handle_message(session_message: SessionMessage) -> None:
message = session_message.message
metadata = (
session_message.metadata
Expand Down Expand Up @@ -467,6 +468,10 @@ async def handle_request_async():
else:
await handle_request_async()

async for session_message in write_stream_reader:
async with anyio.create_task_group() as tg_local:
session_message.context.run(tg_local.start_soon, handle_message, session_message)

except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
Expand Down
13 changes: 9 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@ async def run(
async for message in session.incoming_messages:
logger.debug("Received message: %s", message)

tg.start_soon(
if isinstance(message, RequestResponder) and message.context is not None:
logger.debug("Got a context to propagate, %s", message.context)
context = message.context
else:
context = contextvars.copy_context()

context.run(
tg.start_soon,
self._handle_message,
message,
session,
Expand Down Expand Up @@ -739,9 +746,7 @@ async def _handle_request(
request_data = None
close_sse_stream_cb = None
close_standalone_sse_stream_cb = None
if message.message_metadata is not None and isinstance(
message.message_metadata, ServerMessageMetadata
): # pragma: no cover
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
request_data = message.message_metadata.request_context
close_sse_stream_cb = message.message_metadata.close_sse_stream
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream
Expand Down
74 changes: 38 additions & 36 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,19 @@ def _create_json_response(
headers=response_headers,
)

def _get_session_id(self, request: Request) -> str | None: # pragma: no cover
def _get_session_id(self, request: Request) -> str | None:
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)

def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
"""Create event data dictionary from an EventMessage."""
event_data = {
"event": "message",
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
}

# If an event ID was provided, include it
if event_message.event_id:
if event_message.event_id: # pragma: no cover
event_data["id"] = event_message.event_id

return event_data
Expand Down Expand Up @@ -381,9 +381,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No

if request.method == "POST":
await self._handle_post_request(scope, request, receive, send)
elif request.method == "GET": # pragma: no cover
elif request.method == "GET":
await self._handle_get_request(request, send)
elif request.method == "DELETE": # pragma: no cover
elif request.method == "DELETE":
await self._handle_delete_request(request, send)
else: # pragma: no cover
await self._handle_unsupported_request(request, send)
Expand Down Expand Up @@ -427,7 +427,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
return False
return True

async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: # noqa: PLR0915
"""Handle POST requests containing JSON-RPC messages."""
writer = self._read_stream_writer
if writer is None: # pragma: no cover
Expand Down Expand Up @@ -470,14 +470,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
# Check if this is an initialization request
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request: # pragma: no cover
if is_initialization_request:
# Check if the server already has an established session
if self.mcp_session_id:
# Check if request has a session ID
request_session_id = self._get_session_id(request)

# If request has a session ID but doesn't match, return 404
if request_session_id and request_session_id != self.mcp_session_id:
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
Expand All @@ -488,7 +488,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
return

# For notifications and responses only, return 202 Accepted
if not isinstance(message, JSONRPCRequest): # pragma: no cover
if not isinstance(message, JSONRPCRequest):
# Create response object and send it
response = self._create_json_response(
None,
Expand Down Expand Up @@ -561,14 +561,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
await response(scope, receive, send)
finally:
await self._clean_up_memory_streams(request_id)
else: # pragma: no cover
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)

# Store writer reference so close_sse_stream() can close it
self._sse_stream_writers[request_id] = sse_stream_writer

async def sse_writer():
async def sse_writer(): # pragma: lax no cover
# Get the request ID from the incoming request message
try:
async with sse_stream_writer, request_stream_reader:
Expand Down Expand Up @@ -617,11 +617,12 @@ async def sse_writer():
# Then send the message to be processed by the server
session_message = self._create_session_message(message, request, request_id, protocol_version)
await writer.send(session_message)
except Exception:
except Exception: # pragma: no cover
logger.exception("SSE response error")
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(request_id)
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
logger.exception("Error handling POST request")
Expand All @@ -635,33 +636,33 @@ async def sse_writer():
await writer.send(Exception(err))
return

async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover
async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET request to establish SSE.

This allows the server to communicate to the client without the client
first sending data via HTTP POST. The server can send JSON-RPC requests
and notifications on this stream.
"""
writer = self._read_stream_writer
if writer is None:
if writer is None: # pragma: no cover
raise ValueError("No read stream writer available. Ensure connect() is called first.")

# Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request)

if not has_sse:
if not has_sse: # pragma: no cover
response = self._create_error_response(
"Not Acceptable: Client must accept text/event-stream",
HTTPStatus.NOT_ACCEPTABLE,
)
await response(request.scope, request.receive, send)
return

if not await self._validate_request_headers(request, send):
if not await self._validate_request_headers(request, send): # pragma: no cover
return

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
await self._replay_events(last_event_id, request, send)
return

Expand All @@ -675,7 +676,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Check if we already have an active GET stream
if GET_STREAM_KEY in self._request_streams:
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
response = self._create_error_response(
"Conflict: Only one SSE stream is allowed per session",
HTTPStatus.CONFLICT,
Expand All @@ -695,7 +696,7 @@ async def standalone_sse_writer():

async with sse_stream_writer, standalone_stream_reader:
# Process messages from the standalone stream
async for event_message in standalone_stream_reader:
async for event_message in standalone_stream_reader: # pragma: lax no cover
# For the standalone stream, we handle:
# - JSONRPCNotification (server sends notifications to client)
# - JSONRPCRequest (server sends requests to client)
Expand All @@ -704,7 +705,7 @@ async def standalone_sse_writer():
# Send the message via SSE
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception:
except Exception: # pragma: no cover
logger.exception("Error in standalone SSE writer")
finally:
logger.debug("Closing standalone SSE writer")
Expand All @@ -720,16 +721,17 @@ async def standalone_sse_writer():
try:
# This will send headers immediately and establish the SSE connection
await response(request.scope, request.receive, send)
except Exception:
except Exception: # pragma: lax no cover
logger.exception("Error in standalone SSE response")
await self._clean_up_memory_streams(GET_STREAM_KEY)
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()
await self._clean_up_memory_streams(GET_STREAM_KEY)

async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover
async def _handle_delete_request(self, request: Request, send: Send) -> None:
"""Handle DELETE requests for explicit session termination."""
# Validate session ID
if not self.mcp_session_id:
if not self.mcp_session_id: # pragma: no cover
# If no session ID set, return Method Not Allowed
response = self._create_error_response(
"Method Not Allowed: Session termination not supported",
Expand All @@ -738,7 +740,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: #
await response(request.scope, request.receive, send)
return

if not await self._validate_request_headers(request, send):
if not await self._validate_request_headers(request, send): # pragma: no cover
return

await self.terminate()
Expand Down Expand Up @@ -796,24 +798,24 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
)
await response(request.scope, request.receive, send)

async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover
async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover
if not await self._validate_session(request, send):
return False
if not await self._validate_protocol_version(request, send):
return False
return True

async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover
async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id:
if not self.mcp_session_id: # pragma: no cover
# If we're not using session IDs, return True
return True

# Get the session ID from the request headers
request_session_id = self._get_session_id(request)

# If no session ID provided but required, return error
if not request_session_id:
if not request_session_id: # pragma: no cover
response = self._create_error_response(
"Bad Request: Missing session ID",
HTTPStatus.BAD_REQUEST,
Expand All @@ -822,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag
return False

# If session ID doesn't match, return error
if request_session_id != self.mcp_session_id:
if request_session_id != self.mcp_session_id: # pragma: no cover
response = self._create_error_response(
"Not Found: Invalid or expired session ID",
HTTPStatus.NOT_FOUND,
Expand All @@ -832,17 +834,17 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag

return True

async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
"""Validate the protocol version header in the request."""
# Get the protocol version from the request headers
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None:
if protocol_version is None: # pragma: no cover
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
Expand Down Expand Up @@ -1004,10 +1006,10 @@ async def message_router():
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
except ( # pragma: no cover
except (
anyio.BrokenResourceError,
anyio.ClosedResourceError,
):
): # pragma: no cover
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else: # pragma: no cover
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)

# Existing session case
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
transport = self._server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
Expand Down Expand Up @@ -261,5 +261,5 @@ class StreamableHTTPASGIApp:
def __init__(self, session_manager: StreamableHTTPSessionManager):
self.session_manager = session_manager

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.session_manager.handle_request(scope, receive, send)
4 changes: 3 additions & 1 deletion src/mcp/shared/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
to support transport-specific features like resumability.
"""

import contextvars
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field

from mcp.types import JSONRPCMessage, RequestId

Expand Down Expand Up @@ -46,4 +47,5 @@ class SessionMessage:
"""A message with specific metadata for transport-specific features."""

message: JSONRPCMessage
context: contextvars.Context = field(default_factory=contextvars.copy_context)
metadata: MessageMetadata = None
Loading
Loading