Skip to content

Commit fde04eb

Browse files
committed
Fix SSE server transport to support absolute endpoints
This change fixes the endpoint URL handling in the SSE server transport to support both relative and absolute URLs. Some clients like Copilot Studio require absolute URLs. This change aligns with the TypeScript SDK's support for absolute endpoint URLs as in https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/sse.ts The PR: 1. Removes unnecessary URL quoting which would break absolute URLs 2. Adds comprehensive tests for both relative and absolute URL endpoints
1 parent 58c5e72 commit fde04eb

File tree

2 files changed

+260
-12
lines changed

2 files changed

+260
-12
lines changed

src/mcp/server/sse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ async def handle_sse(request):
4040
import logging
4141
from contextlib import asynccontextmanager
4242
from typing import Any
43-
from urllib.parse import quote
4443
from uuid import UUID, uuid4
4544

4645
import anyio
@@ -100,7 +99,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
10099
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
101100

102101
session_id = uuid4()
103-
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
102+
session_uri = f"{self._endpoint}?session_id={session_id.hex}"
104103
self._read_stream_writers[session_id] = read_stream_writer
105104
logger.debug(f"Created new session with ID: {session_id}")
106105

tests/shared/test_sse.py

Lines changed: 259 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,54 @@ def run_server(server_port: int) -> None:
119119
time.sleep(0.5)
120120

121121

122+
def make_server_app_with_endpoint(endpoint: str) -> Starlette:
123+
"""Create test Starlette app with SSE transport using the specified endpoint"""
124+
sse = SseServerTransport(endpoint)
125+
server = ServerTest()
126+
127+
async def handle_sse(request: Request) -> Response:
128+
async with sse.connect_sse(
129+
request.scope, request.receive, request._send
130+
) as streams:
131+
await server.run(
132+
streams[0], streams[1], server.create_initialization_options()
133+
)
134+
return Response()
135+
136+
# For absolute URLs, we route all paths
137+
if endpoint.startswith(("http://", "https://")):
138+
route_path = "/sse"
139+
mount_path = "/"
140+
else:
141+
route_path = "/sse"
142+
mount_path = endpoint
143+
144+
app = Starlette(
145+
routes=[
146+
Route(route_path, endpoint=handle_sse),
147+
Mount(mount_path, app=sse.handle_post_message),
148+
]
149+
)
150+
151+
return app
152+
153+
154+
def run_server_with_endpoint(server_port: int, endpoint: str) -> None:
155+
app = make_server_app_with_endpoint(endpoint)
156+
server = uvicorn.Server(
157+
config=uvicorn.Config(
158+
app=app, host="127.0.0.1", port=server_port, log_level="error"
159+
)
160+
)
161+
print(f"starting server on {server_port} with endpoint {endpoint}")
162+
server.run()
163+
164+
# Give server time to start
165+
while not server.started:
166+
print("waiting for server to start")
167+
time.sleep(0.5)
168+
169+
122170
@pytest.fixture()
123171
def server(server_port: int) -> Generator[None, None, None]:
124172
proc = multiprocessing.Process(
@@ -159,6 +207,129 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N
159207
yield client
160208

161209

210+
@pytest.fixture()
211+
def server_with_relative_endpoint(server_port: int) -> Generator[None, None, None]:
212+
"""Setup a server with a relative endpoint path"""
213+
proc = multiprocessing.Process(
214+
target=run_server_with_endpoint,
215+
kwargs={"server_port": server_port, "endpoint": "/messages/"},
216+
daemon=True,
217+
)
218+
print("starting process with relative endpoint")
219+
proc.start()
220+
221+
# Wait for server to be running
222+
max_attempts = 20
223+
attempt = 0
224+
print("waiting for server to start")
225+
while attempt < max_attempts:
226+
try:
227+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
228+
s.connect(("127.0.0.1", server_port))
229+
break
230+
except ConnectionRefusedError:
231+
time.sleep(0.1)
232+
attempt += 1
233+
else:
234+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
235+
236+
yield
237+
238+
print("killing server")
239+
# Signal the server to stop
240+
proc.kill()
241+
proc.join(timeout=2)
242+
if proc.is_alive():
243+
print("server process failed to terminate")
244+
245+
246+
@pytest.fixture()
247+
def server_with_absolute_endpoint(
248+
server_port: int, server_url: str
249+
) -> Generator[None, None, None]:
250+
"""Setup a server with an absolute endpoint URL"""
251+
absolute_endpoint = f"{server_url}/messages/"
252+
proc = multiprocessing.Process(
253+
target=run_server_with_endpoint,
254+
kwargs={"server_port": server_port, "endpoint": absolute_endpoint},
255+
daemon=True,
256+
)
257+
print(f"starting process with absolute endpoint: {absolute_endpoint}")
258+
proc.start()
259+
260+
# Wait for server to be running
261+
max_attempts = 20
262+
attempt = 0
263+
print("waiting for server to start")
264+
while attempt < max_attempts:
265+
try:
266+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
267+
s.connect(("127.0.0.1", server_port))
268+
break
269+
except ConnectionRefusedError:
270+
time.sleep(0.1)
271+
attempt += 1
272+
else:
273+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
274+
275+
yield
276+
277+
print("killing server")
278+
# Signal the server to stop
279+
proc.kill()
280+
proc.join(timeout=2)
281+
if proc.is_alive():
282+
print("server process failed to terminate")
283+
284+
285+
@pytest.fixture()
286+
async def http_client_with_relative_endpoint(
287+
server_with_relative_endpoint, server_url
288+
) -> AsyncGenerator[httpx.AsyncClient, None]:
289+
"""Create test client for server with relative endpoint"""
290+
async with httpx.AsyncClient(base_url=server_url) as client:
291+
yield client
292+
293+
294+
@pytest.fixture()
295+
async def http_client_with_absolute_endpoint(
296+
server_with_absolute_endpoint, server_url
297+
) -> AsyncGenerator[httpx.AsyncClient, None]:
298+
"""Create test client for server with absolute endpoint"""
299+
async with httpx.AsyncClient(base_url=server_url) as client:
300+
yield client
301+
302+
303+
@pytest.fixture
304+
async def initialized_sse_client_session(
305+
server, server_url: str
306+
) -> AsyncGenerator[ClientSession, None]:
307+
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
308+
async with ClientSession(*streams) as session:
309+
await session.initialize()
310+
yield session
311+
312+
313+
@pytest.fixture
314+
async def initialized_sse_client_session_with_relative_endpoint(
315+
server_with_relative_endpoint, server_url: str
316+
) -> AsyncGenerator[ClientSession, None]:
317+
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
318+
async with ClientSession(*streams) as session:
319+
await session.initialize()
320+
yield session
321+
322+
323+
@pytest.fixture
324+
async def initialized_sse_client_session_with_absolute_endpoint(
325+
server_with_absolute_endpoint, server_url: str
326+
) -> AsyncGenerator[ClientSession, None]:
327+
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
328+
async with ClientSession(*streams) as session:
329+
await session.initialize()
330+
yield session
331+
332+
162333
# Tests
163334
@pytest.mark.anyio
164335
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
@@ -202,16 +373,6 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
202373
assert isinstance(ping_result, EmptyResult)
203374

204375

205-
@pytest.fixture
206-
async def initialized_sse_client_session(
207-
server, server_url: str
208-
) -> AsyncGenerator[ClientSession, None]:
209-
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:
210-
async with ClientSession(*streams) as session:
211-
await session.initialize()
212-
yield session
213-
214-
215376
@pytest.mark.anyio
216377
async def test_sse_client_happy_request_and_response(
217378
initialized_sse_client_session: ClientSession,
@@ -252,3 +413,91 @@ async def test_sse_client_timeout(
252413
return
253414

254415
pytest.fail("the client should have timed out and returned an error already")
416+
417+
418+
@pytest.mark.anyio
419+
async def test_raw_sse_connection_with_relative_endpoint(http_client_with_relative_endpoint: httpx.AsyncClient) -> None:
420+
"""Test the SSE connection establishment with a relative endpoint URL."""
421+
async with anyio.create_task_group():
422+
423+
async def connection_test() -> None:
424+
async with http_client_with_relative_endpoint.stream("GET", "/sse") as response:
425+
assert response.status_code == 200
426+
assert (
427+
response.headers["content-type"]
428+
== "text/event-stream; charset=utf-8"
429+
)
430+
431+
line_number = 0
432+
async for line in response.aiter_lines():
433+
if line_number == 0:
434+
assert line == "event: endpoint"
435+
elif line_number == 1:
436+
assert line.startswith("data: /messages/?session_id=")
437+
# Verify it's a relative URL
438+
endpoint_data = line.removeprefix("data: ")
439+
assert not endpoint_data.startswith(("http://", "https://"))
440+
assert endpoint_data.startswith("/messages/?session_id=")
441+
else:
442+
return
443+
line_number += 1
444+
445+
# Add timeout to prevent test from hanging if it fails
446+
with anyio.fail_after(3):
447+
await connection_test()
448+
449+
450+
@pytest.mark.anyio
451+
async def test_raw_sse_connection_with_absolute_endpoint(http_client_with_absolute_endpoint: httpx.AsyncClient) -> None:
452+
"""Test the SSE connection establishment with an absolute endpoint URL."""
453+
async with anyio.create_task_group():
454+
455+
async def connection_test() -> None:
456+
async with http_client_with_absolute_endpoint.stream("GET", "/sse") as response:
457+
assert response.status_code == 200
458+
assert (
459+
response.headers["content-type"]
460+
== "text/event-stream; charset=utf-8"
461+
)
462+
463+
line_number = 0
464+
async for line in response.aiter_lines():
465+
if line_number == 0:
466+
assert line == "event: endpoint"
467+
elif line_number == 1:
468+
# Verify it's an absolute URL
469+
assert line.startswith("data: http://")
470+
assert "/messages/?session_id=" in line
471+
else:
472+
return
473+
line_number += 1
474+
475+
# Add timeout to prevent test from hanging if it fails
476+
with anyio.fail_after(3):
477+
await connection_test()
478+
479+
480+
@pytest.mark.anyio
481+
async def test_sse_client_with_relative_endpoint(
482+
initialized_sse_client_session_with_relative_endpoint: ClientSession,
483+
) -> None:
484+
"""Test that a client session works properly with a relative endpoint."""
485+
session = initialized_sse_client_session_with_relative_endpoint
486+
# Test basic functionality
487+
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
488+
assert len(response.contents) == 1
489+
assert isinstance(response.contents[0], TextResourceContents)
490+
assert response.contents[0].text == "Read should-work"
491+
492+
493+
@pytest.mark.anyio
494+
async def test_sse_client_with_absolute_endpoint(
495+
initialized_sse_client_session_with_absolute_endpoint: ClientSession,
496+
) -> None:
497+
"""Test that a client session works properly with an absolute endpoint."""
498+
session = initialized_sse_client_session_with_absolute_endpoint
499+
# Test basic functionality
500+
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
501+
assert len(response.contents) == 1
502+
assert isinstance(response.contents[0], TextResourceContents)
503+
assert response.contents[0].text == "Read should-work"

0 commit comments

Comments
 (0)