diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7ca8d19afd..2025f4a0eb 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,14 +7,17 @@ import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx._config import DEFAULT_TIMEOUT_CONFIG from httpx_sse import aconnect_sse import mcp.types as types -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) +HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG + def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) @@ -26,8 +29,8 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + httpx_client: httpx.AsyncClient | None = None, ): """ Client transport for SSE. @@ -38,9 +41,12 @@ async def sse_client( Args: url: The SSE endpoint URL. headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. + timeout: HTTP timeout for regular operations. Defaults to 5 seconds. + sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes). auth: Optional HTTPX authentication handler. + httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's + existing configuration is preserved. Timeout is only overridden if the provided + client uses httpx's default timeout configuration. """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -51,12 +57,28 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + if httpx_client is not None: + client = httpx_client + if not getattr(client, "follow_redirects", False): + logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP") + if headers: + existing_headers = dict(client.headers) if client.headers else {} + existing_headers.update(headers) + client.headers = existing_headers + if auth and not client.auth: + client.auth = auth + + if client.timeout == HTTPX_DEFAULT_TIMEOUT: + client.timeout = httpx.Timeout(timeout, read=sse_read_timeout) + else: + client = create_mcp_http_client( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) + async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: + async with client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df647057..4a0c8a7e21 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -16,9 +16,10 @@ import httpx from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx._config import DEFAULT_TIMEOUT_CONFIG from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -33,6 +34,7 @@ logger = logging.getLogger(__name__) +HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG SessionMessageOrError = SessionMessage | Exception StreamWriter = MemoryObjectSendStream[SessionMessageOrError] @@ -448,8 +450,8 @@ async def streamablehttp_client( timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, + httpx_client: httpx.AsyncClient | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -464,6 +466,19 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + Args: + url: The StreamableHTTP endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. Defaults to 30 seconds. + Can be specified as float (seconds) or timedelta object. + sse_read_timeout: Timeout for SSE read operations. Defaults to 300 seconds (5 minutes). + Can be specified as float (seconds) or timedelta object. + terminate_on_close: Whether to send a terminate request when closing the connection. + auth: Optional HTTPX authentication handler. + httpx_client: Optional pre-configured httpx.AsyncClient. If provided, the client's + existing configuration is preserved. Timeout is only overridden if the provided + client uses httpx's default timeout configuration. + Yields: Tuple containing: - read_stream: Stream for reading messages from the server @@ -475,15 +490,30 @@ async def streamablehttp_client( read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + if httpx_client is not None: + client = httpx_client + if not getattr(client, "follow_redirects", False): + logger.warning("httpx_client does not have follow_redirects=True, which is recommended for MCP") + if headers: + existing_headers = dict(client.headers) if client.headers else {} + existing_headers.update(transport.request_headers) + client.headers = existing_headers + if auth and not client.auth: + client.auth = auth + if client.timeout == HTTPX_DEFAULT_TIMEOUT: + client.timeout = httpx.Timeout(transport.timeout, read=transport.sse_read_timeout) + else: + client = create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), + auth=transport.auth, + ) + async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: + async with client: # Define callbacks that need access to tg def start_get_stream() -> None: tg.start_soon(transport.handle_get_stream, client, read_stream_writer) diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index e0611ce73d..5240c970c6 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -1,21 +1,12 @@ """Utilities for creating standardized httpx AsyncClient instances.""" -from typing import Any, Protocol +from typing import Any import httpx __all__ = ["create_mcp_http_client"] -class McpHttpClientFactory(Protocol): - def __call__( - self, - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, - auth: httpx.Auth | None = None, - ) -> httpx.AsyncClient: ... - - def create_mcp_http_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py index dcc6fd003c..cabc186540 100644 --- a/tests/shared/test_httpx_utils.py +++ b/tests/shared/test_httpx_utils.py @@ -18,7 +18,7 @@ def test_custom_parameters(): headers = {"Authorization": "Bearer token"} timeout = httpx.Timeout(60.0) - client = create_mcp_http_client(headers, timeout) + client = create_mcp_http_client(headers=headers, timeout=timeout) assert client.headers["Authorization"] == "Bearer token" assert client.timeout.connect == 60.0