Skip to content

Commit 673ae93

Browse files
author
rkondra-eightfold
committed
fix: incorrect resolution of the /messages endpoint URL in the SSE client when the FastAPI app is mounted under a base path (e.g., /mcp).
1 parent babb477 commit 673ae93

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

src/mcp/client/sse.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,31 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17+
# TODO: move these to utils/url_utils.py
18+
def get_origin(url: str) -> str:
19+
parsed_url = urlparse(url)
20+
return f"{parsed_url.scheme}://{parsed_url.netloc}"
21+
22+
23+
def get_path(url: str) -> str:
24+
parsed_url = urlparse(url)
25+
return parsed_url.path
26+
27+
28+
def get_endpoint_url(
29+
base_url: str, sse_relative_url: str, server_mount_path: str = ""
30+
) -> str:
31+
endpoint_url = urljoin(base_url, sse_relative_url)
32+
if server_mount_path:
33+
origin, path = get_origin(endpoint_url), get_path(endpoint_url)
34+
endpoint_url = urljoin(
35+
f"{origin}/{server_mount_path.strip('/')}/", path.lstrip("/")
36+
)
37+
return endpoint_url
38+
39+
1740
def remove_request_params(url: str) -> str:
18-
return urljoin(url, urlparse(url).path)
41+
return urljoin(url, get_path(url))
1942

2043

2144
@asynccontextmanager
@@ -24,12 +47,16 @@ async def sse_client(
2447
headers: dict[str, Any] | None = None,
2548
timeout: float = 5,
2649
sse_read_timeout: float = 60 * 5,
50+
server_mount_path: str = "",
2751
):
2852
"""
2953
Client transport for SSE.
3054
3155
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
3256
event before disconnecting. All other HTTP operations are controlled by `timeout`.
57+
58+
`server_mount_path` provides the relative mount path of the MCP server
59+
(used if it is mounted relatively on another ASGI server).
3360
"""
3461
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
3562
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
@@ -61,18 +88,15 @@ async def sse_reader(
6188
logger.debug(f"Received SSE event: {sse.event}")
6289
match sse.event:
6390
case "endpoint":
64-
endpoint_url = urljoin(url, sse.data)
91+
endpoint_url = get_endpoint_url(
92+
base_url=url,
93+
sse_relative_url=sse.data,
94+
server_mount_path=server_mount_path,
95+
)
6596
logger.info(
6697
f"Received endpoint URL: {endpoint_url}"
6798
)
68-
69-
url_parsed = urlparse(url)
70-
endpoint_parsed = urlparse(endpoint_url)
71-
if (
72-
url_parsed.netloc != endpoint_parsed.netloc
73-
or url_parsed.scheme
74-
!= endpoint_parsed.scheme
75-
):
99+
if get_origin(url) != get_origin(endpoint_url):
76100
error_msg = (
77101
"Endpoint origin does not match "
78102
f"connection origin: {endpoint_url}"

0 commit comments

Comments
 (0)