Skip to content

Commit 283f43c

Browse files
committed
fix for SSE URL handling when a server name is specified
1 parent 697b6e8 commit 283f43c

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

src/mcp/client/sse.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from contextlib import asynccontextmanager
33
from typing import Any
4-
from urllib.parse import urljoin, urlparse
4+
from urllib.parse import urljoin, urlparse, urlunparse
55

66
import anyio
77
import httpx
@@ -18,6 +18,31 @@ def remove_request_params(url: str) -> str:
1818
return urljoin(url, urlparse(url).path)
1919

2020

21+
def custom_url_join(base_url: str, endpoint: str) -> str:
22+
"""
23+
Custom URL join function to handle the case where the endpoint is relative
24+
to a base URL. This function ensures that the base URL and endpoint are
25+
combined correctly, even if the endpoint is not a full URL.
26+
"""
27+
# Parse the base URL
28+
parsed_base = urlparse(base_url)
29+
30+
# Get the path prefix (e.g., '/weather')
31+
path_prefix = "/".join(parsed_base.path.split("/")[:-1])
32+
33+
# Remove any leading slash from the endpoint
34+
clean_endpoint = endpoint.lstrip("/")
35+
36+
# Create the new path by joining prefix and endpoint
37+
new_path = f"{path_prefix}/{clean_endpoint}"
38+
39+
# Create a new parsed URL with the updated path
40+
parsed_new = parsed_base._replace(path=new_path)
41+
42+
# Convert back to a string URL
43+
return urlunparse(parsed_new)
44+
45+
2146
@asynccontextmanager
2247
async def sse_client(
2348
url: str,
@@ -43,6 +68,15 @@ async def sse_client(
4368
async with anyio.create_task_group() as tg:
4469
try:
4570
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
71+
72+
# extract MCP server name from URL, for example: https://mcp.example.com/weather/sse
73+
# will extract 'weather'
74+
path_tokens = urlparse(url).path.split("/")
75+
optional_mcp_server_name = (
76+
path_tokens[1:-1] if len(path_tokens) > 2 else None
77+
)
78+
logger.debug(f"MCP Server name (optional): {optional_mcp_server_name}")
79+
4680
async with httpx.AsyncClient(headers=headers) as client:
4781
async with aconnect_sse(
4882
client,
@@ -61,7 +95,12 @@ async def sse_reader(
6195
logger.debug(f"Received SSE event: {sse.event}")
6296
match sse.event:
6397
case "endpoint":
64-
endpoint_url = urljoin(url, sse.data)
98+
if optional_mcp_server_name:
99+
endpoint_url = custom_url_join(
100+
url, sse.data
101+
)
102+
else:
103+
endpoint_url = urljoin(url, sse.data)
65104
logger.info(
66105
f"Received endpoint URL: {endpoint_url}"
67106
)

tests/client/test_sse.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import pytest
2+
from urllib.parse import urlparse, urlunparse
3+
4+
from mcp.client.sse import custom_url_join
5+
6+
7+
@pytest.mark.parametrize(
8+
"base_url,endpoint,expected",
9+
[
10+
# Additional test cases to verify behavior with different URL structures
11+
(
12+
"https://mcp.example.com/weather/sse",
13+
"/messages/?session_id=616df71373444d76bd566df4377c9629",
14+
"https://mcp.example.com/weather/messages/?session_id=616df71373444d76bd566df4377c9629"
15+
),
16+
(
17+
"https://mcp.example.com/weather/clarksburg/sse",
18+
"/messages/?session_id=616df71373444d76bd566df4377c9629",
19+
"https://mcp.example.com/weather/clarksburg/messages/?session_id=616df71373444d76bd566df4377c9629"
20+
),
21+
(
22+
"https://mcp.example.com/sse",
23+
"/messages/?session_id=616df71373444d76bd566df4377c9629",
24+
"https://mcp.example.com/messages/?session_id=616df71373444d76bd566df4377c9629"
25+
),
26+
],
27+
)
28+
def test_custom_url_join(base_url, endpoint, expected):
29+
"""Test the custom_url_join function with messages endpoint and session ID."""
30+
result = custom_url_join(base_url, endpoint)
31+
assert result == expected

0 commit comments

Comments
 (0)