Skip to content

Commit 288ebe3

Browse files
committed
add resume logic to request/join call_tool functions
1 parent 04ff73a commit 288ebe3

File tree

3 files changed

+222
-50
lines changed

3 files changed

+222
-50
lines changed

src/mcp/client/session.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import mcp.types as types
1111
from mcp.shared.context import RequestContext
12-
from mcp.shared.message import SessionMessage
12+
from mcp.shared.message import ClientMessageMetadata, SessionMessage
1313
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder, RequestStateManager
1414
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1515

@@ -289,7 +289,11 @@ async def request_call_tool(
289289
arguments: dict[str, Any] | None = None,
290290
progress_callback: ProgressFnT | None = None,
291291
) -> types.RequestId:
292-
return await self.start_request(
292+
write, read = anyio.create_memory_object_stream[str]()
293+
294+
metadata = ClientMessageMetadata(on_resumption_token_update=write.send)
295+
296+
request_id = await self.start_request(
293297
types.ClientRequest(
294298
types.CallToolRequest(
295299
method="tools/call",
@@ -300,18 +304,39 @@ async def request_call_tool(
300304
)
301305
),
302306
progress_callback=progress_callback,
307+
metadata=metadata,
303308
)
304309

310+
async def update_token() -> None:
311+
try:
312+
async for token in read:
313+
self._request_state_manager.update_resume_token(request_id, token)
314+
except anyio.ClosedResourceError:
315+
pass
316+
317+
async def close() -> None:
318+
await write.aclose()
319+
await read.aclose()
320+
321+
self._exit_stack.push_async_callback(update_token)
322+
self._exit_stack.push_async_callback(close)
323+
324+
return request_id
325+
305326
async def join_call_tool(
306327
self,
307328
request_id: types.RequestId,
308329
progress_callback: ProgressFnT | None = None,
309330
request_read_timeout_seconds: timedelta | None = None,
310331
done_on_timeout: bool = True,
311332
) -> types.CallToolResult | None:
333+
resume_token = self._request_state_manager.get_resume_token(request_id)
334+
metadata = ClientMessageMetadata(resumption_token=resume_token)
335+
312336
return await self.join_request(
313337
request_id,
314338
types.CallToolResult,
339+
metadata=metadata,
315340
request_read_timeout_seconds=request_read_timeout_seconds,
316341
progress_callback=progress_callback,
317342
done_on_timeout=done_on_timeout,

src/mcp/shared/session.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
ClientNotification,
2222
ClientRequest,
2323
ClientResult,
24+
EmptyResult,
2425
ErrorData,
2526
JSONRPCError,
2627
JSONRPCMessage,
2728
JSONRPCNotification,
2829
JSONRPCRequest,
2930
JSONRPCResponse,
31+
PingRequest,
3032
ProgressNotification,
3133
RequestParams,
3234
ServerNotification,
@@ -165,6 +167,10 @@ class RequestStateManager(
165167
):
166168
def new_request(self, request: SendRequestT) -> RequestId: ...
167169

170+
def update_resume_token(self, request_id: RequestId, token: str) -> None: ...
171+
172+
def get_resume_token(self, request_id: RequestId) -> str | None: ...
173+
168174
def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ...
169175

170176
async def send_progress(
@@ -204,11 +210,13 @@ class ImMemoryRequestStateManager(
204210
],
205211
]
206212
_progress_callbacks: dict[RequestId, list[ProgressFnT]]
213+
_resume_tokens: dict[RequestId, str]
207214

208215
def __init__(self):
209216
self._request_id = 0
210217
self._response_streams = {}
211218
self._progress_callbacks = {}
219+
self._resume_tokens = {}
212220

213221
def new_request(self, request: SendRequestT) -> RequestId:
214222
request_id = self._request_id
@@ -219,6 +227,12 @@ def new_request(self, request: SendRequestT) -> RequestId:
219227

220228
return request_id
221229

230+
def update_resume_token(self, request_id: RequestId, token: str) -> None:
231+
self._resume_tokens[request_id] = token
232+
233+
def get_resume_token(self, request_id: RequestId) -> str | None:
234+
return self._resume_tokens.get(request_id)
235+
222236
def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT):
223237
progress_list = self._progress_callbacks.get(request_id)
224238
if progress_list is None:
@@ -289,6 +303,7 @@ async def close_request(self, request_id: RequestId) -> bool:
289303
await response_stream_reader.aclose()
290304

291305
self._progress_callbacks.pop(request_id, None)
306+
self._resume_tokens.pop(request_id, None)
292307

293308
return response_stream is not None
294309

@@ -373,38 +388,17 @@ async def start_request(
373388
instead.
374389
"""
375390
request_id = self._request_state_manager.new_request(request)
376-
377-
# Set up progress token if progress callback is provided
378-
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
379-
if progress_callback is not None:
380-
# Use request_id as progress token
381-
if "params" not in request_data:
382-
request_data["params"] = {}
383-
if "_meta" not in request_data["params"]:
384-
request_data["params"]["_meta"] = {}
385-
request_data["params"]["_meta"]["progressToken"] = request_id
386-
# Store the callback for this request
387-
self._request_state_manager.add_progress_callback(request_id, progress_callback)
388-
389-
jsonrpc_request = JSONRPCRequest(
390-
jsonrpc="2.0",
391-
id=request_id,
392-
**request_data,
391+
return await self._send_request(
392+
request_id=request_id, request=request, metadata=metadata, progress_callback=progress_callback
393393
)
394394

395-
try:
396-
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
397-
return request_id
398-
except Exception as e:
399-
await self._request_state_manager.close_request(request_id)
400-
raise e
401-
402395
async def join_request(
403396
self,
404397
request_id: RequestId,
405398
result_type: type[ReceiveResultT],
406399
request_read_timeout_seconds: timedelta | None = None,
407400
progress_callback: ProgressFnT | None = None,
401+
metadata: MessageMetadata | None = None,
408402
done_on_timeout: bool = True,
409403
) -> ReceiveResultT:
410404
"""
@@ -420,6 +414,15 @@ async def join_request(
420414
elif self._session_read_timeout_seconds is not None:
421415
timeout = self._session_read_timeout_seconds.total_seconds()
422416

417+
if metadata:
418+
# need to resend metadata - primary use case is client resume support
419+
await self.send_request(
420+
request=PingRequest(method="ping"), # type: ignore
421+
result_type=EmptyResult,
422+
request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout),
423+
metadata=metadata,
424+
)
425+
423426
response_or_error = await self._request_state_manager.receive_response(request_id, timeout)
424427

425428
if isinstance(response_or_error, JSONRPCError):
@@ -433,6 +436,38 @@ async def join_request(
433436
await self._request_state_manager.close_request(request_id)
434437
return result_type.model_validate(response_or_error.result)
435438

439+
async def _send_request(
440+
self,
441+
request_id: RequestId,
442+
request: SendRequestT,
443+
metadata: MessageMetadata = None,
444+
progress_callback: ProgressFnT | None = None,
445+
):
446+
# Set up progress token if progress callback is provided
447+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
448+
if progress_callback is not None:
449+
# Use request_id as progress token
450+
if "params" not in request_data:
451+
request_data["params"] = {}
452+
if "_meta" not in request_data["params"]:
453+
request_data["params"]["_meta"] = {}
454+
request_data["params"]["_meta"]["progressToken"] = request_id
455+
# Store the callback for this request
456+
self._request_state_manager.add_progress_callback(request_id, progress_callback)
457+
458+
jsonrpc_request = JSONRPCRequest(
459+
jsonrpc="2.0",
460+
id=request_id,
461+
**request_data,
462+
)
463+
464+
try:
465+
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
466+
return request_id
467+
except Exception as e:
468+
await self._request_state_manager.close_request(request_id)
469+
raise e
470+
436471
async def cancel_request(self, request_id: RequestId) -> bool:
437472
"""
438473
Cancels a request previously started via start_request

0 commit comments

Comments
 (0)