Skip to content

Commit b0a6aaf

Browse files
committed
WIP
1 parent f164291 commit b0a6aaf

File tree

2 files changed

+116
-54
lines changed

2 files changed

+116
-54
lines changed

src/mcp/client/sse.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,20 @@ async def sse_client(
2424
headers: dict[str, Any] | None = None,
2525
timeout: float = 5,
2626
sse_read_timeout: float = 60 * 5,
27+
client: httpx.AsyncClient | None = None,
2728
):
2829
"""
2930
Client transport for SSE.
3031
3132
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3233
event before disconnecting. All other HTTP operations are controlled by `timeout`.
34+
35+
Args:
36+
url: The URL to connect to
37+
headers: Optional headers to send with the request
38+
timeout: Connection timeout in seconds
39+
sse_read_timeout: Read timeout in seconds
40+
client: Optional httpx.AsyncClient instance to use for requests
3341
"""
3442
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
3543
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
@@ -43,7 +51,13 @@ async def sse_client(
4351
async with anyio.create_task_group() as tg:
4452
try:
4553
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
46-
async with httpx.AsyncClient(headers=headers) as client:
54+
if client is None:
55+
client = httpx.AsyncClient(headers=headers)
56+
should_close_client = True
57+
else:
58+
should_close_client = False
59+
60+
try:
4761
async with aconnect_sse(
4862
client,
4963
"GET",
@@ -137,6 +151,9 @@ async def post_writer(endpoint_url: str):
137151
yield read_stream, write_stream
138152
finally:
139153
tg.cancel_scope.cancel()
154+
finally:
155+
if should_close_client:
156+
await client.aclose()
140157
finally:
141158
await read_stream_writer.aclose()
142159
await write_stream.aclose()

tests/client/test_sse_attempt.py

Lines changed: 98 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,127 @@
1-
import pytest
21
import anyio
2+
import pytest
33
from starlette.applications import Starlette
44
from starlette.routing import Mount, Route
5-
import uvicorn
6-
from mcp.client.sse import sse_client
7-
from exceptiongroup import ExceptionGroup
8-
import asyncio
95
import httpx
10-
from httpx import ReadTimeout
6+
from httpx import ReadTimeout, ASGITransport
117

8+
from mcp.client.sse import sse_client
129
from mcp.server.sse import SseServerTransport
10+
from mcp.types import JSONRPCMessage
11+
1312

1413
@pytest.fixture
15-
async def sse_server():
14+
async def sse_transport():
15+
"""Fixture that creates an SSE transport instance."""
16+
return SseServerTransport("/messages/")
1617

17-
# Create an SSE transport at an endpoint
18-
sse = SseServerTransport("/messages/")
1918

20-
# Create Starlette routes for SSE and message handling
19+
@pytest.fixture
20+
async def sse_app(sse_transport):
21+
"""Fixture that creates a Starlette app with SSE endpoints."""
22+
async def handle_sse(request):
23+
"""Handler for SSE connections."""
24+
async with sse_transport.connect_sse(
25+
request.scope, request.receive, request._send
26+
) as streams:
27+
client_to_server, server_to_client = streams
28+
async for message in client_to_server:
29+
# Echo messages back for testing
30+
await server_to_client.send(message)
31+
2132
routes = [
2233
Route("/sse", endpoint=handle_sse),
23-
Mount("/messages/", app=sse.handle_post_message),
34+
Mount("/messages", app=sse_transport.handle_post_message),
2435
]
25-
#
26-
# Create and run Starlette app
27-
app = Starlette(routes=routes)
2836

29-
# Define handler functions
30-
async def handle_sse(request):
31-
async with sse.connect_sse(
32-
request.scope, request.receive, request._send
33-
) as streams:
34-
await app.run(
35-
streams[0], streams[1], app.create_initialization_options()
36-
)
37+
return Starlette(routes=routes)
3738

38-
uvicorn.run(app, host="127.0.0.1", port=34891)
3939

40-
async def sse_handler(request):
41-
response = httpx.Response(200, content_type="text/event-stream")
42-
response.send_headers()
43-
response.write("data: test\n\n")
44-
await response.aclose()
40+
@pytest.fixture
41+
async def test_client(sse_app):
42+
"""Create a test client with ASGI transport."""
43+
async with httpx.AsyncClient(
44+
transport=ASGITransport(app=sse_app),
45+
base_url="http://testserver",
46+
) as client:
47+
yield client
4548

46-
async with httpx.AsyncServer(sse_handler) as server:
47-
yield server.url
4849

50+
@pytest.mark.anyio
51+
async def test_sse_connection(test_client):
52+
"""Test basic SSE connection and message exchange."""
53+
async with sse_client(
54+
"http://testserver/sse",
55+
headers={"Host": "testserver"},
56+
timeout=5,
57+
client=test_client,
58+
) as (read_stream, write_stream):
59+
# Send a test message
60+
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
61+
await write_stream.send(test_message)
4962

50-
@pytest.fixture
51-
async def sse_client():
52-
async with sse_client("http://test/sse") as (read_stream, write_stream):
63+
# Receive echoed message
5364
async with read_stream:
54-
async for message in read_stream:
55-
if isinstance(message, Exception):
56-
raise message
65+
message = await read_stream.__anext__()
66+
assert isinstance(message, JSONRPCMessage)
67+
assert message.model_dump() == test_message.model_dump()
5768

58-
return read_stream, write_stream
5969

6070
@pytest.mark.anyio
61-
async def test_sse_happy_path(monkeypatch):
62-
# Mock httpx.AsyncClient to return our mock response
63-
monkeypatch.setattr(httpx, "AsyncClient", MockClient)
64-
65-
with pytest.raises(ReadTimeout) as exc_info:
71+
async def test_sse_read_timeout(test_client):
72+
"""Test that SSE client properly handles read timeouts."""
73+
with pytest.raises(ReadTimeout):
6674
async with sse_client(
67-
"http://test/sse",
68-
timeout=5, # Connection timeout - make this longer
69-
sse_read_timeout=1 # Read timeout - this should trigger
75+
"http://testserver/sse",
76+
headers={"Host": "testserver"},
77+
timeout=5,
78+
sse_read_timeout=1,
79+
client=test_client,
7080
) as (read_stream, write_stream):
7181
async with read_stream:
72-
async for message in read_stream:
73-
if isinstance(message, Exception):
74-
raise message
82+
# This should timeout since no messages are being sent
83+
await read_stream.__anext__()
84+
85+
86+
@pytest.mark.anyio
87+
async def test_sse_connection_error(test_client):
88+
"""Test SSE client behavior with connection errors."""
89+
with pytest.raises(httpx.HTTPError):
90+
async with sse_client(
91+
"http://testserver/nonexistent",
92+
headers={"Host": "testserver"},
93+
timeout=5,
94+
client=test_client,
95+
):
96+
pass # Should not reach here
7597

76-
error = exc_info.value
77-
assert isinstance(error, ReadTimeout)
78-
assert str(error) == "Read timeout"
7998

8099
@pytest.mark.anyio
81-
async def test_sse_read_timeouts(monkeypatch):
82-
"""Test that the SSE client properly handles read timeouts between SSE messages."""
100+
async def test_sse_multiple_messages(test_client):
101+
"""Test sending and receiving multiple SSE messages."""
102+
async with sse_client(
103+
"http://testserver/sse",
104+
headers={"Host": "testserver"},
105+
timeout=5,
106+
client=test_client,
107+
) as (read_stream, write_stream):
108+
# Send multiple test messages
109+
messages = [
110+
JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
111+
for i in range(3)
112+
]
113+
114+
for msg in messages:
115+
await write_stream.send(msg)
116+
117+
# Receive all echoed messages
118+
received = []
119+
async with read_stream:
120+
for _ in range(len(messages)):
121+
message = await read_stream.__anext__()
122+
assert isinstance(message, JSONRPCMessage)
123+
received.append(message)
124+
125+
# Verify all messages were received in order
126+
for sent, received in zip(messages, received):
127+
assert sent.model_dump() == received.model_dump()

0 commit comments

Comments
 (0)