|
4 | 4 | Contains tests for both server and client sides of the StreamableHTTP transport. |
5 | 5 | """ |
6 | 6 |
|
| 7 | +import contextlib |
7 | 8 | import json |
8 | 9 | import multiprocessing |
9 | 10 | import socket |
@@ -2393,3 +2394,269 @@ async def test_streamablehttp_client_deprecation_warning(basic_server: None, bas |
2393 | 2394 | await session.initialize() |
2394 | 2395 | tools = await session.list_tools() |
2395 | 2396 | assert len(tools.tools) > 0 |
| 2397 | + |
| 2398 | + |
| 2399 | +@pytest.mark.anyio |
| 2400 | +async def test_sse_stream_ends_without_completing_no_event_id() -> None: |
| 2401 | + """Test that SSE stream ending without completing and no event ID sends error response.""" |
| 2402 | + from unittest.mock import MagicMock, patch |
| 2403 | + |
| 2404 | + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport |
| 2405 | + from mcp.shared.message import SessionMessage |
| 2406 | + from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest |
| 2407 | + |
| 2408 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2409 | + |
| 2410 | + # Create a mock response that returns an empty SSE stream (no events) |
| 2411 | + mock_response = MagicMock() |
| 2412 | + |
| 2413 | + async def mock_aclose() -> None: |
| 2414 | + pass |
| 2415 | + |
| 2416 | + mock_response.aclose = mock_aclose |
| 2417 | + |
| 2418 | + # Create a mock EventSource that yields no events |
| 2419 | + async def empty_iter(): |
| 2420 | + return |
| 2421 | + yield # Make it an async generator that yields nothing |
| 2422 | + |
| 2423 | + mock_event_source = MagicMock() |
| 2424 | + mock_event_source.aiter_sse = empty_iter |
| 2425 | + |
| 2426 | + # Create streams for testing |
| 2427 | + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 2428 | + |
| 2429 | + # Create a request context |
| 2430 | + mock_client = MagicMock() |
| 2431 | + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) |
| 2432 | + session_message = SessionMessage(message=mock_message) |
| 2433 | + |
| 2434 | + ctx = RequestContext( |
| 2435 | + client=mock_client, |
| 2436 | + session_id="test-session", |
| 2437 | + session_message=session_message, |
| 2438 | + metadata=None, |
| 2439 | + read_stream_writer=write_stream, |
| 2440 | + ) |
| 2441 | + |
| 2442 | + try: |
| 2443 | + with patch("mcp.client.streamable_http.EventSource", return_value=mock_event_source): |
| 2444 | + await transport._handle_sse_response(mock_response, ctx, is_initialization=False) |
| 2445 | + |
| 2446 | + # Should have received an error response |
| 2447 | + received = await read_stream.receive() |
| 2448 | + assert isinstance(received, SessionMessage) |
| 2449 | + assert isinstance(received.message.root, JSONRPCError) |
| 2450 | + assert "SSE stream ended without completing" in received.message.root.error.message |
| 2451 | + finally: |
| 2452 | + await write_stream.aclose() |
| 2453 | + await read_stream.aclose() |
| 2454 | + |
| 2455 | + |
| 2456 | +@pytest.mark.anyio |
| 2457 | +async def test_handle_post_request_non_init_error_sends_error_response() -> None: |
| 2458 | + """Test that non-initialization request errors send error response instead of raising.""" |
| 2459 | + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport |
| 2460 | + from mcp.shared.message import SessionMessage |
| 2461 | + from mcp.types import JSONRPCError, JSONRPCMessage, JSONRPCRequest |
| 2462 | + |
| 2463 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2464 | + |
| 2465 | + # Create streams for testing |
| 2466 | + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 2467 | + |
| 2468 | + # Create a non-initialization request |
| 2469 | + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) |
| 2470 | + session_message = SessionMessage(message=mock_message) |
| 2471 | + |
| 2472 | + # Create a mock client that raises an exception |
| 2473 | + mock_client = MagicMock() |
| 2474 | + |
| 2475 | + # Create an async context manager that raises |
| 2476 | + class FailingStream: |
| 2477 | + async def __aenter__(self) -> None: |
| 2478 | + raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) |
| 2479 | + |
| 2480 | + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| 2481 | + pass |
| 2482 | + |
| 2483 | + mock_client.stream = MagicMock(return_value=FailingStream()) |
| 2484 | + |
| 2485 | + ctx = RequestContext( |
| 2486 | + client=mock_client, |
| 2487 | + session_id="test-session", |
| 2488 | + session_message=session_message, |
| 2489 | + metadata=None, |
| 2490 | + read_stream_writer=write_stream, |
| 2491 | + ) |
| 2492 | + |
| 2493 | + try: |
| 2494 | + # This should NOT raise, but send an error response |
| 2495 | + await transport._handle_post_request(ctx) |
| 2496 | + |
| 2497 | + # Should have received an error response |
| 2498 | + received = await read_stream.receive() |
| 2499 | + assert isinstance(received, SessionMessage) |
| 2500 | + assert isinstance(received.message.root, JSONRPCError) |
| 2501 | + finally: |
| 2502 | + await write_stream.aclose() |
| 2503 | + await read_stream.aclose() |
| 2504 | + |
| 2505 | + |
| 2506 | +@pytest.mark.anyio |
| 2507 | +async def test_handle_post_request_init_error_raises() -> None: |
| 2508 | + """Test that initialization request errors are raised, not sent as error response.""" |
| 2509 | + from mcp.client.streamable_http import RequestContext, StreamableHTTPTransport |
| 2510 | + from mcp.shared.message import SessionMessage |
| 2511 | + from mcp.types import JSONRPCMessage, JSONRPCRequest |
| 2512 | + |
| 2513 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2514 | + |
| 2515 | + # Create streams for testing |
| 2516 | + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 2517 | + |
| 2518 | + # Create an initialization request |
| 2519 | + mock_message = JSONRPCMessage( |
| 2520 | + root=JSONRPCRequest( |
| 2521 | + jsonrpc="2.0", |
| 2522 | + id="init-1", |
| 2523 | + method="initialize", |
| 2524 | + params={ |
| 2525 | + "clientInfo": {"name": "test", "version": "1.0"}, |
| 2526 | + "protocolVersion": "2025-03-26", |
| 2527 | + "capabilities": {}, |
| 2528 | + }, |
| 2529 | + ) |
| 2530 | + ) |
| 2531 | + session_message = SessionMessage(message=mock_message) |
| 2532 | + |
| 2533 | + # Create a mock client that raises an exception |
| 2534 | + mock_client = MagicMock() |
| 2535 | + |
| 2536 | + class FailingStream: |
| 2537 | + async def __aenter__(self) -> None: |
| 2538 | + raise httpx.HTTPStatusError("Server error", request=MagicMock(), response=MagicMock(status_code=500)) |
| 2539 | + |
| 2540 | + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| 2541 | + pass |
| 2542 | + |
| 2543 | + mock_client.stream = MagicMock(return_value=FailingStream()) |
| 2544 | + |
| 2545 | + ctx = RequestContext( |
| 2546 | + client=mock_client, |
| 2547 | + session_id=None, |
| 2548 | + session_message=session_message, |
| 2549 | + metadata=None, |
| 2550 | + read_stream_writer=write_stream, |
| 2551 | + ) |
| 2552 | + |
| 2553 | + try: |
| 2554 | + # This SHOULD raise for initialization requests |
| 2555 | + with pytest.raises(httpx.HTTPStatusError): |
| 2556 | + await transport._handle_post_request(ctx) |
| 2557 | + finally: |
| 2558 | + await write_stream.aclose() |
| 2559 | + await read_stream.aclose() |
| 2560 | + |
| 2561 | + |
| 2562 | +@pytest.mark.anyio |
| 2563 | +async def test_handle_reconnection_max_attempts_exceeded() -> None: |
| 2564 | + """Test that _handle_reconnection raises when max attempts exceeded.""" |
| 2565 | + from mcp.client.streamable_http import ( |
| 2566 | + MAX_RECONNECTION_ATTEMPTS, |
| 2567 | + RequestContext, |
| 2568 | + StreamableHTTPTransport, |
| 2569 | + ) |
| 2570 | + from mcp.shared.message import SessionMessage |
| 2571 | + from mcp.types import JSONRPCMessage, JSONRPCRequest |
| 2572 | + |
| 2573 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2574 | + |
| 2575 | + # Create streams for testing |
| 2576 | + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 2577 | + |
| 2578 | + # Create a request context |
| 2579 | + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) |
| 2580 | + session_message = SessionMessage(message=mock_message) |
| 2581 | + |
| 2582 | + ctx = RequestContext( |
| 2583 | + client=MagicMock(), |
| 2584 | + session_id="test-session", |
| 2585 | + session_message=session_message, |
| 2586 | + metadata=None, |
| 2587 | + read_stream_writer=write_stream, |
| 2588 | + ) |
| 2589 | + |
| 2590 | + try: |
| 2591 | + # Call with attempt >= MAX_RECONNECTION_ATTEMPTS should raise |
| 2592 | + with pytest.raises(Exception, match="SSE stream reconnection failed"): |
| 2593 | + await transport._handle_reconnection( |
| 2594 | + ctx, |
| 2595 | + last_event_id="test-event-id", |
| 2596 | + retry_interval_ms=1, # Use 1ms to speed up test |
| 2597 | + attempt=MAX_RECONNECTION_ATTEMPTS, |
| 2598 | + ) |
| 2599 | + finally: |
| 2600 | + await write_stream.aclose() |
| 2601 | + await read_stream.aclose() |
| 2602 | + |
| 2603 | + |
| 2604 | +@pytest.mark.anyio |
| 2605 | +async def test_handle_reconnection_failure_retries() -> None: |
| 2606 | + """Test that _handle_reconnection retries on failure and eventually raises.""" |
| 2607 | + from collections.abc import AsyncGenerator |
| 2608 | + from unittest.mock import MagicMock, patch |
| 2609 | + |
| 2610 | + from mcp.client.streamable_http import ( |
| 2611 | + MAX_RECONNECTION_ATTEMPTS, |
| 2612 | + RequestContext, |
| 2613 | + StreamableHTTPTransport, |
| 2614 | + ) |
| 2615 | + from mcp.shared.message import SessionMessage |
| 2616 | + from mcp.types import JSONRPCMessage, JSONRPCRequest |
| 2617 | + |
| 2618 | + transport = StreamableHTTPTransport(url="http://localhost:8000/mcp") |
| 2619 | + |
| 2620 | + # Create streams for testing |
| 2621 | + write_stream, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 2622 | + |
| 2623 | + # Create a mock client |
| 2624 | + mock_client = MagicMock() |
| 2625 | + |
| 2626 | + # Create a request context |
| 2627 | + mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="test")) |
| 2628 | + session_message = SessionMessage(message=mock_message) |
| 2629 | + |
| 2630 | + ctx = RequestContext( |
| 2631 | + client=mock_client, |
| 2632 | + session_id="test-session", |
| 2633 | + session_message=session_message, |
| 2634 | + metadata=None, |
| 2635 | + read_stream_writer=write_stream, |
| 2636 | + ) |
| 2637 | + |
| 2638 | + # Track how many times aconnect_sse is called |
| 2639 | + call_count = 0 |
| 2640 | + |
| 2641 | + @contextlib.asynccontextmanager |
| 2642 | + async def failing_aconnect_sse(*args: Any, **kwargs: Any) -> AsyncGenerator[None, None]: |
| 2643 | + nonlocal call_count |
| 2644 | + call_count += 1 |
| 2645 | + raise httpx.HTTPStatusError("Connection failed", request=MagicMock(), response=MagicMock(status_code=503)) |
| 2646 | + yield # Make it an async generator |
| 2647 | + |
| 2648 | + try: |
| 2649 | + with patch("mcp.client.streamable_http.aconnect_sse", failing_aconnect_sse): |
| 2650 | + with pytest.raises(Exception, match="SSE stream reconnection failed"): |
| 2651 | + await transport._handle_reconnection( |
| 2652 | + ctx, |
| 2653 | + last_event_id="test-event-id", |
| 2654 | + retry_interval_ms=1, # Use 1ms to speed up test |
| 2655 | + attempt=0, |
| 2656 | + ) |
| 2657 | + |
| 2658 | + # Should have tried MAX_RECONNECTION_ATTEMPTS times |
| 2659 | + assert call_count == MAX_RECONNECTION_ATTEMPTS |
| 2660 | + finally: |
| 2661 | + await write_stream.aclose() |
| 2662 | + await read_stream.aclose() |
0 commit comments