Skip to content

Commit d64d191

Browse files
Fix StreamableHTTP transport API for backwards compatibility and cleaner header handling
This commit addresses three issues identified in PR review: 1. **Restore RequestContext fields for backwards compatibility** - Re-add `headers` and `sse_read_timeout` fields as optional with None defaults - Mark them as deprecated in docstring since they're no longer used internally - Prevents breaking changes for any code accessing these fields 2. **Add runtime deprecation warnings for StreamableHTTPTransport constructor** - Use sentinel value pattern to detect when deprecated parameters are passed - Issue DeprecationWarning at runtime when headers, timeout, sse_read_timeout, or auth are provided - Complements existing @deprecated decorator for type checkers with actual runtime warnings - Improve deprecation message clarity 3. **Simplify header handling by removing redundant client parameter** - Remove `client` parameter from `_prepare_headers()` method - Stop extracting and re-passing client.headers since httpx automatically merges them - Only build MCP-specific headers (Accept, Content-Type, session headers) - httpx merges these with client.headers automatically, with our headers taking precedence - Reduces code complexity and eliminates unnecessary header extraction The header handling change leverages httpx's built-in header merging behavior, similar to how headers were handled before the refactoring but without the redundant extraction-and-repass pattern.
1 parent 862c22f commit d64d191

File tree

3 files changed

+92
-39
lines changed

3 files changed

+92
-39
lines changed

src/mcp/client/streamable_http.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@
6262
JSON = "application/json"
6363
SSE = "text/event-stream"
6464

65+
# Sentinel value for detecting unset optional parameters
66+
_UNSET = object()
67+
6568

6669
class StreamableHTTPError(Exception):
6770
"""Base exception for StreamableHTTP transport errors."""
@@ -81,7 +84,8 @@ class RequestContext:
8184
session_message: SessionMessage
8285
metadata: ClientMessageMetadata | None
8386
read_stream_writer: StreamWriter
84-
sse_read_timeout: float
87+
headers: dict[str, str] | None = None # Deprecated - no longer used
88+
sse_read_timeout: float | None = None # Deprecated - no longer used
8589

8690

8791
class StreamableHTTPTransport:
@@ -90,8 +94,11 @@ class StreamableHTTPTransport:
9094
@overload
9195
def __init__(self, url: str) -> None: ...
9296

93-
@deprecated("Those parameters are deprecated. Use the url parameter instead.")
9497
@overload
98+
@deprecated(
99+
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
100+
"Configure these on the httpx.AsyncClient instead."
101+
)
95102
def __init__(
96103
self,
97104
url: str,
@@ -104,11 +111,10 @@ def __init__(
104111
def __init__(
105112
self,
106113
url: str,
107-
headers: dict[str, str] | None = None,
108-
timeout: float | timedelta = 30,
109-
sse_read_timeout: float | timedelta = 60 * 5,
110-
auth: httpx.Auth | None = None,
111-
**deprecated: dict[str, Any],
114+
headers: Any = _UNSET,
115+
timeout: Any = _UNSET,
116+
sse_read_timeout: Any = _UNSET,
117+
auth: Any = _UNSET,
112118
) -> None:
113119
"""Initialize the StreamableHTTP transport.
114120
@@ -119,26 +125,40 @@ def __init__(
119125
sse_read_timeout: Timeout for SSE read operations.
120126
auth: Optional HTTPX authentication handler.
121127
"""
122-
if deprecated:
123-
warn(f"Deprecated parameters: {deprecated}", DeprecationWarning)
128+
# Check for deprecated parameters and issue runtime warning
129+
deprecated_params: list[str] = []
130+
if headers is not _UNSET:
131+
deprecated_params.append("headers")
132+
if timeout is not _UNSET:
133+
deprecated_params.append("timeout")
134+
if sse_read_timeout is not _UNSET:
135+
deprecated_params.append("sse_read_timeout")
136+
if auth is not _UNSET:
137+
deprecated_params.append("auth")
138+
139+
if deprecated_params:
140+
warn(
141+
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
142+
"Configure these on the httpx.AsyncClient instead.",
143+
DeprecationWarning,
144+
stacklevel=2,
145+
)
146+
124147
self.url = url
125-
self.headers = headers or {}
126-
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
127-
self.sse_read_timeout = (
128-
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
129-
)
130-
self.auth = auth
131148
self.session_id = None
132149
self.protocol_version = None
133-
self.request_headers = {
134-
**self.headers,
135-
ACCEPT: f"{JSON}, {SSE}",
136-
CONTENT_TYPE: JSON,
137-
}
138-
139-
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
140-
"""Update headers with session ID and protocol version if available."""
141-
headers = base_headers.copy()
150+
151+
def _prepare_headers(self) -> dict[str, str]:
152+
"""Build MCP-specific request headers.
153+
154+
These headers will be merged with the httpx.AsyncClient's default headers,
155+
with these MCP-specific headers taking precedence.
156+
"""
157+
headers: dict[str, str] = {}
158+
# Add MCP protocol headers
159+
headers[ACCEPT] = f"{JSON}, {SSE}"
160+
headers[CONTENT_TYPE] = JSON
161+
# Add session headers if available
142162
if self.session_id:
143163
headers[MCP_SESSION_ID] = self.session_id
144164
if self.protocol_version:
@@ -242,7 +262,7 @@ async def handle_get_stream(
242262
if not self.session_id:
243263
return
244264

245-
headers = self._prepare_request_headers(self.request_headers)
265+
headers = self._prepare_headers()
246266
if last_event_id:
247267
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover
248268

@@ -251,7 +271,6 @@ async def handle_get_stream(
251271
"GET",
252272
self.url,
253273
headers=headers,
254-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
255274
) as event_source:
256275
event_source.response.raise_for_status()
257276
logger.debug("GET SSE connection established")
@@ -284,7 +303,7 @@ async def handle_get_stream(
284303

285304
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
286305
"""Handle a resumption request using GET with SSE."""
287-
headers = self._prepare_request_headers(ctx.headers)
306+
headers = self._prepare_headers()
288307
if ctx.metadata and ctx.metadata.resumption_token:
289308
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
290309
else:
@@ -300,7 +319,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
300319
"GET",
301320
self.url,
302321
headers=headers,
303-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
304322
) as event_source:
305323
event_source.response.raise_for_status()
306324
logger.debug("Resumption GET SSE connection established")
@@ -318,7 +336,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
318336

319337
async def _handle_post_request(self, ctx: RequestContext) -> None:
320338
"""Handle a POST request with response processing."""
321-
headers = self._prepare_request_headers(ctx.headers)
339+
headers = self._prepare_headers()
322340
message = ctx.session_message.message
323341
is_initialization = self._is_initialization_request(message)
324342

@@ -436,7 +454,7 @@ async def _handle_reconnection(
436454
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
437455
await anyio.sleep(delay_ms / 1000.0)
438456

439-
headers = self._prepare_request_headers(ctx.headers)
457+
headers = self._prepare_headers()
440458
headers[LAST_EVENT_ID] = last_event_id
441459

442460
# Extract original request ID to map responses
@@ -450,7 +468,6 @@ async def _handle_reconnection(
450468
"GET",
451469
self.url,
452470
headers=headers,
453-
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
454471
) as event_source:
455472
event_source.response.raise_for_status()
456473
logger.info("Reconnected to SSE stream")
@@ -538,12 +555,10 @@ async def post_writer(
538555

539556
ctx = RequestContext(
540557
client=client,
541-
headers=self.request_headers,
542558
session_id=self.session_id,
543559
session_message=session_message,
544560
metadata=metadata,
545561
read_stream_writer=read_stream_writer,
546-
sse_read_timeout=self.sse_read_timeout,
547562
)
548563

549564
async def handle_request_async():
@@ -570,7 +585,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma:
570585
return
571586

572587
try:
573-
headers = self._prepare_request_headers(self.request_headers)
588+
headers = self._prepare_headers()
574589
response = await client.delete(self.url, headers=headers)
575590

576591
if response.status_code == 405:
@@ -678,8 +693,8 @@ def start_get_stream() -> None:
678693
await write_stream.aclose()
679694

680695

681-
@deprecated("Use `streamable_http_client` instead.")
682696
@asynccontextmanager
697+
@deprecated("Use `streamable_http_client` instead.")
683698
async def streamablehttp_client(
684699
url: str,
685700
headers: dict[str, str] | None = None,

tests/client/test_http_unicode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
from mcp.client.session import ClientSession
15-
from mcp.client.streamable_http import streamablehttp_client
15+
from mcp.client.streamable_http import streamable_http_client
1616
from tests.test_helpers import wait_for_server
1717

1818
# Test constants with various Unicode characters
@@ -178,7 +178,7 @@ async def test_streamable_http_client_unicode_tool_call(running_unicode_server:
178178
base_url = running_unicode_server
179179
endpoint_url = f"{base_url}/mcp"
180180

181-
async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
181+
async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
182182
async with ClientSession(read_stream, write_stream) as session:
183183
await session.initialize()
184184

@@ -210,7 +210,7 @@ async def test_streamable_http_client_unicode_prompts(running_unicode_server: st
210210
base_url = running_unicode_server
211211
endpoint_url = f"{base_url}/mcp"
212212

213-
async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
213+
async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
214214
async with ClientSession(read_stream, write_stream) as session:
215215
await session.initialize()
216216

tests/shared/test_streamable_http.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import socket
1010
import time
1111
from collections.abc import Generator
12+
from datetime import timedelta
1213
from typing import Any
1314
from unittest.mock import MagicMock
1415

@@ -25,7 +26,11 @@
2526

2627
import mcp.types as types
2728
from mcp.client.session import ClientSession
28-
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
29+
from mcp.client.streamable_http import (
30+
StreamableHTTPTransport,
31+
streamable_http_client,
32+
streamablehttp_client, # pyright: ignore[reportDeprecated]
33+
)
2934
from mcp.server import Server
3035
from mcp.server.streamable_http import (
3136
MCP_PROTOCOL_VERSION_HEADER,
@@ -2356,3 +2361,36 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(
23562361

23572362
assert "content-type" in headers_data
23582363
assert headers_data["content-type"] == "application/json"
2364+
2365+
2366+
@pytest.mark.anyio
2367+
async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None:
2368+
"""Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored."""
2369+
with pytest.warns(DeprecationWarning):
2370+
transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated]
2371+
url=f"{basic_server_url}/mcp",
2372+
headers={"X-Should-Be-Ignored": "ignored"},
2373+
timeout=999,
2374+
sse_read_timeout=timedelta(seconds=999),
2375+
auth=None,
2376+
)
2377+
2378+
headers = transport._prepare_headers()
2379+
assert "X-Should-Be-Ignored" not in headers
2380+
assert headers["accept"] == "application/json, text/event-stream"
2381+
assert headers["content-type"] == "application/json"
2382+
2383+
2384+
@pytest.mark.anyio
2385+
async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None:
2386+
"""Test that the old streamablehttp_client() function issues a deprecation warning."""
2387+
with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"):
2388+
async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated]
2389+
read_stream,
2390+
write_stream,
2391+
_,
2392+
):
2393+
async with ClientSession(read_stream, write_stream) as session:
2394+
await session.initialize()
2395+
tools = await session.list_tools()
2396+
assert len(tools.tools) > 0

0 commit comments

Comments
 (0)