@@ -1032,7 +1032,20 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session
10321032
10331033
10341034@pytest .mark .anyio
1035- async def test_streamable_http_client_error_handling (initialized_client_session : ClientSession ):
1035+ async def test_streamablehttp_client_tool_invocation_with_extra_headers (initialized_client_session : ClientSession ):
1036+ """Test HTTP POST request with extra headers."""
1037+ result = await initialized_client_session .call_tool (
1038+ "test_tool" ,
1039+ {},
1040+ extra_headers = {"X-Custom-Header" : "test-value" },
1041+ )
1042+ assert len (result .content ) == 1
1043+ assert result .content [0 ].type == "text"
1044+ assert result .content [0 ].text == "Called test_tool"
1045+
1046+
1047+ @pytest .mark .anyio
1048+ async def test_streamablehttp_client_error_handling (initialized_client_session : ClientSession ):
10361049 """Test error handling in client."""
10371050 with pytest .raises (McpError ) as exc_info :
10381051 await initialized_client_session .read_resource (uri = AnyUrl ("unknown://test-error" ))
@@ -1245,26 +1258,27 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
12451258 await session .list_tools ()
12461259
12471260
1248- @ pytest . mark . anyio
1249- async def test_streamable_http_client_resumption ( event_server : tuple [ SimpleEventStore , str ]):
1250- """Test client session resumption using sync primitives for reliable coordination."""
1251- _ , server_url = event_server
1261+ async def _setup_resumption_test (
1262+ server_url : str ,
1263+ ) -> tuple [ str | None , str | None , str | int | None , list [ types . ServerNotification ]]:
1264+ """Helper function to set up a resumption test by starting a session and capturing resumption state.
12521265
1253- # Variables to track the state
1254- captured_session_id = None
1255- captured_resumption_token = None
1256- captured_notifications : list [types .ServerNotification ] = []
1257- captured_protocol_version = None
1258- first_notification_received = False
1266+ Returns:
1267+ Tuple of (session_id, resumption_token, protocol_version, notifications)
1268+ """
1269+ captured_session_id = None # pragma: no cover
1270+ captured_resumption_token = None # pragma: no cover
1271+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1272+ captured_protocol_version = None # pragma: no cover
1273+ first_notification_received = False # pragma: no cover
12591274
12601275 async def message_handler ( # pragma: no branch
12611276 message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1262- ) -> None :
1277+ ) -> None : # pragma: no cover
12631278 if isinstance (message , types .ServerNotification ): # pragma: no branch
12641279 captured_notifications .append (message )
1265- # Look for our first notification
12661280 if isinstance (message .root , types .LoggingMessageNotification ): # pragma: no branch
1267- if message .root .params .data == "First notification before lock" :
1281+ if message .root .params .data == "First notification before lock" : # pragma: no branch
12681282 nonlocal first_notification_received
12691283 first_notification_received = True
12701284
@@ -1320,8 +1334,32 @@ async def run_tool():
13201334 assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification ) # pragma: no cover
13211335 assert captured_notifications [0 ].root .params .data == "First notification before lock" # pragma: no cover
13221336
1323- # Clear notifications for the second phase
1324- captured_notifications = [] # pragma: no cover
1337+ return (
1338+ captured_session_id ,
1339+ captured_resumption_token ,
1340+ captured_protocol_version ,
1341+ captured_notifications ,
1342+ ) # pragma: no cover # noqa: E501
1343+
1344+
1345+ @pytest .mark .anyio
1346+ async def test_streamablehttp_client_resumption (event_server : tuple [SimpleEventStore , str ]):
1347+ """Test client session resumption using sync primitives for reliable coordination."""
1348+ _ , server_url = event_server
1349+
1350+ # Set up the initial session and capture resumption state
1351+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1352+ server_url
1353+ )
1354+
1355+ # Track notifications for the resumed session
1356+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1357+
1358+ async def message_handler ( # pragma: no branch
1359+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1360+ ) -> None : # pragma: no cover
1361+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1362+ captured_notifications .append (message )
13251363
13261364 # Now resume the session with the same mcp-session-id and protocol version
13271365 headers : dict [str , Any ] = {} # pragma: no cover
@@ -1369,6 +1407,73 @@ async def run_tool():
13691407 assert captured_notifications [0 ].root .params .data == "Second notification after lock" # pragma: no cover
13701408
13711409
1410+ @pytest .mark .anyio
1411+ async def test_streamablehttp_client_resumption_with_extra_headers (event_server : tuple [SimpleEventStore , str ]):
1412+ """Test client session resumption with extra headers."""
1413+ _ , server_url = event_server
1414+
1415+ # Set up the initial session and capture resumption state
1416+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1417+ server_url
1418+ )
1419+
1420+ # Track notifications for the resumed session
1421+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1422+
1423+ async def message_handler ( # pragma: no branch
1424+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1425+ ) -> None : # pragma: no cover
1426+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1427+ captured_notifications .append (message )
1428+
1429+ # Now resume the session with the same mcp-session-id and protocol version
1430+ headers : dict [str , Any ] = {} # pragma: no cover
1431+ if captured_session_id : # pragma: no cover
1432+ headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1433+ if captured_protocol_version : # pragma: no cover
1434+ headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1435+
1436+ async with create_mcp_http_client (headers = headers ) as httpx_client :
1437+ async with streamable_http_client (f"{ server_url } /mcp" , http_client = httpx_client ) as (
1438+ read_stream ,
1439+ write_stream ,
1440+ _ ,
1441+ ):
1442+ async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1443+ result = await session .send_request (
1444+ types .ClientRequest (
1445+ types .CallToolRequest (
1446+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1447+ )
1448+ ),
1449+ types .CallToolResult ,
1450+ )
1451+ # Test resumption WITH extra_headers
1452+ metadata = ClientMessageMetadata (
1453+ resumption_token = captured_resumption_token ,
1454+ extra_headers = {"X-Resumption-Test" : "test-value" },
1455+ )
1456+
1457+ result = await session .send_request (
1458+ types .ClientRequest (
1459+ types .CallToolRequest (
1460+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification" , arguments = {}),
1461+ )
1462+ ),
1463+ types .CallToolResult ,
1464+ metadata = metadata ,
1465+ )
1466+ assert len (result .content ) == 1
1467+ assert result .content [0 ].type == "text"
1468+ assert result .content [0 ].text == "Completed"
1469+
1470+ # We should have received the remaining notifications
1471+ assert len (captured_notifications ) == 1
1472+
1473+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1474+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
1475+
1476+
13721477@pytest .mark .anyio
13731478async def test_streamablehttp_server_sampling (basic_server : None , basic_server_url : str ):
13741479 """Test server-initiated sampling request through streamable HTTP transport."""
0 commit comments