Skip to content

Commit a51e6a1

Browse files
committed
add tests
1 parent c9c0872 commit a51e6a1

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

tests/shared/test_streamable_http.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Contains tests for both server and client sides of the StreamableHTTP transport.
55
"""
66

7+
import contextlib
78
import json
89
import multiprocessing
910
import socket
@@ -2393,3 +2394,269 @@ async def test_streamablehttp_client_deprecation_warning(basic_server: None, bas
23932394
await session.initialize()
23942395
tools = await session.list_tools()
23952396
assert len(tools.tools) > 0
2397+
2398+
2399+
@pytest.mark.anyio
2400+
async def test_sse_stream_ends_without_completing_no_event_id() -> None:
2401+
"""Test that SSE stream ending without completing and no event ID sends error response."""
2402+
from unittest.mock import MagicMock, patch
2403+
2404+
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
2405+
from mcp.shared.message import SessionMessage
2406+
from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest
2407+
2408+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
2409+
2410+
# Create a mock response that returns an empty SSE stream (no events)
2411+
mock_response = MagicMock()
2412+
2413+
async def mock_aclose() -> None:
2414+
pass
2415+
2416+
mock_response.aclose = mock_aclose
2417+
2418+
# Create a mock EventSource that yields no events
2419+
async def empty_iter():
2420+
return
2421+
yield # Make it an async generator that yields nothing
2422+
2423+
mock_event_source = MagicMock()
2424+
mock_event_source.aiter_sse = empty_iter
2425+
2426+
# Create streams for testing
2427+
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
2428+
2429+
# Create a request context
2430+
mock_client = MagicMock()
2431+
mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test"))
2432+
session_message = SessionMessage(message=mock_message)
2433+
2434+
ctx = RequestContext(
2435+
client=mock_client,
2436+
session_id="test-session",
2437+
session_message=session_message,
2438+
metadata=None,
2439+
read_stream_writer=write_stream,
2440+
)
2441+
2442+
try:
2443+
with patch("mcp.client.streamable_http.EventSource", return_value=mock_event_source):
2444+
await transport._handle_sse_response(mock_response, ctx, is_initialization=False)
2445+
2446+
# Should have received an error response
2447+
received = await read_stream.receive()
2448+
assert isinstance(received, SessionMessage)
2449+
assert isinstance(received.message.root, JSONRPCError)
2450+
assert "SSE stream ended without completing" in received.message.root.error.message
2451+
finally:
2452+
await write_stream.aclose()
2453+
await read_stream.aclose()
2454+
2455+
2456+
@pytest.mark.anyio
2457+
async def test_handle_post_request_non_init_error_sends_error_response() -> None:
2458+
"""Test that non-initialization request errors send error response instead of raising."""
2459+
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
2460+
from mcp.shared.message import SessionMessage
2461+
from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest
2462+
2463+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
2464+
2465+
# Create streams for testing
2466+
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
2467+
2468+
# Create a non-initialization request
2469+
mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list"))
2470+
session_message = SessionMessage(message=mock_message)
2471+
2472+
# Create a mock client that raises an exception
2473+
mock_client = MagicMock()
2474+
2475+
# Create an async context manager that raises
2476+
class FailingStream:
2477+
async def __aenter__(self) -> None:
2478+
raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500))
2479+
2480+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
2481+
pass
2482+
2483+
mock_client.stream = MagicMock(return_value=FailingStream())
2484+
2485+
ctx = RequestContext(
2486+
client=mock_client,
2487+
session_id="test-session",
2488+
session_message=session_message,
2489+
metadata=None,
2490+
read_stream_writer=write_stream,
2491+
)
2492+
2493+
try:
2494+
# This should NOT raise, but send an error response
2495+
await transport._handle_post_request(ctx)
2496+
2497+
# Should have received an error response
2498+
received = await read_stream.receive()
2499+
assert isinstance(received, SessionMessage)
2500+
assert isinstance(received.message.root, JSONRPCError)
2501+
finally:
2502+
await write_stream.aclose()
2503+
await read_stream.aclose()
2504+
2505+
2506+
@pytest.mark.anyio
2507+
async def test_handle_post_request_init_error_raises() -> None:
2508+
"""Test that initialization request errors are raised, not sent as error response."""
2509+
from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport
2510+
from mcp.shared.message import SessionMessage
2511+
from mcp.types import JSONRPCMessage, JSONRPCRequest
2512+
2513+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
2514+
2515+
# Create streams for testing
2516+
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
2517+
2518+
# Create an initialization request
2519+
mock_message = JSONRPCMessage(
2520+
root=JSONRPCRequest(
2521+
jsonrpc="2.0",
2522+
id="init-1",
2523+
method="initialize",
2524+
params={
2525+
"clientInfo": {"name": "test", "version": "1.0"},
2526+
"protocolVersion": "2025-03-26",
2527+
"capabilities": {},
2528+
},
2529+
)
2530+
)
2531+
session_message = SessionMessage(message=mock_message)
2532+
2533+
# Create a mock client that raises an exception
2534+
mock_client = MagicMock()
2535+
2536+
class FailingStream:
2537+
async def __aenter__(self) -> None:
2538+
raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500))
2539+
2540+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
2541+
pass
2542+
2543+
mock_client.stream = MagicMock(return_value=FailingStream())
2544+
2545+
ctx = RequestContext(
2546+
client=mock_client,
2547+
session_id=None,
2548+
session_message=session_message,
2549+
metadata=None,
2550+
read_stream_writer=write_stream,
2551+
)
2552+
2553+
try:
2554+
# This SHOULD raise for initialization requests
2555+
with pytest.raises(httpx.HTTPStatusError):
2556+
await transport._handle_post_request(ctx)
2557+
finally:
2558+
await write_stream.aclose()
2559+
await read_stream.aclose()
2560+
2561+
2562+
@pytest.mark.anyio
2563+
async def test_handle_reconnection_max_attempts_exceeded() -> None:
2564+
"""Test that _handle_reconnection raises when max attempts exceeded."""
2565+
from mcp.client.streamable_http import (
2566+
MAX_RECONNECTION_ATTEMPTS,
2567+
RequestContext,
2568+
StreamableHTTPTransport,
2569+
)
2570+
from mcp.shared.message import SessionMessage
2571+
from mcp.types import JSONRPCMessage, JSONRPCRequest
2572+
2573+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
2574+
2575+
# Create streams for testing
2576+
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
2577+
2578+
# Create a request context
2579+
mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test"))
2580+
session_message = SessionMessage(message=mock_message)
2581+
2582+
ctx = RequestContext(
2583+
client=MagicMock(),
2584+
session_id="test-session",
2585+
session_message=session_message,
2586+
metadata=None,
2587+
read_stream_writer=write_stream,
2588+
)
2589+
2590+
try:
2591+
# Call with attempt >= MAX_RECONNECTION_ATTEMPTS should raise
2592+
with pytest.raises(Exception, match="SSE stream reconnection failed"):
2593+
await transport._handle_reconnection(
2594+
ctx,
2595+
last_event_id="test-event-id",
2596+
retry_interval_ms=1, # Use 1ms to speed up test
2597+
attempt=MAX_RECONNECTION_ATTEMPTS,
2598+
)
2599+
finally:
2600+
await write_stream.aclose()
2601+
await read_stream.aclose()
2602+
2603+
2604+
@pytest.mark.anyio
2605+
async def test_handle_reconnection_failure_retries() -> None:
2606+
"""Test that _handle_reconnection retries on failure and eventually raises."""
2607+
from collections.abc import AsyncGenerator
2608+
from unittest.mock import MagicMock, patch
2609+
2610+
from mcp.client.streamable_http import (
2611+
MAX_RECONNECTION_ATTEMPTS,
2612+
RequestContext,
2613+
StreamableHTTPTransport,
2614+
)
2615+
from mcp.shared.message import SessionMessage
2616+
from mcp.types import JSONRPCMessage, JSONRPCRequest
2617+
2618+
transport = StreamableHTTPTransport(url="http://localhost:8000/mcp")
2619+
2620+
# Create streams for testing
2621+
write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10)
2622+
2623+
# Create a mock client
2624+
mock_client = MagicMock()
2625+
2626+
# Create a request context
2627+
mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test"))
2628+
session_message = SessionMessage(message=mock_message)
2629+
2630+
ctx = RequestContext(
2631+
client=mock_client,
2632+
session_id="test-session",
2633+
session_message=session_message,
2634+
metadata=None,
2635+
read_stream_writer=write_stream,
2636+
)
2637+
2638+
# Track how many times aconnect_sse is called
2639+
call_count = 0
2640+
2641+
@contextlib.asynccontextmanager
2642+
async def failing_aconnect_sse(*args: Any, **kwargs: Any) -> AsyncGenerator[None, None]:
2643+
nonlocal call_count
2644+
call_count += 1
2645+
raise httpx.HTTPStatusError("Connection failed", request=MagicMock(), response=MagicMock(status_code=503))
2646+
yield # Make it an async generator
2647+
2648+
try:
2649+
with patch("mcp.client.streamable_http.aconnect_sse", failing_aconnect_sse):
2650+
with pytest.raises(Exception, match="SSE stream reconnection failed"):
2651+
await transport._handle_reconnection(
2652+
ctx,
2653+
last_event_id="test-event-id",
2654+
retry_interval_ms=1, # Use 1ms to speed up test
2655+
attempt=0,
2656+
)
2657+
2658+
# Should have tried MAX_RECONNECTION_ATTEMPTS times
2659+
assert call_count == MAX_RECONNECTION_ATTEMPTS
2660+
finally:
2661+
await write_stream.aclose()
2662+
await read_stream.aclose()

0 commit comments

Comments
 (0)