Skip to content

Commit 7329cba

Browse files
committed
Refactor code to send resume as part of join call rather than it, this results in the response being consumed prior to the join, also added a capability that identifies whether the server/transport supports resumption that is passed during initialisation
1 parent aa2cbec commit 7329cba

File tree

6 files changed

+299
-75
lines changed

6 files changed

+299
-75
lines changed

src/mcp/client/session.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
self._logging_callback = logging_callback or _default_logging_callback
136136
self._message_handler = message_handler or _default_message_handler
137137
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
138+
self._resumable = False
138139

139140
async def initialize(self) -> types.InitializeResult:
140141
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -172,24 +173,12 @@ async def initialize(self) -> types.InitializeResult:
172173
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
173174
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
174175

176+
self._resumable = result.capabilities.resume and result.capabilities.resume.resumable
177+
175178
await self.send_notification(
176179
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
177180
)
178181

179-
resume_token = await self._request_state_manager.get_resume_token()
180-
if resume_token:
181-
metadata = ClientMessageMetadata(resumption_token=resume_token)
182-
timeout = None
183-
if self._session_read_timeout_seconds is not None:
184-
timeout = self._session_read_timeout_seconds.total_seconds()
185-
186-
await self.send_request(
187-
request=types.PingRequest(method="ping"), # type: ignore
188-
result_type=types.EmptyResult,
189-
request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout),
190-
metadata=metadata,
191-
)
192-
193182
return result
194183

195184
async def send_ping(self) -> types.EmptyResult:
@@ -303,21 +292,58 @@ async def request_call_tool(
303292
arguments: dict[str, Any] | None = None,
304293
progress_callback: ProgressFnT | None = None,
305294
) -> types.RequestId:
306-
metadata = ClientMessageMetadata(on_resumption_token_update=self._request_state_manager.update_resume_token)
307-
308-
return await self.start_request(
309-
types.ClientRequest(
310-
types.CallToolRequest(
311-
method="tools/call",
312-
params=types.CallToolRequestParams(
313-
name=name,
314-
arguments=arguments,
295+
if self._resumable:
296+
send_stream, receive_stream = anyio.create_memory_object_stream[str](1)
297+
298+
async def close() -> None:
299+
await send_stream.aclose()
300+
await receive_stream.aclose()
301+
302+
self._exit_stack.push_async_callback(close)
303+
304+
with send_stream, receive_stream:
305+
306+
async def send_token(token: str):
307+
try:
308+
await send_stream.send(token)
309+
except anyio.BrokenResourceError as e:
310+
raise e
311+
312+
metadata = ClientMessageMetadata(on_resumption_token_update=send_token)
313+
314+
request_id = await self.start_request(
315+
types.ClientRequest(
316+
types.CallToolRequest(
317+
method="tools/call",
318+
params=types.CallToolRequestParams(
319+
name=name,
320+
arguments=arguments,
321+
),
322+
)
315323
),
324+
progress_callback=progress_callback,
325+
metadata=metadata,
316326
)
317-
),
318-
progress_callback=progress_callback,
319-
metadata=metadata,
320-
)
327+
328+
await anyio.lowlevel.checkpoint()
329+
330+
token = await receive_stream.receive()
331+
await self._request_state_manager.update_resume_token(request_id, token)
332+
333+
return request_id
334+
else:
335+
return await self.start_request(
336+
types.ClientRequest(
337+
types.CallToolRequest(
338+
method="tools/call",
339+
params=types.CallToolRequestParams(
340+
name=name,
341+
arguments=arguments,
342+
),
343+
)
344+
),
345+
progress_callback=progress_callback,
346+
)
321347

322348
async def join_call_tool(
323349
self,

src/mcp/client/streamable_http.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
JSONRPCRequest,
3030
JSONRPCResponse,
3131
RequestId,
32+
ResumeCapability,
3233
)
3334

3435
logger = logging.getLogger(__name__)
@@ -136,18 +137,26 @@ def _maybe_extract_session_id_from_response(
136137
def _maybe_extract_protocol_version_from_message(
137138
self,
138139
message: JSONRPCMessage,
139-
) -> None:
140+
) -> JSONRPCMessage:
140141
"""Extract protocol version from initialization response message."""
141142
if isinstance(message.root, JSONRPCResponse) and message.root.result:
142143
try:
143144
# Parse the result as InitializeResult for type safety
144145
init_result = InitializeResult.model_validate(message.root.result)
145146
self.protocol_version = str(init_result.protocolVersion)
146147
logger.info(f"Negotiated protocol version: {self.protocol_version}")
148+
if init_result.capabilities.resume is None:
149+
# resumeablity is predicated on the server and the transport
150+
# this assumes that if the server hasn't explicitly configured
151+
# that streamable http transports are resumeable
152+
init_result.capabilities.resume = ResumeCapability(resumable=True)
153+
message.root.result = init_result.model_dump()
147154
except Exception as exc:
148155
logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}")
149156
logger.warning(f"Raw result: {message.root.result}")
150157

158+
return message
159+
151160
async def _handle_sse_event(
152161
self,
153162
sse: ServerSentEvent,
@@ -183,7 +192,10 @@ async def _handle_sse_event(
183192

184193
except Exception as exc:
185194
logger.exception("Error parsing SSE message")
186-
await read_stream_writer.send(exc)
195+
try:
196+
await read_stream_writer.send(exc)
197+
except anyio.BrokenResourceError:
198+
pass
187199
return False
188200
else:
189201
logger.warning(f"Unknown SSE event: {sse.event}")
@@ -303,7 +315,7 @@ async def _handle_json_response(
303315

304316
# Extract protocol version from initialization response
305317
if is_initialization:
306-
self._maybe_extract_protocol_version_from_message(message)
318+
message = self._maybe_extract_protocol_version_from_message(message)
307319

308320
session_message = SessionMessage(message)
309321
await read_stream_writer.send(session_message)
@@ -333,7 +345,10 @@ async def _handle_sse_response(
333345
break
334346
except Exception as e:
335347
logger.exception("Error reading SSE stream:")
336-
await ctx.read_stream_writer.send(e)
348+
try:
349+
await ctx.read_stream_writer.send(e)
350+
except anyio.ClosedResourceError:
351+
pass
337352

338353
async def _handle_unexpected_content_type(
339354
self,

0 commit comments

Comments
 (0)