Skip to content

Commit c2be5af

Browse files
committed
streamable http client
1 parent da1df74 commit c2be5af

File tree

2 files changed

+438
-13
lines changed

2 files changed

+438
-13
lines changed

src/mcp/client/streamableHttp.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""
2+
StreamableHTTP Client Transport Module
3+
4+
This module implements the StreamableHTTP transport for MCP clients,
5+
providing support for HTTP POST requests with optional SSE streaming responses
6+
and session management.
7+
"""
8+
9+
import logging
10+
from contextlib import asynccontextmanager
11+
from typing import Any
12+
13+
import anyio
14+
import httpx
15+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
16+
from httpx_sse import EventSource, aconnect_sse
17+
18+
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest
19+
20+
logger = logging.getLogger(__name__)
21+
22+
# Header names
23+
MCP_SESSION_ID_HEADER = "mcp-session-id"
24+
LAST_EVENT_ID_HEADER = "last-event-id"
25+
26+
# Content types
27+
CONTENT_TYPE_JSON = "application/json"
28+
CONTENT_TYPE_SSE = "text/event-stream"
29+
30+
31+
@asynccontextmanager
32+
async def streamablehttp_client(
33+
url: str,
34+
headers: dict[str, Any] | None = None,
35+
timeout: float = 30,
36+
sse_read_timeout: float = 60 * 5,
37+
):
38+
"""
39+
Client transport for StreamableHTTP.
40+
41+
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
42+
event before disconnecting. All other HTTP operations are controlled by `timeout`.
43+
"""
44+
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
45+
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
46+
47+
write_stream: MemoryObjectSendStream[JSONRPCMessage]
48+
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
49+
50+
read_stream_writer, read_stream = anyio.create_memory_object_stream[
51+
JSONRPCMessage | Exception
52+
](0)
53+
write_stream, write_stream_reader = anyio.create_memory_object_stream[
54+
JSONRPCMessage
55+
](0)
56+
57+
async with anyio.create_task_group() as tg:
58+
try:
59+
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
60+
# Set up headers with required Accept header
61+
request_headers = {
62+
"Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}",
63+
"Content-Type": CONTENT_TYPE_JSON,
64+
**(headers or {}),
65+
}
66+
67+
# Track session ID if provided by server
68+
session_id: str | None = None
69+
70+
async with httpx.AsyncClient(
71+
headers=request_headers, timeout=timeout, follow_redirects=True
72+
) as client:
73+
74+
async def post_writer():
75+
nonlocal session_id
76+
try:
77+
async with write_stream_reader:
78+
async for message in write_stream_reader:
79+
# Add session ID to headers if we have one
80+
post_headers = request_headers.copy()
81+
if session_id:
82+
post_headers[MCP_SESSION_ID_HEADER] = session_id
83+
84+
logger.debug(f"Sending client message: {message}")
85+
86+
# Handle initial initialization request
87+
is_initialization = (
88+
isinstance(message.root, JSONRPCRequest)
89+
and message.root.method == "initialize"
90+
)
91+
if (
92+
isinstance(message.root, JSONRPCNotification)
93+
and message.root.method
94+
== "notifications/initialized"
95+
):
96+
tg.start_soon(get_stream)
97+
98+
async with client.stream(
99+
"POST",
100+
url,
101+
json=message.model_dump(
102+
by_alias=True, mode="json", exclude_none=True
103+
),
104+
headers=post_headers,
105+
) as response:
106+
if response.status_code == 202:
107+
logger.debug("Received 202 Accepted")
108+
continue
109+
# Check for 404 (session expired/invalid)
110+
if response.status_code == 404:
111+
if is_initialization and session_id:
112+
logger.info(
113+
"Session expired, retrying without ID"
114+
)
115+
session_id = None
116+
post_headers.pop(
117+
MCP_SESSION_ID_HEADER, None
118+
)
119+
# Retry with client.stream
120+
async with client.stream(
121+
"POST",
122+
url,
123+
json=message.model_dump(
124+
by_alias=True,
125+
mode="json",
126+
exclude_none=True,
127+
),
128+
headers=post_headers,
129+
) as new_response:
130+
response = new_response
131+
else:
132+
response.raise_for_status()
133+
134+
response.raise_for_status()
135+
136+
# Extract session ID from response headers
137+
if is_initialization:
138+
new_session_id = response.headers.get(
139+
MCP_SESSION_ID_HEADER
140+
)
141+
if new_session_id:
142+
session_id = new_session_id
143+
logger.info(
144+
f"Received session ID: {session_id}"
145+
)
146+
147+
# Handle different response types
148+
content_type = response.headers.get(
149+
"content-type", ""
150+
).lower()
151+
152+
if content_type.startswith(CONTENT_TYPE_JSON):
153+
try:
154+
content = await response.aread()
155+
json_message = (
156+
JSONRPCMessage.model_validate_json(
157+
content
158+
)
159+
)
160+
await read_stream_writer.send(json_message)
161+
except Exception as exc:
162+
logger.error(
163+
f"Error parsing JSON response: {exc}"
164+
)
165+
await read_stream_writer.send(exc)
166+
167+
elif content_type.startswith(CONTENT_TYPE_SSE):
168+
# Parse SSE events from the response
169+
try:
170+
event_source = EventSource(response)
171+
async for sse in event_source.aiter_sse():
172+
if sse.event == "message":
173+
try:
174+
await read_stream_writer.send(
175+
JSONRPCMessage.model_validate_json(
176+
sse.data
177+
)
178+
)
179+
except Exception as exc:
180+
logger.exception(
181+
"Error parsing message"
182+
)
183+
await read_stream_writer.send(
184+
exc
185+
)
186+
else:
187+
logger.warning(
188+
f"Unknown event: {sse.event}"
189+
)
190+
191+
except Exception as e:
192+
logger.exception(
193+
"Error reading SSE stream:"
194+
)
195+
await read_stream_writer.send(e)
196+
197+
else:
198+
# For 202 Accepted with no body
199+
if response.status_code == 202:
200+
logger.debug("Received 202 Accepted")
201+
continue
202+
203+
error_msg = (
204+
f"Unexpected content type: {content_type}"
205+
)
206+
logger.error(error_msg)
207+
await read_stream_writer.send(
208+
ValueError(error_msg)
209+
)
210+
211+
except Exception as exc:
212+
logger.error(f"Error in post_writer: {exc}")
213+
await read_stream_writer.send(exc)
214+
finally:
215+
await read_stream_writer.aclose()
216+
await write_stream.aclose()
217+
218+
async def get_stream():
219+
"""
220+
Optional GET stream for server-initiated messages
221+
"""
222+
nonlocal session_id
223+
try:
224+
# Only attempt GET if we have a session ID
225+
if not session_id:
226+
return
227+
228+
get_headers = request_headers.copy()
229+
get_headers[MCP_SESSION_ID_HEADER] = session_id
230+
231+
async with aconnect_sse(
232+
client, "GET", url, headers=get_headers
233+
) as event_source:
234+
event_source.response.raise_for_status()
235+
logger.debug("GET SSE connection established")
236+
237+
async for sse in event_source.aiter_sse():
238+
if sse.event == "message":
239+
try:
240+
message = JSONRPCMessage.model_validate_json(
241+
sse.data
242+
)
243+
logger.debug(f"GET message: {message}")
244+
await read_stream_writer.send(message)
245+
except Exception as exc:
246+
logger.error(
247+
f"Error parsing GET message: {exc}"
248+
)
249+
await read_stream_writer.send(exc)
250+
else:
251+
logger.warning(
252+
f"Unknown SSE event from GET: {sse.event}"
253+
)
254+
except Exception as exc:
255+
# GET stream is optional, so don't propagate errors
256+
logger.debug(f"GET stream error (non-fatal): {exc}")
257+
258+
tg.start_soon(post_writer)
259+
260+
try:
261+
yield read_stream, write_stream
262+
finally:
263+
tg.cancel_scope.cancel()
264+
finally:
265+
await read_stream_writer.aclose()
266+
await write_stream.aclose()

0 commit comments

Comments
 (0)