Skip to content

Commit 596864e

Browse files
committed
Support injectable httpx client
1 parent 6353dd1 commit 596864e

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

src/mcp/client/sse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from contextlib import asynccontextmanager
3-
from typing import Any
3+
from typing import Any, Callable
44
from urllib.parse import urljoin, urlparse
55

66
import anyio
@@ -10,7 +10,7 @@
1010
from httpx_sse import aconnect_sse
1111

1212
import mcp.types as types
13-
from mcp.shared._httpx_utils import create_mcp_http_client
13+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1414
from mcp.shared.message import SessionMessage
1515

1616
logger = logging.getLogger(__name__)
@@ -26,6 +26,7 @@ async def sse_client(
2626
headers: dict[str, Any] | None = None,
2727
timeout: float = 5,
2828
sse_read_timeout: float = 60 * 5,
29+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
2930
):
3031
"""
3132
Client transport for SSE.
@@ -45,7 +46,7 @@ async def sse_client(
4546
async with anyio.create_task_group() as tg:
4647
try:
4748
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
48-
async with create_mcp_http_client(headers=headers) as client:
49+
async with httpx_client_factory(headers=headers) as client:
4950
async with aconnect_sse(
5051
client,
5152
"GET",

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2020
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2121

22-
from mcp.shared._httpx_utils import create_mcp_http_client
22+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
2323
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2424
from mcp.types import (
2525
ErrorData,
@@ -427,6 +427,7 @@ async def streamablehttp_client(
427427
timeout: timedelta = timedelta(seconds=30),
428428
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
429429
terminate_on_close: bool = True,
430+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
430431
) -> AsyncGenerator[
431432
tuple[
432433
MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -460,7 +461,7 @@ async def streamablehttp_client(
460461
try:
461462
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
462463

463-
async with create_mcp_http_client(
464+
async with httpx_client_factory(
464465
headers=transport.request_headers,
465466
timeout=httpx.Timeout(
466467
transport.timeout.seconds, read=transport.sse_read_timeout.seconds

src/mcp/shared/_httpx_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
"""Utilities for creating standardized httpx AsyncClient instances."""
22

3-
from typing import Any
3+
from typing import Any, Protocol
44

55
import httpx
66

77
__all__ = ["create_mcp_http_client"]
88

99

10+
class McpHttpClientFactory(Protocol):
11+
def __call__(
12+
self,
13+
headers: dict[str, str] | None = None,
14+
timeout: httpx.Timeout | None = None,
15+
) -> httpx.AsyncClient:
16+
...
17+
18+
1019
def create_mcp_http_client(
1120
headers: dict[str, str] | None = None,
1221
timeout: httpx.Timeout | None = None,

0 commit comments

Comments
 (0)