Skip to content

Commit 7250289

Browse files
committed
create context_app using StreamingASGITransport and update test_request_context_propogation to apply this methodology
1 parent 7f241e7 commit 7250289

File tree

1 file changed

+58
-19
lines changed

1 file changed

+58
-19
lines changed

tests/shared/test_sse.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mcp.server import Server
2323
from mcp.server.sse import SseServerTransport
2424
from mcp.server.transport_security import TransportSecuritySettings
25+
from mcp.server.streaming_asgi_transport import StreamingASGITransport
2526
from mcp.shared.exceptions import McpError
2627
from mcp.types import (
2728
EmptyResult,
@@ -367,9 +368,32 @@ def context_server(server_port: int) -> Generator[None, None, None]:
367368
if proc.is_alive():
368369
print("context server process failed to terminate")
369370

371+
@pytest.fixture()
372+
async def context_app() -> Starlette:
373+
"""Fixture that provides the context server app"""
374+
security_settings = TransportSecuritySettings(
375+
allowed_hosts=["127.0.0.1:*", "localhost:*", "testserver"],
376+
allowed_origins=["http://127.0.0.1:*", "http://localhost:*", "http://testserver"]
377+
)
378+
sse = SseServerTransport("/messages/", security_settings=security_settings)
379+
context_server = RequestContextServer()
380+
381+
async def handle_sse(request: Request) -> Response:
382+
async with sse.connect_sse(request.scope, request.receive, request._send) as streams:
383+
await context_server.run(streams[0], streams[1], context_server.create_initialization_options())
384+
return Response()
385+
386+
app = Starlette(
387+
routes=[
388+
Route("/sse", endpoint=handle_sse),
389+
Mount("/messages/", app=sse.handle_post_message),
390+
]
391+
)
392+
return app
393+
370394

371395
@pytest.mark.anyio
372-
async def test_request_context_propagation(context_server: None, server_url: str) -> None:
396+
async def test_request_context_propagation(context_app: Starlette) -> None:
373397
"""Test that request context is properly propagated through SSE transport."""
374398
# Test with custom headers
375399
custom_headers = {
@@ -378,27 +402,42 @@ async def test_request_context_propagation(context_server: None, server_url: str
378402
"X-Trace-Id": "trace-123",
379403
}
380404

381-
async with sse_client(server_url + "/sse", headers=custom_headers) as (
382-
read_stream,
383-
write_stream,
384-
):
385-
async with ClientSession(read_stream, write_stream) as session:
386-
# Initialize the session
387-
result = await session.initialize()
388-
assert isinstance(result, InitializeResult)
389-
390-
# Call the tool that echoes headers back
391-
tool_result = await session.call_tool("echo_headers", {})
405+
async with anyio.create_task_group() as tg:
406+
def create_test_client(
407+
headers: dict[str, str] | None = None,
408+
timeout: httpx.Timeout | None = None,
409+
auth: httpx.Auth | None = None,
410+
) -> httpx.AsyncClient:
411+
transport = StreamingASGITransport(app=context_app, task_group=tg)
412+
return httpx.AsyncClient(
413+
transport=transport,
414+
base_url="http://testserver",
415+
headers=headers,
416+
timeout=timeout,
417+
auth=auth,
418+
follow_redirects=True,
419+
)
420+
421+
async with sse_client("http://testserver/sse", headers=custom_headers, httpx_client_factory=create_test_client) as (
422+
read_stream,
423+
write_stream,
424+
):
425+
async with ClientSession(read_stream, write_stream) as session:
426+
# Initialize the session
427+
result = await session.initialize()
428+
assert isinstance(result, InitializeResult)
392429

393-
# Parse the JSON response
430+
# Call the tool that echoes headers back
431+
tool_result = await session.call_tool("echo_headers", {})
394432

395-
assert len(tool_result.content) == 1
396-
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
433+
# Parse the JSON response
434+
assert len(tool_result.content) == 1
435+
headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}")
397436

398-
# Verify headers were propagated
399-
assert headers_data.get("authorization") == "Bearer test-token"
400-
assert headers_data.get("x-custom-header") == "test-value"
401-
assert headers_data.get("x-trace-id") == "trace-123"
437+
# Verify headers were propagated
438+
assert headers_data.get("authorization") == "Bearer test-token"
439+
assert headers_data.get("x-custom-header") == "test-value"
440+
assert headers_data.get("x-trace-id") == "trace-123"
402441

403442

404443
@pytest.mark.anyio

0 commit comments

Comments
 (0)