11import logging
22from contextlib import asynccontextmanager
33from typing import Any
4- from urllib .parse import urljoin , urlparse
4+ from urllib .parse import urljoin , urlparse , urlunparse
55
66import anyio
77import httpx
@@ -18,6 +18,31 @@ def remove_request_params(url: str) -> str:
1818 return urljoin (url , urlparse (url ).path )
1919
2020
21+ def custom_url_join (base_url : str , endpoint : str ) -> str :
22+ """
23+ Custom URL join function to handle the case where the endpoint is relative
24+ to a base URL. This function ensures that the base URL and endpoint are
25+ combined correctly, even if the endpoint is not a full URL.
26+ """
27+ # Parse the base URL
28+ parsed_base = urlparse (base_url )
29+
30+ # Get the path prefix (e.g., '/weather')
31+ path_prefix = "/" .join (parsed_base .path .split ("/" )[:- 1 ])
32+
33+ # Remove any leading slash from the endpoint
34+ clean_endpoint = endpoint .lstrip ("/" )
35+
36+ # Create the new path by joining prefix and endpoint
37+ new_path = f"{ path_prefix } /{ clean_endpoint } "
38+
39+ # Create a new parsed URL with the updated path
40+ parsed_new = parsed_base ._replace (path = new_path )
41+
42+ # Convert back to a string URL
43+ return urlunparse (parsed_new )
44+
45+
2146@asynccontextmanager
2247async def sse_client (
2348 url : str ,
@@ -43,6 +68,15 @@ async def sse_client(
4368 async with anyio .create_task_group () as tg :
4469 try :
4570 logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
71+
72+ # extract MCP server name from URL, for example: https://mcp.example.com/weather/sse
73+ # will extract 'weather'
74+ path_tokens = urlparse (url ).path .split ("/" )
75+ optional_mcp_server_name = (
76+ path_tokens [1 :- 1 ] if len (path_tokens ) > 2 else None
77+ )
78+ logger .debug (f"MCP Server name (optional): { optional_mcp_server_name } " )
79+
4680 async with httpx .AsyncClient (headers = headers ) as client :
4781 async with aconnect_sse (
4882 client ,
@@ -61,7 +95,12 @@ async def sse_reader(
6195 logger .debug (f"Received SSE event: { sse .event } " )
6296 match sse .event :
6397 case "endpoint" :
64- endpoint_url = urljoin (url , sse .data )
98+ if optional_mcp_server_name :
99+ endpoint_url = custom_url_join (
100+ url , sse .data
101+ )
102+ else :
103+ endpoint_url = urljoin (url , sse .data )
65104 logger .info (
66105 f"Received endpoint URL: { endpoint_url } "
67106 )
0 commit comments