@@ -119,6 +119,54 @@ def run_server(server_port: int) -> None:
119119 time .sleep (0.5 )
120120
121121
122+ def make_server_app_with_endpoint (endpoint : str ) -> Starlette :
123+ """Create test Starlette app with SSE transport using the specified endpoint"""
124+ sse = SseServerTransport (endpoint )
125+ server = ServerTest ()
126+
127+ async def handle_sse (request : Request ) -> Response :
128+ async with sse .connect_sse (
129+ request .scope , request .receive , request ._send
130+ ) as streams :
131+ await server .run (
132+ streams [0 ], streams [1 ], server .create_initialization_options ()
133+ )
134+ return Response ()
135+
136+ # For absolute URLs, we route all paths
137+ if endpoint .startswith (("http://" , "https://" )):
138+ route_path = "/sse"
139+ mount_path = "/"
140+ else :
141+ route_path = "/sse"
142+ mount_path = endpoint
143+
144+ app = Starlette (
145+ routes = [
146+ Route (route_path , endpoint = handle_sse ),
147+ Mount (mount_path , app = sse .handle_post_message ),
148+ ]
149+ )
150+
151+ return app
152+
153+
154+ def run_server_with_endpoint (server_port : int , endpoint : str ) -> None :
155+ app = make_server_app_with_endpoint (endpoint )
156+ server = uvicorn .Server (
157+ config = uvicorn .Config (
158+ app = app , host = "127.0.0.1" , port = server_port , log_level = "error"
159+ )
160+ )
161+ print (f"starting server on { server_port } with endpoint { endpoint } " )
162+ server .run ()
163+
164+ # Give server time to start
165+ while not server .started :
166+ print ("waiting for server to start" )
167+ time .sleep (0.5 )
168+
169+
122170@pytest .fixture ()
123171def server (server_port : int ) -> Generator [None , None , None ]:
124172 proc = multiprocessing .Process (
@@ -159,6 +207,129 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N
159207 yield client
160208
161209
210+ @pytest .fixture ()
211+ def server_with_relative_endpoint (server_port : int ) -> Generator [None , None , None ]:
212+ """Setup a server with a relative endpoint path"""
213+ proc = multiprocessing .Process (
214+ target = run_server_with_endpoint ,
215+ kwargs = {"server_port" : server_port , "endpoint" : "/messages/" },
216+ daemon = True ,
217+ )
218+ print ("starting process with relative endpoint" )
219+ proc .start ()
220+
221+ # Wait for server to be running
222+ max_attempts = 20
223+ attempt = 0
224+ print ("waiting for server to start" )
225+ while attempt < max_attempts :
226+ try :
227+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
228+ s .connect (("127.0.0.1" , server_port ))
229+ break
230+ except ConnectionRefusedError :
231+ time .sleep (0.1 )
232+ attempt += 1
233+ else :
234+ raise RuntimeError (f"Server failed to start after { max_attempts } attempts" )
235+
236+ yield
237+
238+ print ("killing server" )
239+ # Signal the server to stop
240+ proc .kill ()
241+ proc .join (timeout = 2 )
242+ if proc .is_alive ():
243+ print ("server process failed to terminate" )
244+
245+
246+ @pytest .fixture ()
247+ def server_with_absolute_endpoint (
248+ server_port : int , server_url : str
249+ ) -> Generator [None , None , None ]:
250+ """Setup a server with an absolute endpoint URL"""
251+ absolute_endpoint = f"{ server_url } /messages/"
252+ proc = multiprocessing .Process (
253+ target = run_server_with_endpoint ,
254+ kwargs = {"server_port" : server_port , "endpoint" : absolute_endpoint },
255+ daemon = True ,
256+ )
257+ print (f"starting process with absolute endpoint: { absolute_endpoint } " )
258+ proc .start ()
259+
260+ # Wait for server to be running
261+ max_attempts = 20
262+ attempt = 0
263+ print ("waiting for server to start" )
264+ while attempt < max_attempts :
265+ try :
266+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
267+ s .connect (("127.0.0.1" , server_port ))
268+ break
269+ except ConnectionRefusedError :
270+ time .sleep (0.1 )
271+ attempt += 1
272+ else :
273+ raise RuntimeError (f"Server failed to start after { max_attempts } attempts" )
274+
275+ yield
276+
277+ print ("killing server" )
278+ # Signal the server to stop
279+ proc .kill ()
280+ proc .join (timeout = 2 )
281+ if proc .is_alive ():
282+ print ("server process failed to terminate" )
283+
284+
285+ @pytest .fixture ()
286+ async def http_client_with_relative_endpoint (
287+ server_with_relative_endpoint , server_url
288+ ) -> AsyncGenerator [httpx .AsyncClient , None ]:
289+ """Create test client for server with relative endpoint"""
290+ async with httpx .AsyncClient (base_url = server_url ) as client :
291+ yield client
292+
293+
294+ @pytest .fixture ()
295+ async def http_client_with_absolute_endpoint (
296+ server_with_absolute_endpoint , server_url
297+ ) -> AsyncGenerator [httpx .AsyncClient , None ]:
298+ """Create test client for server with absolute endpoint"""
299+ async with httpx .AsyncClient (base_url = server_url ) as client :
300+ yield client
301+
302+
303+ @pytest .fixture
304+ async def initialized_sse_client_session (
305+ server , server_url : str
306+ ) -> AsyncGenerator [ClientSession , None ]:
307+ async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
308+ async with ClientSession (* streams ) as session :
309+ await session .initialize ()
310+ yield session
311+
312+
313+ @pytest .fixture
314+ async def initialized_sse_client_session_with_relative_endpoint (
315+ server_with_relative_endpoint , server_url : str
316+ ) -> AsyncGenerator [ClientSession , None ]:
317+ async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
318+ async with ClientSession (* streams ) as session :
319+ await session .initialize ()
320+ yield session
321+
322+
323+ @pytest .fixture
324+ async def initialized_sse_client_session_with_absolute_endpoint (
325+ server_with_absolute_endpoint , server_url : str
326+ ) -> AsyncGenerator [ClientSession , None ]:
327+ async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
328+ async with ClientSession (* streams ) as session :
329+ await session .initialize ()
330+ yield session
331+
332+
162333# Tests
163334@pytest .mark .anyio
164335async def test_raw_sse_connection (http_client : httpx .AsyncClient ) -> None :
@@ -202,16 +373,6 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
202373 assert isinstance (ping_result , EmptyResult )
203374
204375
205- @pytest .fixture
206- async def initialized_sse_client_session (
207- server , server_url : str
208- ) -> AsyncGenerator [ClientSession , None ]:
209- async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
210- async with ClientSession (* streams ) as session :
211- await session .initialize ()
212- yield session
213-
214-
215376@pytest .mark .anyio
216377async def test_sse_client_happy_request_and_response (
217378 initialized_sse_client_session : ClientSession ,
@@ -252,3 +413,91 @@ async def test_sse_client_timeout(
252413 return
253414
254415 pytest .fail ("the client should have timed out and returned an error already" )
416+
417+
418+ @pytest .mark .anyio
419+ async def test_raw_sse_connection_with_relative_endpoint (http_client_with_relative_endpoint : httpx .AsyncClient ) -> None :
420+ """Test the SSE connection establishment with a relative endpoint URL."""
421+ async with anyio .create_task_group ():
422+
423+ async def connection_test () -> None :
424+ async with http_client_with_relative_endpoint .stream ("GET" , "/sse" ) as response :
425+ assert response .status_code == 200
426+ assert (
427+ response .headers ["content-type" ]
428+ == "text/event-stream; charset=utf-8"
429+ )
430+
431+ line_number = 0
432+ async for line in response .aiter_lines ():
433+ if line_number == 0 :
434+ assert line == "event: endpoint"
435+ elif line_number == 1 :
436+ assert line .startswith ("data: /messages/?session_id=" )
437+ # Verify it's a relative URL
438+ endpoint_data = line .removeprefix ("data: " )
439+ assert not endpoint_data .startswith (("http://" , "https://" ))
440+ assert endpoint_data .startswith ("/messages/?session_id=" )
441+ else :
442+ return
443+ line_number += 1
444+
445+ # Add timeout to prevent test from hanging if it fails
446+ with anyio .fail_after (3 ):
447+ await connection_test ()
448+
449+
450+ @pytest .mark .anyio
451+ async def test_raw_sse_connection_with_absolute_endpoint (http_client_with_absolute_endpoint : httpx .AsyncClient ) -> None :
452+ """Test the SSE connection establishment with an absolute endpoint URL."""
453+ async with anyio .create_task_group ():
454+
455+ async def connection_test () -> None :
456+ async with http_client_with_absolute_endpoint .stream ("GET" , "/sse" ) as response :
457+ assert response .status_code == 200
458+ assert (
459+ response .headers ["content-type" ]
460+ == "text/event-stream; charset=utf-8"
461+ )
462+
463+ line_number = 0
464+ async for line in response .aiter_lines ():
465+ if line_number == 0 :
466+ assert line == "event: endpoint"
467+ elif line_number == 1 :
468+ # Verify it's an absolute URL
469+ assert line .startswith ("data: http://" )
470+ assert "/messages/?session_id=" in line
471+ else :
472+ return
473+ line_number += 1
474+
475+ # Add timeout to prevent test from hanging if it fails
476+ with anyio .fail_after (3 ):
477+ await connection_test ()
478+
479+
480+ @pytest .mark .anyio
481+ async def test_sse_client_with_relative_endpoint (
482+ initialized_sse_client_session_with_relative_endpoint : ClientSession ,
483+ ) -> None :
484+ """Test that a client session works properly with a relative endpoint."""
485+ session = initialized_sse_client_session_with_relative_endpoint
486+ # Test basic functionality
487+ response = await session .read_resource (uri = AnyUrl ("foobar://should-work" ))
488+ assert len (response .contents ) == 1
489+ assert isinstance (response .contents [0 ], TextResourceContents )
490+ assert response .contents [0 ].text == "Read should-work"
491+
492+
493+ @pytest .mark .anyio
494+ async def test_sse_client_with_absolute_endpoint (
495+ initialized_sse_client_session_with_absolute_endpoint : ClientSession ,
496+ ) -> None :
497+ """Test that a client session works properly with an absolute endpoint."""
498+ session = initialized_sse_client_session_with_absolute_endpoint
499+ # Test basic functionality
500+ response = await session .read_resource (uri = AnyUrl ("foobar://should-work" ))
501+ assert len (response .contents ) == 1
502+ assert isinstance (response .contents [0 ], TextResourceContents )
503+ assert response .contents [0 ].text == "Read should-work"
0 commit comments