@@ -116,24 +116,28 @@ def server_app() -> Starlette:
116116
117117
118118@pytest .fixture ()
119- async def http_client (server_app : Starlette ) -> AsyncGenerator [httpx .AsyncClient , None ]:
120- """Create test client using StreamingASGITransport"""
119+ async def tg () -> AsyncGenerator [TaskGroup , None ]:
121120 async with anyio .create_task_group () as tg :
122- transport = StreamingASGITransport (app = server_app , task_group = tg )
123- async with httpx .AsyncClient (transport = transport , base_url = TEST_SERVER_BASE_URL ) as client :
124- yield client
121+ yield tg
125122
126123
127124@pytest .fixture ()
128- async def sse_client_session (server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
129- async with anyio .create_task_group () as tg :
130- asgi_client_factory = create_asgi_client_factory (server_app , tg )
125+ async def http_client (tg : TaskGroup , server_app : Starlette ) -> AsyncGenerator [httpx .AsyncClient , None ]:
126+ """Create test client using StreamingASGITransport"""
127+ transport = StreamingASGITransport (app = server_app , task_group = tg )
128+ async with httpx .AsyncClient (transport = transport , base_url = TEST_SERVER_BASE_URL ) as client :
129+ yield client
131130
132- async with sse_client (
133- f"{ TEST_SERVER_BASE_URL } /sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
134- ) as streams :
135- async with ClientSession (* streams ) as session :
136- yield session
131+
132+ @pytest .fixture ()
133+ async def sse_client_session (tg : TaskGroup , server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
134+ asgi_client_factory = create_asgi_client_factory (server_app , tg )
135+
136+ async with sse_client (
137+ f"{ TEST_SERVER_BASE_URL } /sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory ,
138+ ) as streams :
139+ async with ClientSession (* streams ) as session :
140+ yield session
137141
138142
139143# Tests
@@ -228,15 +232,16 @@ async def mounted_server_app(server_app: Starlette) -> Starlette:
228232
229233
230234@pytest .fixture ()
231- async def sse_client_mounted_server_app_session (mounted_server_app : Starlette ) -> AsyncGenerator [ClientSession , None ]:
232- async with anyio .create_task_group () as tg :
233- asgi_client_factory = create_asgi_client_factory (mounted_server_app , tg )
235+ async def sse_client_mounted_server_app_session (
236+ tg : TaskGroup , mounted_server_app : Starlette
237+ ) -> AsyncGenerator [ClientSession , None ]:
238+ asgi_client_factory = create_asgi_client_factory (mounted_server_app , tg )
234239
235- async with sse_client (
236- f"{ TEST_SERVER_BASE_URL } /mounted_app/sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory
237- ) as streams :
238- async with ClientSession (* streams ) as session :
239- yield session
240+ async with sse_client (
241+ f"{ TEST_SERVER_BASE_URL } /mounted_app/sse" , sse_read_timeout = 0.5 , httpx_client_factory = asgi_client_factory ,
242+ ) as streams :
243+ async with ClientSession (* streams ) as session :
244+ yield session
240245
241246
242247@pytest .mark .anyio
@@ -303,7 +308,7 @@ async def context_server_app() -> Starlette:
303308
304309
305310@pytest .mark .anyio
306- async def test_request_context_propagation (context_server_app : Starlette ) -> None :
311+ async def test_request_context_propagation (tg : TaskGroup , context_server_app : Starlette ) -> None :
307312 """Test that request context is properly propagated through SSE transport."""
308313 # Test with custom headers
309314 custom_headers = {
@@ -312,63 +317,59 @@ async def test_request_context_propagation(context_server_app: Starlette) -> Non
312317 "X-Trace-Id" : "trace-123" ,
313318 }
314319
315- async with anyio .create_task_group () as tg :
316- asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
320+ asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
317321
318- async with sse_client (
319- f"{ TEST_SERVER_BASE_URL } /sse" ,
320- headers = custom_headers ,
321- httpx_client_factory = asgi_client_factory ,
322- sse_read_timeout = 0.5 ,
323- ) as streams :
324- async with ClientSession (* streams ) as session :
325- # Initialize the session
326- result = await session .initialize ()
327- assert isinstance (result , InitializeResult )
322+ async with sse_client (
323+ f"{ TEST_SERVER_BASE_URL } /sse" ,
324+ headers = custom_headers ,
325+ httpx_client_factory = asgi_client_factory ,
326+ sse_read_timeout = 0.5 ,
327+
328+ ) as streams :
329+ async with ClientSession (* streams ) as session :
330+ # Initialize the session
331+ result = await session .initialize ()
332+ assert isinstance (result , InitializeResult )
328333
329- # Call the tool that echoes headers back
330- tool_result = await session .call_tool ("echo_headers" , {})
334+ # Call the tool that echoes headers back
335+ tool_result = await session .call_tool ("echo_headers" , {})
331336
332- # Parse the JSON response
333- assert len (tool_result .content ) == 1
334- content_item = tool_result .content [0 ]
335- headers_data = json .loads (content_item .text if content_item .type == "text" else "{}" )
337+ # Parse the JSON response
338+ assert len (tool_result .content ) == 1
339+ content_item = tool_result .content [0 ]
340+ headers_data = json .loads (content_item .text if content_item .type == "text" else "{}" )
336341
337- # Verify headers were propagated
338- assert headers_data .get ("authorization" ) == "Bearer test-token"
339- assert headers_data .get ("x-custom-header" ) == "test-value"
340- assert headers_data .get ("x-trace-id" ) == "trace-123"
342+ # Verify headers were propagated
343+ assert headers_data .get ("authorization" ) == "Bearer test-token"
344+ assert headers_data .get ("x-custom-header" ) == "test-value"
345+ assert headers_data .get ("x-trace-id" ) == "trace-123"
341346
342347
343348@pytest .mark .anyio
344- async def test_request_context_isolation (context_server_app : Starlette ) -> None :
349+ async def test_request_context_isolation (tg : TaskGroup , context_server_app : Starlette ) -> None :
345350 """Test that request contexts are isolated between different SSE clients."""
346351 contexts : list [dict [str , Any ]] = []
347352
348- async with anyio .create_task_group () as tg :
349- asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
350-
351- # Create multiple clients with different headers
352- for i in range (3 ):
353- headers = {"X-Request-Id" : f"request-{ i } " , "X-Custom-Value" : f"value-{ i } " }
354-
355- async with sse_client (
356- f"{ TEST_SERVER_BASE_URL } /sse" , headers = headers , httpx_client_factory = asgi_client_factory
357- ) as (
358- read_stream ,
359- write_stream ,
360- ):
361- async with ClientSession (read_stream , write_stream ) as session :
362- await session .initialize ()
363-
364- # Call the tool that echoes context
365- tool_result = await session .call_tool ("echo_context" , {"request_id" : f"request-{ i } " })
366-
367- assert len (tool_result .content ) == 1
368- context_data = json .loads (
369- tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}"
370- )
371- contexts .append (context_data )
353+ asgi_client_factory = create_asgi_client_factory (context_server_app , tg )
354+
355+ # Create multiple clients with different headers
356+ for i in range (3 ):
357+ headers = {"X-Request-Id" : f"request-{ i } " , "X-Custom-Value" : f"value-{ i } " }
358+
359+ async with sse_client (
360+ f"{ TEST_SERVER_BASE_URL } /sse" , headers = headers , httpx_client_factory = asgi_client_factory ,
361+ ) as streams :
362+ async with ClientSession (* streams ) as session :
363+ await session .initialize ()
364+
365+ # Call the tool that echoes context
366+ tool_result = await session .call_tool ("echo_context" , {"request_id" : f"request-{ i } " })
367+
368+ assert len (tool_result .content ) == 1
369+ context_data = json .loads (
370+ tool_result .content [0 ].text if tool_result .content [0 ].type == "text" else "{}"
371+ )
372+ contexts .append (context_data )
372373
373374 # Verify each request had its own context
374375 assert len (contexts ) == 3
0 commit comments