Skip to content

Commit fdcd8f5

Browse files
Fix multiple close_sse_stream support and add reconnection logging (SEP-1699)
- Register SSE writer in _replay_events() so subsequent close_sse_stream() calls work - Send priming event on each reconnection - Handle ClosedResourceError gracefully in both POST and GET SSE writers - Add disconnect/reconnect logging at INFO level for visibility - Add test for multiple reconnections during long-running tool calls - Remove pragma from store_event (now covered by tests)
1 parent 3161353 commit fdcd8f5

File tree

3 files changed

+111
-7
lines changed

3 files changed

+111
-7
lines changed

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ async def _handle_sse_response(
363363

364364
# Stream ended without response - reconnect if we received an event with ID
365365
if last_event_id is not None:
366+
logger.info("SSE stream disconnected, reconnecting...")
366367
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
367368

368369
async def _handle_reconnection(
@@ -399,7 +400,7 @@ async def _handle_reconnection(
399400
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
400401
) as event_source:
401402
event_source.response.raise_for_status()
402-
logger.debug("Reconnection GET SSE connection established")
403+
logger.info("Reconnected to SSE stream")
403404

404405
# Track for potential further reconnection
405406
reconnect_last_event_id: str | None = last_event_id
@@ -423,6 +424,7 @@ async def _handle_reconnection(
423424

424425
# Stream ended again without response - reconnect again (reset attempt counter)
425426
if reconnect_last_event_id is not None:
427+
logger.info("SSE stream disconnected, reconnecting...")
426428
await self._handle_reconnection(
427429
ctx, reconnect_last_event_id, reconnect_retry_ms, 0
428430
)

src/mcp/server/streamable_http.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ def close_sse_stream(self, request_id: RequestId) -> None:
208208
if writer:
209209
writer.close()
210210

211+
# Also close and remove request streams
212+
if request_id in self._request_streams:
213+
send_stream, receive_stream = self._request_streams.pop(request_id)
214+
send_stream.close()
215+
receive_stream.close()
216+
211217
def _create_session_message(
212218
self,
213219
message: JSONRPCMessage,
@@ -545,6 +551,9 @@ async def sse_writer():
545551
JSONRPCResponse | JSONRPCError,
546552
):
547553
break
554+
except anyio.ClosedResourceError:
555+
# Expected when close_sse_stream() is called
556+
logger.debug("SSE stream closed by close_sse_stream()")
548557
except Exception:
549558
logger.exception("Error in SSE writer")
550559
finally:
@@ -848,6 +857,13 @@ async def send_event(event_message: EventMessage) -> None:
848857

849858
# If stream ID not in mapping, create it
850859
if stream_id and stream_id not in self._request_streams:
860+
# Register SSE writer so close_sse_stream() can close it
861+
self._sse_stream_writers[stream_id] = sse_stream_writer
862+
863+
# Send priming event for this new connection
864+
await self._send_priming_event(stream_id, sse_stream_writer)
865+
866+
# Create new request streams for this connection
851867
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
852868
msg_reader = self._request_streams[stream_id][1]
853869

@@ -857,6 +873,9 @@ async def send_event(event_message: EventMessage) -> None:
857873
event_data = self._create_event_data(event_message)
858874

859875
await sse_stream_writer.send(event_data)
876+
except anyio.ClosedResourceError:
877+
# Expected when close_sse_stream() is called
878+
logger.debug("Replay SSE stream closed by close_sse_stream()")
860879
except Exception:
861880
logger.exception("Error in replay sender")
862881

tests/shared/test_streamable_http.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self):
8080
self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage | None]] = []
8181
self._event_id_counter = 0
8282

83-
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId: # pragma: no cover
83+
async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | None) -> EventId:
8484
"""Store an event and return its ID."""
8585
self._event_id_counter += 1
8686
event_id = str(self._event_id_counter)
@@ -175,6 +175,17 @@ async def handle_list_tools() -> list[Tool]:
175175
description="Tool that sends notification1, closes stream, sends notification2, notification3",
176176
inputSchema={"type": "object", "properties": {}},
177177
),
178+
Tool(
179+
name="tool_with_multiple_stream_closes",
180+
description="Tool that closes SSE stream multiple times during execution",
181+
inputSchema={
182+
"type": "object",
183+
"properties": {
184+
"checkpoints": {"type": "integer", "default": 3},
185+
"sleep_time": {"type": "number", "default": 0.2},
186+
},
187+
},
188+
),
178189
]
179190

180191
@self.call_tool()
@@ -314,6 +325,25 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
314325
)
315326
return [TextContent(type="text", text="All notifications sent")]
316327

328+
elif name == "tool_with_multiple_stream_closes":
329+
num_checkpoints = args.get("checkpoints", 3)
330+
sleep_time = args.get("sleep_time", 0.2)
331+
332+
for i in range(num_checkpoints):
333+
await ctx.session.send_log_message(
334+
level="info",
335+
data=f"checkpoint_{i}",
336+
logger="multi_close_tool",
337+
related_request_id=ctx.request_id,
338+
)
339+
340+
if ctx.close_sse_stream:
341+
await ctx.close_sse_stream()
342+
343+
await anyio.sleep(sleep_time)
344+
345+
return [TextContent(type="text", text=f"Completed {num_checkpoints} checkpoints")]
346+
317347
return [TextContent(type="text", text=f"Called {name}")]
318348

319349

@@ -950,7 +980,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
950980
"""Test client tool invocation."""
951981
# First list tools
952982
tools = await initialized_client_session.list_tools()
953-
assert len(tools.tools) == 8
983+
assert len(tools.tools) == 9
954984
assert tools.tools[0].name == "test_tool"
955985

956986
# Call the tool
@@ -987,7 +1017,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas
9871017

9881018
# Make multiple requests to verify session persistence
9891019
tools = await session.list_tools()
990-
assert len(tools.tools) == 8
1020+
assert len(tools.tools) == 9
9911021

9921022
# Read a resource
9931023
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
@@ -1016,7 +1046,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
10161046

10171047
# Check tool listing
10181048
tools = await session.list_tools()
1019-
assert len(tools.tools) == 8
1049+
assert len(tools.tools) == 9
10201050

10211051
# Call a tool and verify JSON response handling
10221052
result = await session.call_tool("test_tool", {})
@@ -1086,7 +1116,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas
10861116

10871117
# Make a request to confirm session is working
10881118
tools = await session.list_tools()
1089-
assert len(tools.tools) == 8
1119+
assert len(tools.tools) == 9
10901120

10911121
headers: dict[str, str] = {} # pragma: no cover
10921122
if captured_session_id: # pragma: no cover
@@ -1152,7 +1182,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11521182

11531183
# Make a request to confirm session is working
11541184
tools = await session.list_tools()
1155-
assert len(tools.tools) == 8
1185+
assert len(tools.tools) == 9
11561186

11571187
headers: dict[str, str] = {} # pragma: no cover
11581188
if captured_session_id: # pragma: no cover
@@ -1944,3 +1974,56 @@ async def message_handler(
19441974
assert result.content[0].type == "text"
19451975
assert isinstance(result.content[0], TextContent)
19461976
assert result.content[0].text == "All notifications sent"
1977+
1978+
1979+
@pytest.mark.anyio
1980+
async def test_streamablehttp_multiple_reconnections(
1981+
event_server: tuple[SimpleEventStore, str],
1982+
):
1983+
"""Verify multiple close_sse_stream() calls each trigger a client reconnect.
1984+
1985+
Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure
1986+
client has time to reconnect before the next checkpoint.
1987+
1988+
With 3 checkpoints, we expect 8 resumption tokens:
1989+
- 1 priming (initial POST connection)
1990+
- 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2)
1991+
- 3 priming (one per reconnect after each close)
1992+
- 1 response
1993+
"""
1994+
_, server_url = event_server
1995+
resumption_tokens: list[str] = []
1996+
1997+
async def on_resumption_token(token: str) -> None:
1998+
resumption_tokens.append(token)
1999+
2000+
async with streamablehttp_client(f"{server_url}/mcp") as (read_stream, write_stream, _):
2001+
async with ClientSession(read_stream, write_stream) as session:
2002+
await session.initialize()
2003+
2004+
# Use send_request with metadata to track resumption tokens
2005+
metadata = ClientMessageMetadata(on_resumption_token_update=on_resumption_token)
2006+
result = await session.send_request(
2007+
types.ClientRequest(
2008+
types.CallToolRequest(
2009+
method="tools/call",
2010+
params=types.CallToolRequestParams(
2011+
name="tool_with_multiple_stream_closes",
2012+
# retry_interval=500ms, so sleep 600ms to ensure reconnect completes
2013+
arguments={"checkpoints": 3, "sleep_time": 0.6},
2014+
),
2015+
)
2016+
),
2017+
types.CallToolResult,
2018+
metadata=metadata,
2019+
)
2020+
2021+
assert result.content[0].type == "text"
2022+
assert isinstance(result.content[0], TextContent)
2023+
assert "Completed 3 checkpoints" in result.content[0].text
2024+
2025+
# 4 priming + 3 notifications + 1 response = 8 tokens
2026+
assert len(resumption_tokens) == 8, (
2027+
f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2028+
f"got {len(resumption_tokens)}: {resumption_tokens}"
2029+
)

0 commit comments

Comments
 (0)