Skip to content

Commit cf9e43e

Browse files
committed
Merge main and fix uv.lock
2 parents 1a943ad + 1f6e8ec commit cf9e43e

File tree

5 files changed

+155
-133
lines changed

5 files changed

+155
-133
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,9 +739,7 @@ async def _handle_request(
739739
request_data = None
740740
close_sse_stream_cb = None
741741
close_standalone_sse_stream_cb = None
742-
if message.message_metadata is not None and isinstance(
743-
message.message_metadata, ServerMessageMetadata
744-
): # pragma: no cover
742+
if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata):
745743
request_data = message.message_metadata.request_context
746744
close_sse_stream_cb = message.message_metadata.close_sse_stream
747745
close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream

src/mcp/server/streamable_http.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -328,19 +328,19 @@ def _create_json_response(
328328
headers=response_headers,
329329
)
330330

331-
def _get_session_id(self, request: Request) -> str | None: # pragma: no cover
331+
def _get_session_id(self, request: Request) -> str | None:
332332
"""Extract the session ID from request headers."""
333333
return request.headers.get(MCP_SESSION_ID_HEADER)
334334

335-
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: # pragma: no cover
335+
def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
336336
"""Create event data dictionary from an EventMessage."""
337337
event_data = {
338338
"event": "message",
339339
"data": event_message.message.model_dump_json(by_alias=True, exclude_none=True),
340340
}
341341

342342
# If an event ID was provided, include it
343-
if event_message.event_id:
343+
if event_message.event_id: # pragma: no cover
344344
event_data["id"] = event_message.event_id
345345

346346
return event_data
@@ -381,9 +381,9 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
381381

382382
if request.method == "POST":
383383
await self._handle_post_request(scope, request, receive, send)
384-
elif request.method == "GET": # pragma: no cover
384+
elif request.method == "GET":
385385
await self._handle_get_request(request, send)
386-
elif request.method == "DELETE": # pragma: no cover
386+
elif request.method == "DELETE":
387387
await self._handle_delete_request(request, send)
388388
else: # pragma: no cover
389389
await self._handle_unsupported_request(request, send)
@@ -427,7 +427,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se
427427
return False
428428
return True
429429

430-
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None:
430+
async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: # noqa: PLR0915
431431
"""Handle POST requests containing JSON-RPC messages."""
432432
writer = self._read_stream_writer
433433
if writer is None: # pragma: no cover
@@ -470,14 +470,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
470470
# Check if this is an initialization request
471471
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"
472472

473-
if is_initialization_request: # pragma: no cover
473+
if is_initialization_request:
474474
# Check if the server already has an established session
475475
if self.mcp_session_id:
476476
# Check if request has a session ID
477477
request_session_id = self._get_session_id(request)
478478

479479
# If request has a session ID but doesn't match, return 404
480-
if request_session_id and request_session_id != self.mcp_session_id:
480+
if request_session_id and request_session_id != self.mcp_session_id: # pragma: no cover
481481
response = self._create_error_response(
482482
"Not Found: Invalid or expired session ID",
483483
HTTPStatus.NOT_FOUND,
@@ -488,7 +488,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
488488
return
489489

490490
# For notifications and responses only, return 202 Accepted
491-
if not isinstance(message, JSONRPCRequest): # pragma: no cover
491+
if not isinstance(message, JSONRPCRequest):
492492
# Create response object and send it
493493
response = self._create_json_response(
494494
None,
@@ -561,14 +561,14 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
561561
await response(scope, receive, send)
562562
finally:
563563
await self._clean_up_memory_streams(request_id)
564-
else: # pragma: no cover
564+
else:
565565
# Create SSE stream
566566
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
567567

568568
# Store writer reference so close_sse_stream() can close it
569569
self._sse_stream_writers[request_id] = sse_stream_writer
570570

571-
async def sse_writer():
571+
async def sse_writer(): # pragma: lax no cover
572572
# Get the request ID from the incoming request message
573573
try:
574574
async with sse_stream_writer, request_stream_reader:
@@ -617,11 +617,12 @@ async def sse_writer():
617617
# Then send the message to be processed by the server
618618
session_message = self._create_session_message(message, request, request_id, protocol_version)
619619
await writer.send(session_message)
620-
except Exception:
620+
except Exception: # pragma: no cover
621621
logger.exception("SSE response error")
622622
await sse_stream_writer.aclose()
623-
await sse_stream_reader.aclose()
624623
await self._clean_up_memory_streams(request_id)
624+
finally:
625+
await sse_stream_reader.aclose()
625626

626627
except Exception as err: # pragma: no cover
627628
logger.exception("Error handling POST request")
@@ -635,33 +636,33 @@ async def sse_writer():
635636
await writer.send(Exception(err))
636637
return
637638

638-
async def _handle_get_request(self, request: Request, send: Send) -> None: # pragma: no cover
639+
async def _handle_get_request(self, request: Request, send: Send) -> None:
639640
"""Handle GET request to establish SSE.
640641
641642
This allows the server to communicate to the client without the client
642643
first sending data via HTTP POST. The server can send JSON-RPC requests
643644
and notifications on this stream.
644645
"""
645646
writer = self._read_stream_writer
646-
if writer is None:
647+
if writer is None: # pragma: no cover
647648
raise ValueError("No read stream writer available. Ensure connect() is called first.")
648649

649650
# Validate Accept header - must include text/event-stream
650651
_, has_sse = self._check_accept_headers(request)
651652

652-
if not has_sse:
653+
if not has_sse: # pragma: no cover
653654
response = self._create_error_response(
654655
"Not Acceptable: Client must accept text/event-stream",
655656
HTTPStatus.NOT_ACCEPTABLE,
656657
)
657658
await response(request.scope, request.receive, send)
658659
return
659660

660-
if not await self._validate_request_headers(request, send):
661+
if not await self._validate_request_headers(request, send): # pragma: no cover
661662
return
662663

663664
# Handle resumability: check for Last-Event-ID header
664-
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
665+
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
665666
await self._replay_events(last_event_id, request, send)
666667
return
667668

@@ -675,7 +676,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # pr
675676
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id
676677

677678
# Check if we already have an active GET stream
678-
if GET_STREAM_KEY in self._request_streams:
679+
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
679680
response = self._create_error_response(
680681
"Conflict: Only one SSE stream is allowed per session",
681682
HTTPStatus.CONFLICT,
@@ -695,7 +696,7 @@ async def standalone_sse_writer():
695696

696697
async with sse_stream_writer, standalone_stream_reader:
697698
# Process messages from the standalone stream
698-
async for event_message in standalone_stream_reader:
699+
async for event_message in standalone_stream_reader: # pragma: lax no cover
699700
# For the standalone stream, we handle:
700701
# - JSONRPCNotification (server sends notifications to client)
701702
# - JSONRPCRequest (server sends requests to client)
@@ -704,7 +705,7 @@ async def standalone_sse_writer():
704705
# Send the message via SSE
705706
event_data = self._create_event_data(event_message)
706707
await sse_stream_writer.send(event_data)
707-
except Exception:
708+
except Exception: # pragma: no cover
708709
logger.exception("Error in standalone SSE writer")
709710
finally:
710711
logger.debug("Closing standalone SSE writer")
@@ -720,16 +721,17 @@ async def standalone_sse_writer():
720721
try:
721722
# This will send headers immediately and establish the SSE connection
722723
await response(request.scope, request.receive, send)
723-
except Exception:
724+
except Exception: # pragma: lax no cover
724725
logger.exception("Error in standalone SSE response")
726+
await self._clean_up_memory_streams(GET_STREAM_KEY)
727+
finally:
725728
await sse_stream_writer.aclose()
726729
await sse_stream_reader.aclose()
727-
await self._clean_up_memory_streams(GET_STREAM_KEY)
728730

729-
async def _handle_delete_request(self, request: Request, send: Send) -> None: # pragma: no cover
731+
async def _handle_delete_request(self, request: Request, send: Send) -> None:
730732
"""Handle DELETE requests for explicit session termination."""
731733
# Validate session ID
732-
if not self.mcp_session_id:
734+
if not self.mcp_session_id: # pragma: no cover
733735
# If no session ID set, return Method Not Allowed
734736
response = self._create_error_response(
735737
"Method Not Allowed: Session termination not supported",
@@ -738,7 +740,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: #
738740
await response(request.scope, request.receive, send)
739741
return
740742

741-
if not await self._validate_request_headers(request, send):
743+
if not await self._validate_request_headers(request, send): # pragma: no cover
742744
return
743745

744746
await self.terminate()
@@ -796,24 +798,24 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non
796798
)
797799
await response(request.scope, request.receive, send)
798800

799-
async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: no cover
801+
async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover
800802
if not await self._validate_session(request, send):
801803
return False
802804
if not await self._validate_protocol_version(request, send):
803805
return False
804806
return True
805807

806-
async def _validate_session(self, request: Request, send: Send) -> bool: # pragma: no cover
808+
async def _validate_session(self, request: Request, send: Send) -> bool:
807809
"""Validate the session ID in the request."""
808-
if not self.mcp_session_id:
810+
if not self.mcp_session_id: # pragma: no cover
809811
# If we're not using session IDs, return True
810812
return True
811813

812814
# Get the session ID from the request headers
813815
request_session_id = self._get_session_id(request)
814816

815817
# If no session ID provided but required, return error
816-
if not request_session_id:
818+
if not request_session_id: # pragma: no cover
817819
response = self._create_error_response(
818820
"Bad Request: Missing session ID",
819821
HTTPStatus.BAD_REQUEST,
@@ -822,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag
822824
return False
823825

824826
# If session ID doesn't match, return error
825-
if request_session_id != self.mcp_session_id:
827+
if request_session_id != self.mcp_session_id: # pragma: no cover
826828
response = self._create_error_response(
827829
"Not Found: Invalid or expired session ID",
828830
HTTPStatus.NOT_FOUND,
@@ -832,17 +834,17 @@ async def _validate_session(self, request: Request, send: Send) -> bool: # prag
832834

833835
return True
834836

835-
async def _validate_protocol_version(self, request: Request, send: Send) -> bool: # pragma: no cover
837+
async def _validate_protocol_version(self, request: Request, send: Send) -> bool:
836838
"""Validate the protocol version header in the request."""
837839
# Get the protocol version from the request headers
838840
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
839841

840842
# If no protocol version provided, assume default version
841-
if protocol_version is None:
843+
if protocol_version is None: # pragma: no cover
842844
protocol_version = DEFAULT_NEGOTIATED_VERSION
843845

844846
# Check if the protocol version is supported
845-
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
847+
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
846848
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
847849
response = self._create_error_response(
848850
f"Bad Request: Unsupported protocol version: {protocol_version}. "
@@ -1004,10 +1006,10 @@ async def message_router():
10041006
try:
10051007
# Send both the message and the event ID
10061008
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
1007-
except ( # pragma: no cover
1009+
except (
10081010
anyio.BrokenResourceError,
10091011
anyio.ClosedResourceError,
1010-
):
1012+
): # pragma: no cover
10111013
# Stream might be closed, remove from registry
10121014
self._request_streams.pop(request_stream_id, None)
10131015
else: # pragma: no cover

src/mcp/server/streamable_http_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
181181
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
182182

183183
# Existing session case
184-
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: # pragma: no cover
184+
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
185185
transport = self._server_instances[request_mcp_session_id]
186186
logger.debug("Session already exists, handling request directly")
187187
await transport.handle_request(scope, receive, send)
@@ -261,5 +261,5 @@ class StreamableHTTPASGIApp:
261261
def __init__(self, session_manager: StreamableHTTPSessionManager):
262262
self.session_manager = session_manager
263263

264-
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
264+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
265265
await self.session_manager.handle_request(scope, receive, send)

tests/server/test_streamable_http_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from unittest.mock import AsyncMock, patch
66

77
import anyio
8+
import httpx
89
import pytest
910
from starlette.types import Message
1011

12+
from mcp import Client, types
13+
from mcp.client.streamable_http import streamable_http_client
1114
from mcp.server import streamable_http_manager
1215
from mcp.server.lowlevel import Server
1316
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
@@ -313,3 +316,22 @@ async def mock_receive():
313316
assert error_data["id"] == "server-error"
314317
assert error_data["error"]["code"] == INVALID_REQUEST
315318
assert error_data["error"]["message"] == "Session not found"
319+
320+
321+
@pytest.mark.anyio
322+
async def test_e2e_streamable_http_server_cleanup():
323+
host = "testserver"
324+
app = Server("test-server")
325+
326+
@app.list_tools()
327+
async def list_tools(req: types.ListToolsRequest) -> types.ListToolsResult:
328+
return types.ListToolsResult(tools=[])
329+
330+
mcp_app = app.streamable_http_app(host=host)
331+
async with (
332+
mcp_app.router.lifespan_context(mcp_app),
333+
httpx.ASGITransport(mcp_app) as transport,
334+
httpx.AsyncClient(transport=transport) as http_client,
335+
Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client,
336+
):
337+
await client.list_tools()

0 commit comments

Comments
 (0)