1- import pytest
21import anyio
2+ import pytest
33from starlette .applications import Starlette
44from starlette .routing import Mount , Route
5- import uvicorn
6- from mcp .client .sse import sse_client
7- from exceptiongroup import ExceptionGroup
8- import asyncio
95import httpx
10- from httpx import ReadTimeout
6+ from httpx import ReadTimeout , ASGITransport
117
8+ from mcp .client .sse import sse_client
129from mcp .server .sse import SseServerTransport
10+ from mcp .types import JSONRPCMessage
11+
1312
1413@pytest .fixture
15- async def sse_server ():
14+ async def sse_transport ():
15+ """Fixture that creates an SSE transport instance."""
16+ return SseServerTransport ("/messages/" )
1617
17- # Create an SSE transport at an endpoint
18- sse = SseServerTransport ("/messages/" )
1918
20- # Create Starlette routes for SSE and message handling
19+ @pytest .fixture
20+ async def sse_app (sse_transport ):
21+ """Fixture that creates a Starlette app with SSE endpoints."""
22+ async def handle_sse (request ):
23+ """Handler for SSE connections."""
24+ async with sse_transport .connect_sse (
25+ request .scope , request .receive , request ._send
26+ ) as streams :
27+ client_to_server , server_to_client = streams
28+ async for message in client_to_server :
29+ # Echo messages back for testing
30+ await server_to_client .send (message )
31+
2132 routes = [
2233 Route ("/sse" , endpoint = handle_sse ),
23- Mount ("/messages/ " , app = sse .handle_post_message ),
34+ Mount ("/messages" , app = sse_transport .handle_post_message ),
2435 ]
25- #
26- # Create and run Starlette app
27- app = Starlette (routes = routes )
2836
29- # Define handler functions
30- async def handle_sse (request ):
31- async with sse .connect_sse (
32- request .scope , request .receive , request ._send
33- ) as streams :
34- await app .run (
35- streams [0 ], streams [1 ], app .create_initialization_options ()
36- )
37+ return Starlette (routes = routes )
3738
38- uvicorn .run (app , host = "127.0.0.1" , port = 34891 )
3939
40- async def sse_handler (request ):
41- response = httpx .Response (200 , content_type = "text/event-stream" )
42- response .send_headers ()
43- response .write ("data: test\n \n " )
44- await response .aclose ()
40+ @pytest .fixture
41+ async def test_client (sse_app ):
42+ """Create a test client with ASGI transport."""
43+ async with httpx .AsyncClient (
44+ transport = ASGITransport (app = sse_app ),
45+ base_url = "http://testserver" ,
46+ ) as client :
47+ yield client
4548
46- async with httpx .AsyncServer (sse_handler ) as server :
47- yield server .url
4849
50+ @pytest .mark .anyio
51+ async def test_sse_connection (test_client ):
52+ """Test basic SSE connection and message exchange."""
53+ async with sse_client (
54+ "http://testserver/sse" ,
55+ headers = {"Host" : "testserver" },
56+ timeout = 5 ,
57+ client = test_client ,
58+ ) as (read_stream , write_stream ):
59+ # Send a test message
60+ test_message = JSONRPCMessage .model_validate ({"jsonrpc" : "2.0" , "method" : "test" })
61+ await write_stream .send (test_message )
4962
50- @pytest .fixture
51- async def sse_client ():
52- async with sse_client ("http://test/sse" ) as (read_stream , write_stream ):
63+ # Receive echoed message
5364 async with read_stream :
54- async for message in read_stream :
55- if isinstance (message , Exception ):
56- raise message
65+ message = await read_stream . __anext__ ()
66+ assert isinstance (message , JSONRPCMessage )
67+ assert message . model_dump () == test_message . model_dump ()
5768
58- return read_stream , write_stream
5969
6070@pytest .mark .anyio
61- async def test_sse_happy_path (monkeypatch ):
62- # Mock httpx.AsyncClient to return our mock response
63- monkeypatch .setattr (httpx , "AsyncClient" , MockClient )
64-
65- with pytest .raises (ReadTimeout ) as exc_info :
71+ async def test_sse_read_timeout (test_client ):
72+ """Test that SSE client properly handles read timeouts."""
73+ with pytest .raises (ReadTimeout ):
6674 async with sse_client (
67- "http://test/sse" ,
68- timeout = 5 , # Connection timeout - make this longer
69- sse_read_timeout = 1 # Read timeout - this should trigger
75+ "http://testserver/sse" ,
76+ headers = {"Host" : "testserver" },
77+ timeout = 5 ,
78+ sse_read_timeout = 1 ,
79+ client = test_client ,
7080 ) as (read_stream , write_stream ):
7181 async with read_stream :
72- async for message in read_stream :
73- if isinstance (message , Exception ):
74- raise message
82+ # This should timeout since no messages are being sent
83+ await read_stream .__anext__ ()
84+
85+
86+ @pytest .mark .anyio
87+ async def test_sse_connection_error (test_client ):
88+ """Test SSE client behavior with connection errors."""
89+ with pytest .raises (httpx .HTTPError ):
90+ async with sse_client (
91+ "http://testserver/nonexistent" ,
92+ headers = {"Host" : "testserver" },
93+ timeout = 5 ,
94+ client = test_client ,
95+ ):
96+ pass # Should not reach here
7597
76- error = exc_info .value
77- assert isinstance (error , ReadTimeout )
78- assert str (error ) == "Read timeout"
7998
8099@pytest .mark .anyio
81- async def test_sse_read_timeouts (monkeypatch ):
82- """Test that the SSE client properly handles read timeouts between SSE messages."""
100+ async def test_sse_multiple_messages (test_client ):
101+ """Test sending and receiving multiple SSE messages."""
102+ async with sse_client (
103+ "http://testserver/sse" ,
104+ headers = {"Host" : "testserver" },
105+ timeout = 5 ,
106+ client = test_client ,
107+ ) as (read_stream , write_stream ):
108+ # Send multiple test messages
109+ messages = [
110+ JSONRPCMessage .model_validate ({"jsonrpc" : "2.0" , "method" : f"test{ i } " })
111+ for i in range (3 )
112+ ]
113+
114+ for msg in messages :
115+ await write_stream .send (msg )
116+
117+ # Receive all echoed messages
118+ received = []
119+ async with read_stream :
120+ for _ in range (len (messages )):
121+ message = await read_stream .__anext__ ()
122+ assert isinstance (message , JSONRPCMessage )
123+ received .append (message )
124+
125+ # Verify all messages were received in order
126+ for sent , received in zip (messages , received ):
127+ assert sent .model_dump () == received .model_dump ()
0 commit comments