Skip to content

Commit dde8cd5

Browse files
committed
clean up server memory streams
1 parent 02f00c4 commit dde8cd5

File tree

3 files changed

+109
-68
lines changed

3 files changed

+109
-68
lines changed

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send):
185185
)
186186
server_instances[http_transport.mcp_session_id] = http_transport
187187
logger.info(f"Created new transport with session ID: {new_session_id}")
188-
async with http_transport.connect() as streams:
189-
read_stream, write_stream = streams
190188

191-
async def run_server():
192-
await app.run(
193-
read_stream,
194-
write_stream,
195-
app.create_initialization_options(),
196-
)
189+
async def run_server(task_status=None):
190+
async with http_transport.connect() as streams:
191+
read_stream, write_stream = streams
192+
if task_status:
193+
task_status.started()
194+
await app.run(
195+
read_stream,
196+
write_stream,
197+
app.create_initialization_options(),
198+
)
197199

198200
if not task_group:
199201
raise RuntimeError("Task group is not initialized")
200202

201-
task_group.start_soon(run_server)
203+
await task_group.start(run_server)
202204

203205
# Handle the HTTP request and return the response
204206
await http_transport.handle_request(scope, receive, send)

src/mcp/server/streamable_http.py

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129129
_read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = (
130130
None
131131
)
132+
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None
133+
_write_stream: MemoryObjectSendStream[SessionMessage] | None = None
132134
_write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None
133135

134136
def __init__(
@@ -163,7 +165,11 @@ def __init__(
163165
self.is_json_response_enabled = is_json_response_enabled
164166
self._event_store = event_store
165167
self._request_streams: dict[
166-
RequestId, MemoryObjectSendStream[EventMessage]
168+
RequestId,
169+
tuple[
170+
MemoryObjectSendStream[EventMessage],
171+
MemoryObjectReceiveStream[EventMessage],
172+
],
167173
] = {}
168174
self._terminated = False
169175

@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239245

240246
return event_data
241247

248+
async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
249+
"""Clean up memory streams for a given request ID."""
250+
if request_id in self._request_streams:
251+
try:
252+
# Close the request stream
253+
await self._request_streams[request_id][0].aclose()
254+
await self._request_streams[request_id][1].aclose()
255+
except Exception as e:
256+
logger.debug(f"Error closing memory streams: {e}")
257+
finally:
258+
# Remove the request stream from the mapping
259+
self._request_streams.pop(request_id, None)
260+
242261
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
243262
"""Application entry point that handles all HTTP requests"""
244263
request = Request(scope, receive)
@@ -386,13 +405,11 @@ async def _handle_post_request(
386405

387406
# Extract the request ID outside the try block for proper scope
388407
request_id = str(message.root.id)
389-
# Create promise stream for getting response
390-
request_stream_writer, request_stream_reader = (
391-
anyio.create_memory_object_stream[EventMessage](0)
392-
)
393-
394408
# Register this stream for the request ID
395-
self._request_streams[request_id] = request_stream_writer
409+
self._request_streams[request_id] = anyio.create_memory_object_stream[
410+
EventMessage
411+
](0)
412+
request_stream_reader = self._request_streams[request_id][1]
396413

397414
if self.is_json_response_enabled:
398415
# Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441458
)
442459
await response(scope, receive, send)
443460
finally:
444-
# Clean up the request stream
445-
if request_id in self._request_streams:
446-
self._request_streams.pop(request_id, None)
447-
await request_stream_reader.aclose()
448-
await request_stream_writer.aclose()
461+
await self._clean_up_memory_streams(request_id)
449462
else:
450463
# Create SSE stream
451464
sse_stream_writer, sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467480
event_message.message.root,
468481
JSONRPCResponse | JSONRPCError,
469482
):
470-
if request_id:
471-
self._request_streams.pop(request_id, None)
472483
break
473484
except Exception as e:
474485
logger.exception(f"Error in SSE writer: {e}")
475486
finally:
476487
logger.debug("Closing SSE writer")
477-
# Clean up the request-specific streams
478-
if request_id and request_id in self._request_streams:
479-
self._request_streams.pop(request_id, None)
488+
await self._clean_up_memory_streams(request_id)
480489

481490
# Create and start EventSourceResponse
482491
# SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507516
await writer.send(session_message)
508517
except Exception:
509518
logger.exception("SSE response error")
510-
# Clean up the request stream if something goes wrong
511-
if request_id and request_id in self._request_streams:
512-
self._request_streams.pop(request_id, None)
519+
await sse_stream_writer.aclose()
520+
await sse_stream_reader.aclose()
521+
await self._clean_up_memory_streams(request_id)
513522

514523
except Exception as err:
515524
logger.exception("Error handling POST request")
@@ -583,12 +592,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
583592
async def standalone_sse_writer():
584593
try:
585594
# Create a standalone message stream for server-initiated messages
586-
standalone_stream_writer, standalone_stream_reader = (
595+
596+
self._request_streams[GET_STREAM_KEY] = (
587597
anyio.create_memory_object_stream[EventMessage](0)
588598
)
589-
590-
# Register this stream using the special key
591-
self._request_streams[GET_STREAM_KEY] = standalone_stream_writer
599+
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]
592600

593601
async with sse_stream_writer, standalone_stream_reader:
594602
# Process messages from the standalone stream
@@ -605,8 +613,7 @@ async def standalone_sse_writer():
605613
logger.exception(f"Error in standalone SSE writer: {e}")
606614
finally:
607615
logger.debug("Closing standalone SSE writer")
608-
# Remove the stream from request_streams
609-
self._request_streams.pop(GET_STREAM_KEY, None)
616+
await self._clean_up_memory_streams(GET_STREAM_KEY)
610617

611618
# Create and start EventSourceResponse
612619
response = EventSourceResponse(
@@ -620,8 +627,9 @@ async def standalone_sse_writer():
620627
await response(request.scope, request.receive, send)
621628
except Exception as e:
622629
logger.exception(f"Error in standalone SSE response: {e}")
623-
# Clean up the request stream
624-
self._request_streams.pop(GET_STREAM_KEY, None)
630+
# await sse_stream_writer.aclose()
631+
# await sse_stream_reader.aclose()
632+
await self._clean_up_memory_streams(GET_STREAM_KEY)
625633

626634
async def _handle_delete_request(self, request: Request, send: Send) -> None:
627635
"""Handle DELETE requests for explicit session termination."""
@@ -638,15 +646,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
638646
if not await self._validate_session(request, send):
639647
return
640648

641-
self._terminate_session()
649+
await self._terminate_session()
642650

643651
response = self._create_json_response(
644652
None,
645653
HTTPStatus.OK,
646654
)
647655
await response(request.scope, request.receive, send)
648656

649-
def _terminate_session(self) -> None:
657+
async def _terminate_session(self) -> None:
650658
"""Terminate the current session, closing all streams.
651659
652660
Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -658,19 +666,27 @@ def _terminate_session(self) -> None:
658666
# We need a copy of the keys to avoid modification during iteration
659667
request_stream_keys = list(self._request_streams.keys())
660668

661-
# Close all request streams (synchronously)
669+
# Close all request streams asynchronously
662670
for key in request_stream_keys:
663671
try:
664-
# Get the stream
665-
stream = self._request_streams.get(key)
666-
if stream:
667-
# We must use close() here, not aclose() since this is a sync method
668-
stream.close()
672+
await self._clean_up_memory_streams(key)
669673
except Exception as e:
670674
logger.debug(f"Error closing stream {key} during termination: {e}")
671675

672676
# Clear the request streams dictionary immediately
673677
self._request_streams.clear()
678+
try:
679+
if self._read_stream_writer is not None:
680+
await self._read_stream_writer.aclose()
681+
if self._read_stream is not None:
682+
await self._read_stream.aclose()
683+
if self._write_stream_reader is not None:
684+
await self._write_stream_reader.aclose()
685+
if self._write_stream is not None:
686+
await self._write_stream.aclose()
687+
except Exception as e:
688+
logger.debug(f"Error closing streams: {e}")
689+
pass
674690

675691
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
676692
"""Handle unsupported HTTP methods."""
@@ -758,10 +774,10 @@ async def send_event(event_message: EventMessage) -> None:
758774

759775
# If stream ID not in mapping, create it
760776
if stream_id and stream_id not in self._request_streams:
761-
msg_writer, msg_reader = anyio.create_memory_object_stream[
762-
EventMessage
763-
](0)
764-
self._request_streams[stream_id] = msg_writer
777+
self._request_streams[stream_id] = (
778+
anyio.create_memory_object_stream[EventMessage](0)
779+
)
780+
msg_reader = self._request_streams[stream_id][1]
765781

766782
# Forward messages to SSE
767783
async with msg_reader:
@@ -783,6 +799,9 @@ async def send_event(event_message: EventMessage) -> None:
783799
await response(request.scope, request.receive, send)
784800
except Exception as e:
785801
logger.exception(f"Error in replay response: {e}")
802+
finally:
803+
await sse_stream_writer.aclose()
804+
await sse_stream_reader.aclose()
786805

787806
except Exception as e:
788807
logger.exception(f"Error replaying events: {e}")
@@ -820,7 +839,9 @@ async def connect(
820839

821840
# Store the streams
822841
self._read_stream_writer = read_stream_writer
842+
self._read_stream = read_stream
823843
self._write_stream_reader = write_stream_reader
844+
self._write_stream = write_stream
824845

825846
# Start a task group for message routing
826847
async with anyio.create_task_group() as tg:
@@ -865,7 +886,7 @@ async def message_router():
865886
if request_stream_id in self._request_streams:
866887
try:
867888
# Send both the message and the event ID
868-
await self._request_streams[request_stream_id].send(
889+
await self._request_streams[request_stream_id][0].send(
869890
EventMessage(message, event_id)
870891
)
871892
except (
@@ -874,6 +895,12 @@ async def message_router():
874895
):
875896
# Stream might be closed, remove from registry
876897
self._request_streams.pop(request_stream_id, None)
898+
else:
899+
logging.debug(
900+
f"""Request stream {request_stream_id} not found
901+
for message. Still processing message as the client
902+
might reconnect and replay."""
903+
)
877904
except Exception as e:
878905
logger.exception(f"Error in message router: {e}")
879906

@@ -884,9 +911,20 @@ async def message_router():
884911
# Yield the streams for the caller to use
885912
yield read_stream, write_stream
886913
finally:
887-
for stream in list(self._request_streams.values()):
914+
for stream_id in list(self._request_streams.keys()):
888915
try:
889-
await stream.aclose()
890-
except Exception:
916+
await self._clean_up_memory_streams(stream_id)
917+
except Exception as e:
918+
logger.debug(f"Error closing request stream: {e}")
891919
pass
892920
self._request_streams.clear()
921+
922+
# Clean up the read and write streams
923+
try:
924+
await read_stream_writer.aclose()
925+
await read_stream.aclose()
926+
await write_stream_reader.aclose()
927+
await write_stream.aclose()
928+
except Exception as e:
929+
logger.debug(f"Error closing streams: {e}")
930+
pass

tests/shared/test_streamable_http.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,30 @@ async def handle_streamable_http(scope, receive, send):
231231
event_store=event_store,
232232
)
233233

234-
async with http_transport.connect() as streams:
235-
read_stream, write_stream = streams
236-
237-
async def run_server():
234+
async def run_server(task_status=None):
235+
async with http_transport.connect() as streams:
236+
read_stream, write_stream = streams
237+
if task_status:
238+
task_status.started()
238239
await server.run(
239240
read_stream,
240241
write_stream,
241242
server.create_initialization_options(),
242243
)
243244

244-
if task_group is None:
245-
response = Response(
246-
"Internal Server Error: Task group is not initialized",
247-
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
248-
)
249-
await response(scope, receive, send)
250-
return
245+
if task_group is None:
246+
response = Response(
247+
"Internal Server Error: Task group is not initialized",
248+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
249+
)
250+
await response(scope, receive, send)
251+
return
251252

252-
# Store the instance before starting the task to prevent races
253-
server_instances[http_transport.mcp_session_id] = http_transport
254-
task_group.start_soon(run_server)
253+
# Store the instance before starting the task to prevent races
254+
server_instances[http_transport.mcp_session_id] = http_transport
255+
await task_group.start(run_server)
255256

256-
await http_transport.handle_request(scope, receive, send)
257+
await http_transport.handle_request(scope, receive, send)
257258
else:
258259
response = Response(
259260
"Bad Request: No valid session ID provided",

0 commit comments

Comments
 (0)