2222from mcp .server import Server
2323from mcp .server .sse import SseServerTransport
2424from mcp .server .transport_security import TransportSecuritySettings
25+ from mcp .server .streaming_asgi_transport import StreamingASGITransport
2526from mcp .shared .exceptions import McpError
2627from mcp .types import (
2728 EmptyResult ,
@@ -367,9 +368,32 @@ def context_server(server_port: int) -> Generator[None, None, None]:
367368 if proc .is_alive ():
368369 print ("context server process failed to terminate" )
369370
371+ @pytest .fixture ()
372+ async def context_app () -> Starlette :
373+ """Fixture that provides the context server app"""
374+ security_settings = TransportSecuritySettings (
375+ allowed_hosts = ["127.0.0.1:*" , "localhost:*" , "testserver" ],
376+ allowed_origins = ["http://127.0.0.1:*" , "http://localhost:*" , "http://testserver" ]
377+ )
378+ sse = SseServerTransport ("/messages/" , security_settings = security_settings )
379+ context_server = RequestContextServer ()
380+
381+ async def handle_sse (request : Request ) -> Response :
382+ async with sse .connect_sse (request .scope , request .receive , request ._send ) as streams :
383+ await context_server .run (streams [0 ], streams [1 ], context_server .create_initialization_options ())
384+ return Response ()
385+
386+ app = Starlette (
387+ routes = [
388+ Route ("/sse" , endpoint = handle_sse ),
389+ Mount ("/messages/" , app = sse .handle_post_message ),
390+ ]
391+ )
392+ return app
393+
370394
371395@pytest .mark .anyio
372- async def test_request_context_propagation (context_server : None , server_url : str ) -> None :
396+ async def test_request_context_propagation (context_app : Starlette ) -> None :
373397 """Test that request context is properly propagated through SSE transport."""
374398 # Test with custom headers
375399 custom_headers = {
@@ -378,27 +402,42 @@ async def test_request_context_propagation(context_server: None, server_url: str
378402 "X-Trace-Id" : "trace-123" ,
379403 }
380404
381- async with sse_client (server_url + "/sse" , headers = custom_headers ) as (
382- read_stream ,
383- write_stream ,
384- ):
385- async with ClientSession (read_stream , write_stream ) as session :
386- # Initialize the session
387- result = await session .initialize ()
388- assert isinstance (result , InitializeResult )
389-
390- # Call the tool that echoes headers back
391- tool_result = await session .call_tool ("echo_headers" , {})
405+ async with anyio .create_task_group () as tg :
406+ def create_test_client (
407+ headers : dict [str , str ] | None = None ,
408+ timeout : httpx .Timeout | None = None ,
409+ auth : httpx .Auth | None = None ,
410+ ) -> httpx .AsyncClient :
411+ transport = StreamingASGITransport (app = context_app , task_group = tg )
412+ return httpx .AsyncClient (
413+ transport = transport ,
414+ base_url = "http://testserver" ,
415+ headers = headers ,
416+ timeout = timeout ,
417+ auth = auth ,
418+ follow_redirects = True ,
419+ )
420+
421+ async with sse_client ("http://testserver/sse" , headers = custom_headers , httpx_client_factory = create_test_client ) as (
422+ read_stream ,
423+ write_stream ,
424+ ):
425+ async with ClientSession (read_stream , write_stream ) as session :
426+ # Initialize the session
427+ result = await session .initialize ()
428+ assert isinstance (result , InitializeResult )
392429
393- # Parse the JSON response
430+ # Call the tool that echoes headers back
431+ tool_result = await session .call_tool ("echo_headers" , {})
394432
395- assert len (tool_result .content ) == 1
396- headers_data = json .loads (tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}" )
433+ # Parse the JSON response
434+ assert len (tool_result .content ) == 1
435+ headers_data = json .loads (tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}" )
397436
398- # Verify headers were propagated
399- assert headers_data .get ("authorization" ) == "Bearer test-token"
400- assert headers_data .get ("x-custom-header" ) == "test-value"
401- assert headers_data .get ("x-trace-id" ) == "trace-123"
437+ # Verify headers were propagated
438+ assert headers_data .get ("authorization" ) == "Bearer test-token"
439+ assert headers_data .get ("x-custom-header" ) == "test-value"
440+ assert headers_data .get ("x-trace-id" ) == "trace-123"
402441
403442
404443@pytest .mark .anyio
0 commit comments