Skip to content

Commit ebd10b3

Browse files
fix: resolve URL path truncation in SSE transport for proxied servers
1 parent 959d4e3 commit ebd10b3

File tree

3 files changed

+95
-25
lines changed

3 files changed

+95
-25
lines changed

src/mcp/server/sse.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
"""
2-
SSE Server Transport Module
2+
SSE Server Transport Module - Fixed Version
33
44
This module implements a Server-Sent Events (SSE) transport layer for MCP servers.
5+
Fixes the URL path joining issue when using subpaths/proxied servers.
56
67
Example usage:
7-
```
8-
# Create an SSE transport at an endpoint
8+
```python
9+
# Option 1: Create an SSE transport with absolute path (leading slash)
10+
# This treats "/messages/" as absolute within the app
911
sse = SseServerTransport("/messages/")
1012
13+
# Option 2: Create an SSE transport with relative path (no leading slash)
14+
# This treats "messages/" as relative to the root path - RECOMMENDED for proxied servers
15+
sse = SseServerTransport("messages/")
16+
1117
# Create Starlette routes for SSE and message handling
1218
routes = [
1319
Route("/sse", endpoint=handle_sse, methods=["GET"]),
@@ -30,6 +36,15 @@ async def handle_sse(request):
3036
uvicorn.run(starlette_app, host="127.0.0.1", port=port)
3137
```
3238
39+
Path behavior examples:
40+
- With root_path="" and endpoint="/messages/": Final path = "/messages/"
41+
- With root_path="" and endpoint="messages/": Final path = "/messages/"
42+
- With root_path="/api" and endpoint="/messages/": Final path = "/api/messages/"
43+
- With root_path="/api" and endpoint="messages/": Final path = "/api/messages/"
44+
45+
For servers behind proxies or mounted at subpaths, use the relative path format
46+
(without leading slash) to ensure proper URL joining with urllib.parse.urljoin().
47+
3348
Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType'
3449
object is not callable" error when client disconnects. The example above returns
3550
an empty Response() after the SSE connection ends to fix this.
@@ -84,7 +99,7 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
8499
85100
Args:
86101
endpoint: A relative path where messages should be posted
87-
(e.g., "/messages/").
102+
(e.g., "/messages/" or "messages/").
88103
security_settings: Optional security settings for DNS rebinding protection.
89104
90105
Note:
@@ -96,6 +111,9 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
96111
3. Portability: The same endpoint configuration works across different
97112
environments (development, staging, production)
98113
114+
The endpoint path handling has been updated to work correctly with urllib.parse.urljoin()
115+
when servers are behind proxies or mounted at subpaths.
116+
99117
Raises:
100118
ValueError: If the endpoint is a full URL instead of a relative path
101119
"""
@@ -105,19 +123,49 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
105123
# Validate that endpoint is a relative path and not a full URL
106124
if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint:
107125
raise ValueError(
108-
f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), "
109-
"expecting a relative path (e.g., '/messages/')."
126+
f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/' or 'messages/'), "
127+
"expecting a relative path (e.g., '/messages/' or 'messages/')."
110128
)
111129

112-
# Ensure endpoint starts with a forward slash
113-
if not endpoint.startswith("/"):
114-
endpoint = "/" + endpoint
115-
130+
# Handle leading slash more intelligently
131+
# Remove automatic leading slash enforcement to support proper URL joining
132+
# Store the endpoint as-is, allowing both "/messages/" and "messages/" formats
116133
self._endpoint = endpoint
134+
117135
self._read_stream_writers = {}
118136
self._security = TransportSecurityMiddleware(security_settings)
119137
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
120138

139+
def _build_message_path(self, root_path: str) -> str:
140+
"""
141+
Helper method to properly construct the message path
142+
143+
This method handles the path construction logic that was causing issues
144+
with urllib.parse.urljoin() when servers are proxied or mounted at subpaths.
145+
146+
Args:
147+
root_path: The root path from ASGI scope (e.g., "" or "/api_prefix")
148+
149+
Returns:
150+
The properly constructed path for client message posting
151+
"""
152+
# Clean up the root path
153+
clean_root_path = root_path.rstrip("/")
154+
155+
# If endpoint starts with "/", it's meant to be absolute within the app
156+
# If endpoint doesn't start with "/", it's meant to be relative to root_path
157+
if self._endpoint.startswith("/"):
158+
# Absolute path within the app - just concatenate
159+
full_path = clean_root_path + self._endpoint
160+
else:
161+
# Relative path - ensure proper joining
162+
if clean_root_path:
163+
full_path = clean_root_path + "/" + self._endpoint
164+
else:
165+
full_path = "/" + self._endpoint
166+
167+
return full_path
168+
121169
@asynccontextmanager
122170
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
123171
if scope["type"] != "http":
@@ -145,17 +193,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
145193
self._read_stream_writers[session_id] = read_stream_writer
146194
logger.debug(f"Created new session with ID: {session_id}")
147195

148-
# Determine the full path for the message endpoint to be sent to the client.
149-
# scope['root_path'] is the prefix where the current Starlette app
150-
# instance is mounted.
151-
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
196+
# Use the new helper method for proper path construction
152197
root_path = scope.get("root_path", "")
153-
154-
# self._endpoint is the path *within* this app, e.g., "/messages".
155-
# Concatenating them gives the full absolute path from the server root.
156-
# e.g., "" + "/messages" -> "/messages"
157-
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
158-
full_message_path_for_client = root_path.rstrip("/") + self._endpoint
198+
full_message_path_for_client = self._build_message_path(root_path)
159199

160200
# This is the URI (path + query) the client will use to POST messages.
161201
client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
@@ -246,4 +286,4 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
246286
logger.debug(f"Sending session message to writer: {session_message}")
247287
response = Response("Accepted", status_code=202)
248288
await response(scope, receive, send)
249-
await writer.send(session_message)
289+
await writer.send(session_message)

tests/server/test_sse_security.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,34 @@ async def test_sse_security_post_valid_content_type(server_port: int):
291291
finally:
292292
process.terminate()
293293
process.join()
294+
295+
296+
@pytest.mark.anyio
297+
async def test_endpoint_validation_rejects_absolute_urls():
298+
"""Test that SseServerTransport properly validates endpoint format."""
299+
# These should all raise ValueError due to being absolute URLs or having invalid characters
300+
invalid_endpoints = [
301+
"http://example.com/messages/",
302+
"https://example.com/messages/",
303+
"//example.com/messages/",
304+
"/messages/?query=test",
305+
"/messages/#fragment",
306+
]
307+
308+
for invalid_endpoint in invalid_endpoints:
309+
with pytest.raises(ValueError, match="is not a relative path"):
310+
SseServerTransport(invalid_endpoint)
311+
312+
# These should all be valid - endpoint is stored as-is (no automatic normalization)
313+
valid_endpoints_and_expected = [
314+
("/messages/", "/messages/"), # Absolute path format
315+
("messages/", "messages/"), # Relative path format
316+
("/api/v1/messages/", "/api/v1/messages/"),
317+
("api/v1/messages/", "api/v1/messages/"),
318+
]
319+
320+
for valid_endpoint, expected_stored_value in valid_endpoints_and_expected:
321+
# Should not raise an exception
322+
transport = SseServerTransport(valid_endpoint)
323+
# Endpoint should be stored exactly as provided (no normalization)
324+
assert transport._endpoint == expected_stored_value

tests/shared/test_sse.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ def test_sse_message_id_coercion():
487487
@pytest.mark.parametrize(
488488
"endpoint, expected_result",
489489
[
490-
# Valid endpoints - should normalize and work
490+
# These should all be valid - endpoint is stored as-is (no automatic normalization)
491491
("/messages/", "/messages/"),
492-
("messages/", "/messages/"),
492+
("messages/", "messages/"),
493493
("/", "/"),
494494
# Invalid endpoints - should raise ValueError
495495
("http://example.com/messages/", ValueError),
@@ -506,7 +506,6 @@ def test_sse_server_transport_endpoint_validation(endpoint: str, expected_result
506506
with pytest.raises(expected_result, match="is not a relative path.*expecting a relative path"):
507507
SseServerTransport(endpoint)
508508
else:
509-
# Test valid endpoints that should normalize correctly
509+
# Endpoint should be stored exactly as provided (no normalization)
510510
sse = SseServerTransport(endpoint)
511511
assert sse._endpoint == expected_result
512-
assert sse._endpoint.startswith("/")

0 commit comments

Comments
 (0)