Skip to content

Commit 90a647a

Browse files
committed
RDCIST-3853: Add support extensions - stream
1 parent 5983a65 commit 90a647a

File tree

1 file changed

+51
-35
lines changed

1 file changed

+51
-35
lines changed

src/mcp/client/streamable_http.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class RequestContext:
6464

6565
client: httpx.AsyncClient
6666
headers: dict[str, str]
67+
extensions: dict[str, str] | None
6768
session_id: str | None
6869
session_message: SessionMessage
6970
metadata: ClientMessageMetadata | None
@@ -78,6 +79,7 @@ def __init__(
7879
self,
7980
url: str,
8081
headers: dict[str, str] | None = None,
82+
extensions: dict[str, str] | None = None,
8183
timeout: float | timedelta = 30,
8284
sse_read_timeout: float | timedelta = 60 * 5,
8385
auth: httpx.Auth | None = None,
@@ -87,12 +89,14 @@ def __init__(
8789
Args:
8890
url: The endpoint URL.
8991
headers: Optional headers to include in requests.
92+
extensions: Optional extensions to include in requests.
9093
timeout: HTTP timeout for regular operations.
9194
sse_read_timeout: Timeout for SSE read operations.
9295
auth: Optional HTTPX authentication handler.
9396
"""
9497
self.url = url
9598
self.headers = headers or {}
99+
self.extensions = extensions.copy() if extensions else {}
96100
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
97101
self.sse_read_timeout = (
98102
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
@@ -115,6 +119,12 @@ def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, st
115119
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
116120
return headers
117121

122+
def _prepare_request_extensions(self, base_extensions: dict[str, str] | None) -> dict[str, str]:
123+
"""Update extensions with session-specific data if available."""
124+
extensions = base_extensions.copy() if base_extensions else {}
125+
# Add any session-specific extensions here if needed
126+
return extensions
127+
118128
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
119129
"""Check if the message is an initialization request."""
120130
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
@@ -138,16 +148,14 @@ def _maybe_extract_protocol_version_from_message(
138148
message: JSONRPCMessage,
139149
) -> None:
140150
"""Extract protocol version from initialization response message."""
141-
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
151+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
142152
try:
143153
# Parse the result as InitializeResult for type safety
144154
init_result = InitializeResult.model_validate(message.root.result)
145155
self.protocol_version = str(init_result.protocolVersion)
146156
logger.info(f"Negotiated protocol version: {self.protocol_version}")
147-
except Exception as exc: # pragma: no cover
148-
logger.warning(
149-
f"Failed to parse initialization response as InitializeResult: {exc}"
150-
) # pragma: no cover
157+
except Exception as exc:
158+
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
151159
logger.warning(f"Raw result: {message.root.result}")
152160

153161
async def _handle_sse_event(
@@ -160,9 +168,6 @@ async def _handle_sse_event(
160168
) -> bool:
161169
"""Handle an SSE event, returning True if the response is complete."""
162170
if sse.event == "message":
163-
# Skip empty data (keep-alive pings)
164-
if not sse.data:
165-
return False
166171
try:
167172
message = JSONRPCMessage.model_validate_json(sse.data)
168173
logger.debug(f"SSE message: {message}")
@@ -186,11 +191,11 @@ async def _handle_sse_event(
186191
# Otherwise, return False to continue listening
187192
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
188193

189-
except Exception as exc: # pragma: no cover
194+
except Exception as exc:
190195
logger.exception("Error parsing SSE message")
191196
await read_stream_writer.send(exc)
192197
return False
193-
else: # pragma: no cover
198+
else:
194199
logger.warning(f"Unknown SSE event: {sse.event}")
195200
return False
196201

@@ -220,19 +225,19 @@ async def handle_get_stream(
220225
await self._handle_sse_event(sse, read_stream_writer)
221226

222227
except Exception as exc:
223-
logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover
228+
logger.debug(f"GET stream error (non-fatal): {exc}")
224229

225230
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
226231
"""Handle a resumption request using GET with SSE."""
227232
headers = self._prepare_request_headers(ctx.headers)
228233
if ctx.metadata and ctx.metadata.resumption_token:
229234
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
230235
else:
231-
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
236+
raise ResumptionError("Resumption request requires a resumption token")
232237

233238
# Extract original request ID to map responses
234239
original_request_id = None
235-
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
240+
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
236241
original_request_id = ctx.session_message.message.root.id
237242

238243
async with aconnect_sse(
@@ -245,7 +250,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
245250
event_source.response.raise_for_status()
246251
logger.debug("Resumption GET SSE connection established")
247252

248-
async for sse in event_source.aiter_sse(): # pragma: no branch
253+
async for sse in event_source.aiter_sse():
249254
is_complete = await self._handle_sse_event(
250255
sse,
251256
ctx.read_stream_writer,
@@ -259,6 +264,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
259264
async def _handle_post_request(self, ctx: RequestContext) -> None:
260265
"""Handle a POST request with response processing."""
261266
headers = self._prepare_request_headers(ctx.headers)
267+
extensions = self._prepare_request_extensions(ctx.extensions)
262268
message = ctx.session_message.message
263269
is_initialization = self._is_initialization_request(message)
264270

@@ -267,18 +273,19 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
267273
self.url,
268274
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
269275
headers=headers,
276+
extensions=extensions,
270277
) as response:
271278
if response.status_code == 202:
272279
logger.debug("Received 202 Accepted")
273280
return
274281

275-
if response.status_code == 404: # pragma: no branch
282+
if response.status_code == 404:
276283
if isinstance(message.root, JSONRPCRequest):
277-
await self._send_session_terminated_error( # pragma: no cover
278-
ctx.read_stream_writer, # pragma: no cover
279-
message.root.id, # pragma: no cover
280-
) # pragma: no cover
281-
return # pragma: no cover
284+
await self._send_session_terminated_error(
285+
ctx.read_stream_writer,
286+
message.root.id,
287+
)
288+
return
282289

283290
response.raise_for_status()
284291
if is_initialization:
@@ -293,10 +300,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
293300
elif content_type.startswith(SSE):
294301
await self._handle_sse_response(response, ctx, is_initialization)
295302
else:
296-
await self._handle_unexpected_content_type( # pragma: no cover
297-
content_type, # pragma: no cover
298-
ctx.read_stream_writer, # pragma: no cover
299-
) # pragma: no cover
303+
await self._handle_unexpected_content_type(
304+
content_type,
305+
ctx.read_stream_writer,
306+
)
300307

301308
async def _handle_json_response(
302309
self,
@@ -315,7 +322,7 @@ async def _handle_json_response(
315322

316323
session_message = SessionMessage(message)
317324
await read_stream_writer.send(session_message)
318-
except Exception as exc: # pragma: no cover
325+
except Exception as exc:
319326
logger.exception("Error parsing JSON response")
320327
await read_stream_writer.send(exc)
321328

@@ -328,7 +335,7 @@ async def _handle_sse_response(
328335
"""Handle SSE response from the server."""
329336
try:
330337
event_source = EventSource(response)
331-
async for sse in event_source.aiter_sse(): # pragma: no branch
338+
async for sse in event_source.aiter_sse():
332339
is_complete = await self._handle_sse_event(
333340
sse,
334341
ctx.read_stream_writer,
@@ -341,18 +348,18 @@ async def _handle_sse_response(
341348
await response.aclose()
342349
break
343350
except Exception as e:
344-
logger.exception("Error reading SSE stream:") # pragma: no cover
345-
await ctx.read_stream_writer.send(e) # pragma: no cover
351+
logger.exception("Error reading SSE stream:")
352+
await ctx.read_stream_writer.send(e)
346353

347354
async def _handle_unexpected_content_type(
348355
self,
349356
content_type: str,
350357
read_stream_writer: StreamWriter,
351-
) -> None: # pragma: no cover
358+
) -> None:
352359
"""Handle unexpected content type in response."""
353-
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
354-
logger.error(error_msg) # pragma: no cover
355-
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
360+
error_msg = f"Unexpected content type: {content_type}"
361+
logger.error(error_msg)
362+
await read_stream_writer.send(ValueError(error_msg))
356363

357364
async def _send_session_terminated_error(
358365
self,
@@ -400,6 +407,7 @@ async def post_writer(
400407
ctx = RequestContext(
401408
client=client,
402409
headers=self.request_headers,
410+
extensions=self.extensions,
403411
session_id=self.session_id,
404412
session_message=session_message,
405413
metadata=metadata,
@@ -420,12 +428,12 @@ async def handle_request_async():
420428
await handle_request_async()
421429

422430
except Exception:
423-
logger.exception("Error in post_writer") # pragma: no cover
431+
logger.exception("Error in post_writer")
424432
finally:
425433
await read_stream_writer.aclose()
426434
await write_stream.aclose()
427435

428-
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
436+
async def terminate_session(self, client: httpx.AsyncClient) -> None:
429437
"""Terminate the session by sending a DELETE request."""
430438
if not self.session_id:
431439
return
@@ -450,6 +458,7 @@ def get_session_id(self) -> str | None:
450458
async def streamablehttp_client(
451459
url: str,
452460
headers: dict[str, str] | None = None,
461+
extensions: dict[str, str] | None = None,
453462
timeout: float | timedelta = 30,
454463
sse_read_timeout: float | timedelta = 60 * 5,
455464
terminate_on_close: bool = True,
@@ -475,7 +484,14 @@ async def streamablehttp_client(
475484
- write_stream: Stream for sending messages to the server
476485
- get_session_id_callback: Function to retrieve the current session ID
477486
"""
478-
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
487+
transport = StreamableHTTPTransport(
488+
url=url,
489+
headers=headers,
490+
extensions=extensions,
491+
timeout=timeout,
492+
sse_read_timeout=sse_read_timeout,
493+
auth=auth,
494+
)
479495

480496
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
481497
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

0 commit comments

Comments
 (0)