@@ -80,7 +80,7 @@ def __init__(self):
8080 self ._events : list [tuple [StreamId , EventId , types .JSONRPCMessage | None ]] = []
8181 self ._event_id_counter = 0
8282
83- async def store_event (self , stream_id : StreamId , message : types .JSONRPCMessage | None ) -> EventId : # pragma: no cover
83+ async def store_event (self , stream_id : StreamId , message : types .JSONRPCMessage | None ) -> EventId :
8484 """Store an event and return its ID."""
8585 self ._event_id_counter += 1
8686 event_id = str (self ._event_id_counter )
@@ -175,6 +175,17 @@ async def handle_list_tools() -> list[Tool]:
175175 description = "Tool that sends notification1, closes stream, sends notification2, notification3" ,
176176 inputSchema = {"type" : "object" , "properties" : {}},
177177 ),
178+ Tool (
179+ name = "tool_with_multiple_stream_closes" ,
180+ description = "Tool that closes SSE stream multiple times during execution" ,
181+ inputSchema = {
182+ "type" : "object" ,
183+ "properties" : {
184+ "checkpoints" : {"type" : "integer" , "default" : 3 },
185+ "sleep_time" : {"type" : "number" , "default" : 0.2 },
186+ },
187+ },
188+ ),
178189 ]
179190
180191 @self .call_tool ()
@@ -314,6 +325,25 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
314325 )
315326 return [TextContent (type = "text" , text = "All notifications sent" )]
316327
328+ elif name == "tool_with_multiple_stream_closes" :
329+ num_checkpoints = args .get ("checkpoints" , 3 )
330+ sleep_time = args .get ("sleep_time" , 0.2 )
331+
332+ for i in range (num_checkpoints ):
333+ await ctx .session .send_log_message (
334+ level = "info" ,
335+ data = f"checkpoint_{ i } " ,
336+ logger = "multi_close_tool" ,
337+ related_request_id = ctx .request_id ,
338+ )
339+
340+ if ctx .close_sse_stream :
341+ await ctx .close_sse_stream ()
342+
343+ await anyio .sleep (sleep_time )
344+
345+ return [TextContent (type = "text" , text = f"Completed { num_checkpoints } checkpoints" )]
346+
317347 return [TextContent (type = "text" , text = f"Called { name } " )]
318348
319349
@@ -950,7 +980,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
950980 """Test client tool invocation."""
951981 # First list tools
952982 tools = await initialized_client_session .list_tools ()
953- assert len (tools .tools ) == 8
983+ assert len (tools .tools ) == 9
954984 assert tools .tools [0 ].name == "test_tool"
955985
956986 # Call the tool
@@ -987,7 +1017,7 @@ async def test_streamablehttp_client_session_persistence(basic_server: None, bas
9871017
9881018 # Make multiple requests to verify session persistence
9891019 tools = await session .list_tools ()
990- assert len (tools .tools ) == 8
1020+ assert len (tools .tools ) == 9
9911021
9921022 # Read a resource
9931023 resource = await session .read_resource (uri = AnyUrl ("foobar://test-persist" ))
@@ -1016,7 +1046,7 @@ async def test_streamablehttp_client_json_response(json_response_server: None, j
10161046
10171047 # Check tool listing
10181048 tools = await session .list_tools ()
1019- assert len (tools .tools ) == 8
1049+ assert len (tools .tools ) == 9
10201050
10211051 # Call a tool and verify JSON response handling
10221052 result = await session .call_tool ("test_tool" , {})
@@ -1086,7 +1116,7 @@ async def test_streamablehttp_client_session_termination(basic_server: None, bas
10861116
10871117 # Make a request to confirm session is working
10881118 tools = await session .list_tools ()
1089- assert len (tools .tools ) == 8
1119+ assert len (tools .tools ) == 9
10901120
10911121 headers : dict [str , str ] = {} # pragma: no cover
10921122 if captured_session_id : # pragma: no cover
@@ -1152,7 +1182,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11521182
11531183 # Make a request to confirm session is working
11541184 tools = await session .list_tools ()
1155- assert len (tools .tools ) == 8
1185+ assert len (tools .tools ) == 9
11561186
11571187 headers : dict [str , str ] = {} # pragma: no cover
11581188 if captured_session_id : # pragma: no cover
@@ -1944,3 +1974,56 @@ async def message_handler(
19441974 assert result .content [0 ].type == "text"
19451975 assert isinstance (result .content [0 ], TextContent )
19461976 assert result .content [0 ].text == "All notifications sent"
1977+
1978+
1979+ @pytest .mark .anyio
1980+ async def test_streamablehttp_multiple_reconnections (
1981+ event_server : tuple [SimpleEventStore , str ],
1982+ ):
1983+ """Verify multiple close_sse_stream() calls each trigger a client reconnect.
1984+
1985+ Server uses retry_interval=500ms, tool sleeps 600ms after each close to ensure
1986+ client has time to reconnect before the next checkpoint.
1987+
1988+ With 3 checkpoints, we expect 8 resumption tokens:
1989+ - 1 priming (initial POST connection)
1990+ - 3 notifications (checkpoint_0, checkpoint_1, checkpoint_2)
1991+ - 3 priming (one per reconnect after each close)
1992+ - 1 response
1993+ """
1994+ _ , server_url = event_server
1995+ resumption_tokens : list [str ] = []
1996+
1997+ async def on_resumption_token (token : str ) -> None :
1998+ resumption_tokens .append (token )
1999+
2000+ async with streamablehttp_client (f"{ server_url } /mcp" ) as (read_stream , write_stream , _ ):
2001+ async with ClientSession (read_stream , write_stream ) as session :
2002+ await session .initialize ()
2003+
2004+ # Use send_request with metadata to track resumption tokens
2005+ metadata = ClientMessageMetadata (on_resumption_token_update = on_resumption_token )
2006+ result = await session .send_request (
2007+ types .ClientRequest (
2008+ types .CallToolRequest (
2009+ method = "tools/call" ,
2010+ params = types .CallToolRequestParams (
2011+ name = "tool_with_multiple_stream_closes" ,
2012+ # retry_interval=500ms, so sleep 600ms to ensure reconnect completes
2013+ arguments = {"checkpoints" : 3 , "sleep_time" : 0.6 },
2014+ ),
2015+ )
2016+ ),
2017+ types .CallToolResult ,
2018+ metadata = metadata ,
2019+ )
2020+
2021+ assert result .content [0 ].type == "text"
2022+ assert isinstance (result .content [0 ], TextContent )
2023+ assert "Completed 3 checkpoints" in result .content [0 ].text
2024+
2025+ # 4 priming + 3 notifications + 1 response = 8 tokens
2026+ assert len (resumption_tokens ) == 8 , (
2027+ f"Expected 8 resumption tokens (4 priming + 3 notifs + 1 response), "
2028+ f"got { len (resumption_tokens )} : { resumption_tokens } "
2029+ )
0 commit comments