1414logger = logging .getLogger (__name__ )
1515
1616
17+ # TODO: move these to utils/url_utils.py
18+ def get_origin (url : str ) -> str :
19+ parsed_url = urlparse (url )
20+ return f"{ parsed_url .scheme } ://{ parsed_url .netloc } "
21+
22+
23+ def get_path (url : str ) -> str :
24+ parsed_url = urlparse (url )
25+ return parsed_url .path
26+
27+
28+ def get_endpoint_url (
29+ base_url : str , sse_relative_url : str , server_mount_path : str = ""
30+ ) -> str :
31+ endpoint_url = urljoin (base_url , sse_relative_url )
32+ if server_mount_path :
33+ origin , path = get_origin (endpoint_url ), get_path (endpoint_url )
34+ endpoint_url = urljoin (
35+ f"{ origin } /{ server_mount_path .strip ('/' )} /" , path .lstrip ("/" )
36+ )
37+ return endpoint_url
38+
39+
1740def remove_request_params (url : str ) -> str :
18- return urljoin (url , urlparse (url ). path )
41+ return urljoin (url , get_path (url ))
1942
2043
2144@asynccontextmanager
@@ -24,12 +47,16 @@ async def sse_client(
2447 headers : dict [str , Any ] | None = None ,
2548 timeout : float = 5 ,
2649 sse_read_timeout : float = 60 * 5 ,
50+ server_mount_path : str = "" ,
2751):
2852 """
2953 Client transport for SSE.
3054
3155 `sse_read_timeout` determines how long (in seconds) the client will wait for a new
3256 event before disconnecting. All other HTTP operations are controlled by `timeout`.
57+
58+ `server_mount_path` provides the relative mount path of the MCP server
59+ (used if it is mounted relatively on another ASGI server).
3360 """
3461 read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ]
3562 read_stream_writer : MemoryObjectSendStream [types .JSONRPCMessage | Exception ]
@@ -61,18 +88,15 @@ async def sse_reader(
6188 logger .debug (f"Received SSE event: { sse .event } " )
6289 match sse .event :
6390 case "endpoint" :
64- endpoint_url = urljoin (url , sse .data )
91+ endpoint_url = get_endpoint_url (
92+ base_url = url ,
93+ sse_relative_url = sse .data ,
94+ server_mount_path = server_mount_path ,
95+ )
6596 logger .info (
6697 f"Received endpoint URL: { endpoint_url } "
6798 )
68-
69- url_parsed = urlparse (url )
70- endpoint_parsed = urlparse (endpoint_url )
71- if (
72- url_parsed .netloc != endpoint_parsed .netloc
73- or url_parsed .scheme
74- != endpoint_parsed .scheme
75- ):
99+ if get_origin (url ) != get_origin (endpoint_url ):
76100 error_msg = (
77101 "Endpoint origin does not match "
78102 f"connection origin: { endpoint_url } "
0 commit comments