Skip to content

Commit 6a24fc9

Browse files
test(mcp): Use AsyncClient for SSE (#5396)
Stop mocking transport layer of the `mcp` package in tests.
1 parent b3e9c10 commit 6a24fc9

File tree

2 files changed

+231
-38
lines changed

2 files changed

+231
-38
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import asyncio
2+
from httpx import ASGITransport, Request, Response, AsyncByteStream
3+
import anyio
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from typing import Any, Callable, MutableMapping
9+
10+
11+
class StreamingASGITransport(ASGITransport):
12+
"""
13+
Simple transport whose only purpose is to keep GET request alive in SSE connections, allowing
14+
tests involving SSE interactions to run in-process.
15+
"""
16+
17+
def __init__(
18+
self,
19+
app: "Callable",
20+
keep_sse_alive: "asyncio.Event",
21+
) -> None:
22+
self.keep_sse_alive = keep_sse_alive
23+
super().__init__(app)
24+
25+
async def handle_async_request(self, request: "Request") -> "Response":
26+
scope = {
27+
"type": "http",
28+
"method": request.method,
29+
"headers": [(k.lower(), v) for (k, v) in request.headers.raw],
30+
"path": request.url.path,
31+
"query_string": request.url.query,
32+
}
33+
34+
is_streaming_sse = scope["method"] == "GET" and scope["path"] == "/sse"
35+
if not is_streaming_sse:
36+
return await super().handle_async_request(request)
37+
38+
request_body = b""
39+
if request.content:
40+
request_body = await request.aread()
41+
42+
body_sender, body_receiver = anyio.create_memory_object_stream[bytes](0)
43+
44+
async def receive() -> "dict[str, Any]":
45+
if self.keep_sse_alive.is_set():
46+
return {"type": "http.disconnect"}
47+
48+
await self.keep_sse_alive.wait() # Keep alive :)
49+
return {"type": "http.request", "body": request_body, "more_body": False}
50+
51+
async def send(message: "MutableMapping[str, Any]") -> None:
52+
if message["type"] == "http.response.body":
53+
body = message.get("body", b"")
54+
more_body = message.get("more_body", False)
55+
56+
if body == b"" and not more_body:
57+
return
58+
59+
if body:
60+
await body_sender.send(body)
61+
62+
if not more_body:
63+
await body_sender.aclose()
64+
65+
async def run_app():
66+
await self.app(scope, receive, send)
67+
68+
class StreamingBodyStream(AsyncByteStream):
69+
def __init__(self, receiver):
70+
self.receiver = receiver
71+
72+
async def __aiter__(self):
73+
try:
74+
async for chunk in self.receiver:
75+
yield chunk
76+
except anyio.EndOfStream:
77+
pass
78+
79+
stream = StreamingBodyStream(body_receiver)
80+
response = Response(status_code=200, headers=[], stream=stream)
81+
82+
asyncio.create_task(run_app())
83+
return response

tests/integrations/mcp/test_mcp.py

Lines changed: 148 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
that the integration properly instruments MCP handlers with Sentry spans.
1616
"""
1717

18+
from urllib.parse import urlparse, parse_qs
1819
import anyio
20+
import asyncio
21+
import httpx
22+
from .streaming_asgi_transport import StreamingASGITransport
23+
1924
import pytest
2025
import json
2126
from unittest import mock
@@ -43,9 +48,11 @@ async def __call__(self, *args, **kwargs):
4348
from sentry_sdk.consts import SPANDATA, OP
4449
from sentry_sdk.integrations.mcp import MCPIntegration
4550

51+
from mcp.server.sse import SseServerTransport
4652
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
47-
from starlette.routing import Mount
53+
from starlette.routing import Mount, Route
4854
from starlette.applications import Starlette
55+
from starlette.responses import Response
4956

5057

5158
@pytest.fixture(autouse=True)
@@ -67,39 +74,103 @@ def reset_request_ctx():
6774
pass
6875

6976

70-
class MockRequestContext:
71-
"""Mock MCP request context"""
72-
73-
def __init__(self, request_id=None, session_id=None, transport="stdio"):
74-
self.request_id = request_id
75-
if transport in ("http", "sse"):
76-
self.request = MockHTTPRequest(session_id, transport)
77-
else:
78-
self.request = None
77+
class MockTextContent:
78+
"""Mock TextContent object"""
7979

80+
def __init__(self, text):
81+
self.text = text
8082

81-
class MockHTTPRequest:
82-
"""Mock HTTP request for SSE/StreamableHTTP transport"""
8383

84-
def __init__(self, session_id=None, transport="http"):
85-
self.headers = {}
86-
self.query_params = {}
84+
async def json_rpc_sse(
85+
app, method: str, params, request_id: str, keep_sse_alive: "asyncio.Event"
86+
):
87+
context = {}
88+
89+
stream_complete = asyncio.Event()
90+
endpoint_parsed = asyncio.Event()
91+
92+
# https://github.com/Kludex/starlette/issues/104#issuecomment-729087925
93+
async with httpx.AsyncClient(
94+
transport=StreamingASGITransport(app=app, keep_sse_alive=keep_sse_alive),
95+
base_url="http://test",
96+
) as client:
97+
98+
async def parse_stream():
99+
async with client.stream("GET", "/sse") as stream:
100+
# Read directly from stream.stream instead of aiter_bytes()
101+
async for chunk in stream.stream:
102+
if b"event: endpoint" in chunk:
103+
sse_text = chunk.decode("utf-8")
104+
url = sse_text.split("data: ")[1]
105+
106+
parsed = urlparse(url)
107+
query_params = parse_qs(parsed.query)
108+
context["session_id"] = query_params["session_id"][0]
109+
endpoint_parsed.set()
110+
continue
111+
112+
if b"event: message" in chunk and b"structuredContent" in chunk:
113+
sse_text = chunk.decode("utf-8")
114+
115+
json_str = sse_text.split("data: ")[1]
116+
context["response"] = json.loads(json_str)
117+
break
118+
119+
stream_complete.set()
120+
121+
task = asyncio.create_task(parse_stream())
122+
await endpoint_parsed.wait()
123+
124+
await client.post(
125+
f"/messages/?session_id={context['session_id']}",
126+
headers={
127+
"Content-Type": "application/json",
128+
},
129+
json={
130+
"jsonrpc": "2.0",
131+
"method": "initialize",
132+
"params": {
133+
"clientInfo": {"name": "test-client", "version": "1.0"},
134+
"protocolVersion": "2025-11-25",
135+
"capabilities": {},
136+
},
137+
"id": request_id,
138+
},
139+
)
87140

88-
if transport == "sse":
89-
# SSE transport uses query parameter
90-
if session_id:
91-
self.query_params["session_id"] = session_id
92-
else:
93-
# StreamableHTTP transport uses header
94-
if session_id:
95-
self.headers["mcp-session-id"] = session_id
141+
# Notification response is mandatory.
142+
# https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle
143+
await client.post(
144+
f"/messages/?session_id={context['session_id']}",
145+
headers={
146+
"Content-Type": "application/json",
147+
"mcp-session-id": context["session_id"],
148+
},
149+
json={
150+
"jsonrpc": "2.0",
151+
"method": "notifications/initialized",
152+
"params": {},
153+
},
154+
)
96155

156+
await client.post(
157+
f"/messages/?session_id={context['session_id']}",
158+
headers={
159+
"Content-Type": "application/json",
160+
"mcp-session-id": context["session_id"],
161+
},
162+
json={
163+
"jsonrpc": "2.0",
164+
"method": method,
165+
"params": params,
166+
"id": request_id,
167+
},
168+
)
97169

98-
class MockTextContent:
99-
"""Mock TextContent object"""
170+
await stream_complete.wait()
171+
keep_sse_alive.set()
100172

101-
def __init__(self, text):
102-
self.text = text
173+
return task, context["session_id"], context["response"]
103174

104175

105176
def test_integration_patches_server(sentry_init):
@@ -986,7 +1057,8 @@ def test_tool_complex(tool_name, arguments):
9861057
assert span["data"]["mcp.request.argument.number"] == "42"
9871058

9881059

989-
def test_sse_transport_detection(sentry_init, capture_events):
1060+
@pytest.mark.asyncio
1061+
async def test_sse_transport_detection(sentry_init, capture_events):
9901062
"""Test that SSE transport is correctly detected via query parameter"""
9911063
sentry_init(
9921064
integrations=[MCPIntegration()],
@@ -995,29 +1067,67 @@ def test_sse_transport_detection(sentry_init, capture_events):
9951067
events = capture_events()
9961068

9971069
server = Server("test-server")
1070+
sse = SseServerTransport("/messages/")
9981071

999-
# Set up mock request context with SSE transport
1000-
mock_ctx = MockRequestContext(
1001-
request_id="req-sse", session_id="session-sse-123", transport="sse"
1072+
sse_connection_closed = asyncio.Event()
1073+
1074+
async def handle_sse(request):
1075+
async with sse.connect_sse(
1076+
request.scope, request.receive, request._send
1077+
) as streams:
1078+
async with anyio.create_task_group() as tg:
1079+
1080+
async def run_server():
1081+
await server.run(
1082+
streams[0], streams[1], server.create_initialization_options()
1083+
)
1084+
1085+
tg.start_soon(run_server)
1086+
1087+
sse_connection_closed.set()
1088+
return Response()
1089+
1090+
app = Starlette(
1091+
routes=[
1092+
Route("/sse", endpoint=handle_sse, methods=["GET"]),
1093+
Mount("/messages/", app=sse.handle_post_message),
1094+
],
10021095
)
1003-
request_ctx.set(mock_ctx)
10041096

10051097
@server.call_tool()
1006-
def test_tool(tool_name, arguments):
1098+
async def test_tool(tool_name, arguments):
10071099
return {"result": "success"}
10081100

1009-
with start_transaction(name="mcp tx"):
1010-
result = test_tool("sse_tool", {})
1101+
keep_sse_alive = asyncio.Event()
1102+
app_task, session_id, result = await json_rpc_sse(
1103+
app,
1104+
method="tools/call",
1105+
params={
1106+
"name": "sse_tool",
1107+
"arguments": {},
1108+
},
1109+
request_id="req-sse",
1110+
keep_sse_alive=keep_sse_alive,
1111+
)
10111112

1012-
assert result == {"result": "success"}
1113+
await sse_connection_closed.wait()
1114+
await app_task
10131115

1014-
(tx,) = events
1116+
assert result["result"]["structuredContent"] == {"result": "success"}
1117+
1118+
transactions = [
1119+
event
1120+
for event in events
1121+
if event["type"] == "transaction" and event["transaction"] == "/sse"
1122+
]
1123+
assert len(transactions) == 1
1124+
tx = transactions[0]
10151125
span = tx["spans"][0]
10161126

10171127
# Check that SSE transport is detected
10181128
assert span["data"][SPANDATA.MCP_TRANSPORT] == "sse"
10191129
assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp"
1020-
assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-sse-123"
1130+
assert span["data"][SPANDATA.MCP_SESSION_ID] == session_id
10211131

10221132

10231133
def test_streamable_http_transport_detection(

0 commit comments

Comments
 (0)