Skip to content

Commit 8714c53

Browse files
committed
Adapt to PR #1177
1 parent ee35583 commit 8714c53

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def get_session_id(self) -> str | None:
464464
@asynccontextmanager
465465
async def streamable_http_client(
466466
url: str,
467+
extensions: dict[str, str] | None = None,
467468
*,
468469
http_client: httpx.AsyncClient | None = None,
469470
terminate_on_close: bool = True,
@@ -480,6 +481,7 @@ async def streamable_http_client(
480481
481482
Args:
482483
url: The MCP server endpoint URL.
484+
extensions: Optional extensions to include in requests.
483485
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
484486
client with recommended MCP timeouts will be created. To configure headers,
485487
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
@@ -515,7 +517,7 @@ async def streamable_http_client(
515517
auth = client.auth
516518

517519
# Create transport with extracted configuration
518-
transport = StreamableHTTPTransport(url, headers_dict, timeout, sse_read_timeout, auth)
520+
transport = StreamableHTTPTransport(url, headers_dict, extensions, timeout, sse_read_timeout, auth)
519521

520522
async with anyio.create_task_group() as tg:
521523
try:

tests/shared/test_streamable_http.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,7 +1799,7 @@ async def test_extensions_passed_to_streamablehttp_client(self, basic_server: No
17991799
"custom_metadata": "custom_data",
18001800
}
18011801

1802-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
1802+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
18031803
read_stream,
18041804
write_stream,
18051805
_,
@@ -1817,15 +1817,15 @@ async def test_extensions_passed_to_streamablehttp_client(self, basic_server: No
18171817
@pytest.mark.anyio
18181818
async def test_extensions_with_empty_dict(self, basic_server: None, basic_server_url: str):
18191819
"""Test streamablehttp_client with empty extensions dict."""
1820-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions={}) as (read_stream, write_stream, _):
1820+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions={}) as (read_stream, write_stream, _):
18211821
async with ClientSession(read_stream, write_stream) as session:
18221822
result = await session.initialize()
18231823
assert isinstance(result, InitializeResult)
18241824

18251825
@pytest.mark.anyio
18261826
async def test_extensions_with_none(self, basic_server: None, basic_server_url: str):
18271827
"""Test streamablehttp_client with None extensions."""
1828-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=None) as (read_stream, write_stream, _):
1828+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=None) as (read_stream, write_stream, _):
18291829
async with ClientSession(read_stream, write_stream) as session:
18301830
result = await session.initialize()
18311831
assert isinstance(result, InitializeResult)
@@ -1887,7 +1887,7 @@ async def test_extensions_isolation_between_clients(self, basic_server: None, ba
18871887
# Create two clients with different extensions
18881888
results: list[tuple[str, str]] = []
18891889

1890-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=extensions_1) as (
1890+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=extensions_1) as (
18911891
read_stream1,
18921892
write_stream1,
18931893
_,
@@ -1896,7 +1896,7 @@ async def test_extensions_isolation_between_clients(self, basic_server: None, ba
18961896
result1 = await session1.initialize()
18971897
results.append(("client1", result1.serverInfo.name))
18981898

1899-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=extensions_2) as (
1899+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=extensions_2) as (
19001900
read_stream2,
19011901
write_stream2,
19021902
_,
@@ -1950,18 +1950,11 @@ async def stream(self, *args: Any, **kwargs: Any):
19501950
async with super().stream(*args, **kwargs) as response:
19511951
yield response
19521952

1953-
# Custom client factory that returns our capturing client
1954-
def custom_client_factory(
1955-
headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None
1956-
) -> httpx.AsyncClient:
1957-
return ExtensionCapturingClient(
1958-
headers=headers,
1959-
timeout=timeout,
1960-
auth=auth,
1961-
)
1962-
1963-
async with streamablehttp_client(
1964-
f"{basic_server_url}/mcp/", extensions=test_extensions, httpx_client_factory=custom_client_factory
1953+
# Create the custom client that will capture extensions
1954+
custom_client = ExtensionCapturingClient()
1955+
1956+
async with streamable_http_client(
1957+
f"{basic_server_url}/mcp/", extensions=test_extensions, http_client=custom_client
19651958
) as (read_stream, write_stream, _):
19661959
async with ClientSession(read_stream, write_stream) as session:
19671960
# Initialize - this should make a POST request with extensions
@@ -1970,6 +1963,9 @@ def custom_client_factory(
19701963
# Make another request to capture more extensions usage
19711964
await session.list_tools()
19721965

1966+
# Close the custom client
1967+
await custom_client.aclose()
1968+
19731969
# Verify extensions were captured in requests
19741970
assert len(captured_extensions) > 0
19751971

@@ -1986,7 +1982,7 @@ async def test_extensions_with_json_and_sse_responses(self, basic_server: None,
19861982
test_extensions = {"response_test": "json_sse_test", "format": "both"}
19871983

19881984
# Test with regular SSE response (default behavior)
1989-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
1985+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
19901986
read_stream,
19911987
write_stream,
19921988
_,
@@ -2010,7 +2006,7 @@ async def test_extensions_with_json_response_server(self, json_response_server:
20102006
"""Test extensions work with JSON response mode."""
20112007
test_extensions = {"response_mode": "json_only", "test_id": "json_test_123"}
20122008

2013-
async with streamablehttp_client(f"{json_server_url}/mcp", extensions=test_extensions) as (
2009+
async with streamable_http_client(f"{json_server_url}/mcp", extensions=test_extensions) as (
20142010
read_stream,
20152011
write_stream,
20162012
_,
@@ -2049,7 +2045,7 @@ async def test_extensions_with_special_characters(self, basic_server: None, basi
20492045
"url_like": "https://example.com/path?param=value",
20502046
}
20512047

2052-
async with streamablehttp_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
2048+
async with streamable_http_client(f"{basic_server_url}/mcp", extensions=test_extensions) as (
20532049
read_stream,
20542050
write_stream,
20552051
_,
@@ -2061,4 +2057,4 @@ async def test_extensions_with_special_characters(self, basic_server: None, basi
20612057

20622058
# Should work normally with tools
20632059
tools = await session.list_tools()
2064-
assert len(tools.tools) == 6
2060+
assert len(tools.tools) == 6

0 commit comments

Comments
 (0)