Skip to content

Commit 0d2839e

Browse files
committed
test: add coverage for extra_headers in HTTP transport
- Add test_streamablehttp_client_tool_invocation_with_extra_headers for POST requests - Add test_streamablehttp_client_resumption_with_extra_headers for resumption with extra headers - Refactor common resumption setup code into _setup_resumption_test helper - Achieve 100% coverage for streamable_http.py
1 parent 1d58321 commit 0d2839e

File tree

2 files changed

+123
-18
lines changed

2 files changed

+123
-18
lines changed

tests/client/test_extra_headers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ async def mock_server():
557557
from mcp.types import ListToolsResult
558558

559559
result = ServerResult(ListToolsResult(tools=[]))
560-
else:
560+
else: # pragma: no cover
561561
continue
562562

563563
await server_to_client_send.send(
@@ -648,7 +648,7 @@ async def mock_server():
648648
assert jsonrpc_request.root.method == "logging/setLevel"
649649

650650
# Capture the metadata that was passed with the request
651-
if isinstance(session_message.metadata, ClientMessageMetadata):
651+
if isinstance(session_message.metadata, ClientMessageMetadata): # pragma: no branch
652652
captured_metadata.append(session_message.metadata)
653653

654654
# Send response

tests/shared/test_streamable_http.py

Lines changed: 121 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,20 @@ async def test_streamable_http_client_tool_invocation(initialized_client_session
10321032

10331033

10341034
@pytest.mark.anyio
1035-
async def test_streamable_http_client_error_handling(initialized_client_session: ClientSession):
1035+
async def test_streamablehttp_client_tool_invocation_with_extra_headers(initialized_client_session: ClientSession):
1036+
"""Test HTTP POST request with extra headers."""
1037+
result = await initialized_client_session.call_tool(
1038+
"test_tool",
1039+
{},
1040+
extra_headers={"X-Custom-Header": "test-value"},
1041+
)
1042+
assert len(result.content) == 1
1043+
assert result.content[0].type == "text"
1044+
assert result.content[0].text == "Called test_tool"
1045+
1046+
1047+
@pytest.mark.anyio
1048+
async def test_streamablehttp_client_error_handling(initialized_client_session: ClientSession):
10361049
"""Test error handling in client."""
10371050
with pytest.raises(McpError) as exc_info:
10381051
await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error"))
@@ -1245,26 +1258,27 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
12451258
await session.list_tools()
12461259

12471260

1248-
@pytest.mark.anyio
1249-
async def test_streamable_http_client_resumption(event_server: tuple[SimpleEventStore, str]):
1250-
"""Test client session resumption using sync primitives for reliable coordination."""
1251-
_, server_url = event_server
1261+
async def _setup_resumption_test(
1262+
server_url: str,
1263+
) -> tuple[str | None, str | None, str | int | None, list[types.ServerNotification]]:
1264+
"""Helper function to set up a resumption test by starting a session and capturing resumption state.
12521265
1253-
# Variables to track the state
1254-
captured_session_id = None
1255-
captured_resumption_token = None
1256-
captured_notifications: list[types.ServerNotification] = []
1257-
captured_protocol_version = None
1258-
first_notification_received = False
1266+
Returns:
1267+
Tuple of (session_id, resumption_token, protocol_version, notifications)
1268+
"""
1269+
captured_session_id = None # pragma: no cover
1270+
captured_resumption_token = None # pragma: no cover
1271+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1272+
captured_protocol_version = None # pragma: no cover
1273+
first_notification_received = False # pragma: no cover
12591274

12601275
async def message_handler( # pragma: no branch
12611276
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1262-
) -> None:
1277+
) -> None: # pragma: no cover
12631278
if isinstance(message, types.ServerNotification): # pragma: no branch
12641279
captured_notifications.append(message)
1265-
# Look for our first notification
12661280
if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch
1267-
if message.root.params.data == "First notification before lock":
1281+
if message.root.params.data == "First notification before lock": # pragma: no branch
12681282
nonlocal first_notification_received
12691283
first_notification_received = True
12701284

@@ -1320,8 +1334,32 @@ async def run_tool():
13201334
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover
13211335
assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover
13221336

1323-
# Clear notifications for the second phase
1324-
captured_notifications = [] # pragma: no cover
1337+
return (
1338+
captured_session_id,
1339+
captured_resumption_token,
1340+
captured_protocol_version,
1341+
captured_notifications,
1342+
) # pragma: no cover # noqa: E501
1343+
1344+
1345+
@pytest.mark.anyio
1346+
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1347+
"""Test client session resumption using sync primitives for reliable coordination."""
1348+
_, server_url = event_server
1349+
1350+
# Set up the initial session and capture resumption state
1351+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1352+
server_url
1353+
)
1354+
1355+
# Track notifications for the resumed session
1356+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1357+
1358+
async def message_handler( # pragma: no branch
1359+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1360+
) -> None: # pragma: no cover
1361+
if isinstance(message, types.ServerNotification): # pragma: no branch
1362+
captured_notifications.append(message)
13251363

13261364
# Now resume the session with the same mcp-session-id and protocol version
13271365
headers: dict[str, Any] = {} # pragma: no cover
@@ -1369,6 +1407,73 @@ async def run_tool():
13691407
assert captured_notifications[0].root.params.data == "Second notification after lock" # pragma: no cover
13701408

13711409

1410+
@pytest.mark.anyio
1411+
async def test_streamablehttp_client_resumption_with_extra_headers(event_server: tuple[SimpleEventStore, str]):
1412+
"""Test client session resumption with extra headers."""
1413+
_, server_url = event_server
1414+
1415+
# Set up the initial session and capture resumption state
1416+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1417+
server_url
1418+
)
1419+
1420+
# Track notifications for the resumed session
1421+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1422+
1423+
async def message_handler( # pragma: no branch
1424+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1425+
) -> None: # pragma: no cover
1426+
if isinstance(message, types.ServerNotification): # pragma: no branch
1427+
captured_notifications.append(message)
1428+
1429+
# Now resume the session with the same mcp-session-id and protocol version
1430+
headers: dict[str, Any] = {} # pragma: no cover
1431+
if captured_session_id: # pragma: no cover
1432+
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1433+
if captured_protocol_version: # pragma: no cover
1434+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1435+
1436+
async with create_mcp_http_client(headers=headers) as httpx_client:
1437+
async with streamable_http_client(f"{server_url}/mcp", http_client=httpx_client) as (
1438+
read_stream,
1439+
write_stream,
1440+
_,
1441+
):
1442+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1443+
result = await session.send_request(
1444+
types.ClientRequest(
1445+
types.CallToolRequest(
1446+
params=types.CallToolRequestParams(name="release_lock", arguments={}),
1447+
)
1448+
),
1449+
types.CallToolResult,
1450+
)
1451+
# Test resumption WITH extra_headers
1452+
metadata = ClientMessageMetadata(
1453+
resumption_token=captured_resumption_token,
1454+
extra_headers={"X-Resumption-Test": "test-value"},
1455+
)
1456+
1457+
result = await session.send_request(
1458+
types.ClientRequest(
1459+
types.CallToolRequest(
1460+
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
1461+
)
1462+
),
1463+
types.CallToolResult,
1464+
metadata=metadata,
1465+
)
1466+
assert len(result.content) == 1
1467+
assert result.content[0].type == "text"
1468+
assert result.content[0].text == "Completed"
1469+
1470+
# We should have received the remaining notifications
1471+
assert len(captured_notifications) == 1
1472+
1473+
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
1474+
assert captured_notifications[0].root.params.data == "Second notification after lock"
1475+
1476+
13721477
@pytest.mark.anyio
13731478
async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str):
13741479
"""Test server-initiated sampling request through streamable HTTP transport."""

0 commit comments

Comments
 (0)