Skip to content

Commit 161da46

Browse files
committed
send resume on init rather than part of join, refactor resume to be global to session rather than per request (read the spec)
1 parent 40028da commit 161da46

File tree

3 files changed

+57
-193
lines changed

3 files changed

+57
-193
lines changed

src/mcp/client/session.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,20 @@ async def initialize(self) -> types.InitializeResult:
176176
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
177177
)
178178

179+
resume_token = await self._request_state_manager.get_resume_token()
180+
if resume_token:
181+
metadata = ClientMessageMetadata(resumption_token=resume_token)
182+
timeout = None
183+
if self._session_read_timeout_seconds is not None:
184+
timeout = self._session_read_timeout_seconds.total_seconds()
185+
186+
await self.send_request(
187+
request=PingRequest(method="ping"), # type: ignore
188+
result_type=types.EmptyResult,
189+
request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout),
190+
metadata=metadata,
191+
)
192+
179193
return result
180194

181195
async def send_ping(self) -> types.EmptyResult:
@@ -289,11 +303,9 @@ async def request_call_tool(
289303
arguments: dict[str, Any] | None = None,
290304
progress_callback: ProgressFnT | None = None,
291305
) -> types.RequestId:
292-
write, read = anyio.create_memory_object_stream[str]()
293-
294-
metadata = ClientMessageMetadata(on_resumption_token_update=write.send)
306+
metadata = ClientMessageMetadata(on_resumption_token_update=self._request_state_manager.update_resume_token)
295307

296-
request_id = await self.start_request(
308+
return await self.start_request(
297309
types.ClientRequest(
298310
types.CallToolRequest(
299311
method="tools/call",
@@ -307,36 +319,16 @@ async def request_call_tool(
307319
metadata=metadata,
308320
)
309321

310-
async def update_token() -> None:
311-
try:
312-
async for token in read:
313-
self._request_state_manager.update_resume_token(request_id, token)
314-
except anyio.ClosedResourceError:
315-
pass
316-
317-
async def close() -> None:
318-
await write.aclose()
319-
await read.aclose()
320-
321-
self._exit_stack.push_async_callback(update_token)
322-
self._exit_stack.push_async_callback(close)
323-
324-
return request_id
325-
326322
async def join_call_tool(
327323
self,
328324
request_id: types.RequestId,
329325
progress_callback: ProgressFnT | None = None,
330326
request_read_timeout_seconds: timedelta | None = None,
331327
done_on_timeout: bool = True,
332328
) -> types.CallToolResult:
333-
resume_token = self._request_state_manager.get_resume_token(request_id)
334-
metadata = ClientMessageMetadata(resumption_token=resume_token)
335-
336329
return await self.join_request(
337330
request_id,
338331
types.CallToolResult,
339-
metadata=metadata,
340332
request_read_timeout_seconds=request_read_timeout_seconds,
341333
progress_callback=progress_callback,
342334
done_on_timeout=done_on_timeout,

src/mcp/shared/session.py

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121
ClientNotification,
2222
ClientRequest,
2323
ClientResult,
24-
EmptyResult,
2524
ErrorData,
2625
JSONRPCError,
2726
JSONRPCMessage,
2827
JSONRPCNotification,
2928
JSONRPCRequest,
3029
JSONRPCResponse,
31-
PingRequest,
3230
ProgressNotification,
3331
RequestParams,
3432
ServerNotification,
@@ -167,9 +165,9 @@ class RequestStateManager(
167165
):
168166
def new_request(self, request: SendRequestT) -> RequestId: ...
169167

170-
def update_resume_token(self, request_id: RequestId, token: str) -> None: ...
168+
async def update_resume_token(self, token: str) -> None: ...
171169

172-
def get_resume_token(self, request_id: RequestId) -> str | None: ...
170+
async def get_resume_token(self) -> str | None: ...
173171

174172
def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT): ...
175173

@@ -210,13 +208,13 @@ class InMemoryRequestStateManager(
210208
],
211209
]
212210
_progress_callbacks: dict[RequestId, list[ProgressFnT]]
213-
_resume_tokens: dict[RequestId, str]
211+
_resume_token: str | None
214212

215213
def __init__(self):
216214
self._request_id = 0
217215
self._response_streams = {}
218216
self._progress_callbacks = {}
219-
self._resume_tokens = {}
217+
self._resume_token = None
220218

221219
def new_request(self, request: SendRequestT) -> RequestId:
222220
request_id = self._request_id
@@ -227,11 +225,11 @@ def new_request(self, request: SendRequestT) -> RequestId:
227225

228226
return request_id
229227

230-
def update_resume_token(self, request_id: RequestId, token: str) -> None:
231-
self._resume_tokens[request_id] = token
228+
async def update_resume_token(self, token: str) -> None:
229+
self._resume_token = token
232230

233-
def get_resume_token(self, request_id: RequestId) -> str | None:
234-
return self._resume_tokens.get(request_id)
231+
async def get_resume_token(self) -> str | None:
232+
return self._resume_token
235233

236234
def add_progress_callback(self, request_id: RequestId, progress_callback: ProgressFnT):
237235
progress_list = self._progress_callbacks.get(request_id)
@@ -303,7 +301,6 @@ async def close_request(self, request_id: RequestId) -> bool:
303301
await response_stream_reader.aclose()
304302

305303
self._progress_callbacks.pop(request_id, None)
306-
self._resume_tokens.pop(request_id, None)
307304

308305
return response_stream is not None
309306

@@ -388,17 +385,37 @@ async def start_request(
388385
instead.
389386
"""
390387
request_id = self._request_state_manager.new_request(request)
391-
return await self._send_request(
392-
request_id=request_id, request=request, metadata=metadata, progress_callback=progress_callback
388+
# Set up progress token if progress callback is provided
389+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
390+
if progress_callback is not None:
391+
# Use request_id as progress token
392+
if "params" not in request_data:
393+
request_data["params"] = {}
394+
if "_meta" not in request_data["params"]:
395+
request_data["params"]["_meta"] = {}
396+
request_data["params"]["_meta"]["progressToken"] = request_id
397+
# Store the callback for this request
398+
self._request_state_manager.add_progress_callback(request_id, progress_callback)
399+
400+
jsonrpc_request = JSONRPCRequest(
401+
jsonrpc="2.0",
402+
id=request_id,
403+
**request_data,
393404
)
394405

406+
try:
407+
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
408+
return request_id
409+
except Exception as e:
410+
await self._request_state_manager.close_request(request_id)
411+
raise e
412+
395413
async def join_request(
396414
self,
397415
request_id: RequestId,
398416
result_type: type[ReceiveResultT],
399417
request_read_timeout_seconds: timedelta | None = None,
400418
progress_callback: ProgressFnT | None = None,
401-
metadata: MessageMetadata | None = None,
402419
done_on_timeout: bool = True,
403420
) -> ReceiveResultT:
404421
"""
@@ -414,15 +431,6 @@ async def join_request(
414431
elif self._session_read_timeout_seconds is not None:
415432
timeout = self._session_read_timeout_seconds.total_seconds()
416433

417-
if metadata:
418-
# need to resend metadata - primary use case is client resume support
419-
await self.send_request(
420-
request=PingRequest(method="ping"), # type: ignore
421-
result_type=EmptyResult,
422-
request_read_timeout_seconds=None if timeout is None else timedelta(seconds=timeout),
423-
metadata=metadata,
424-
)
425-
426434
response_or_error = await self._request_state_manager.receive_response(request_id, timeout)
427435

428436
if isinstance(response_or_error, JSONRPCError):
@@ -436,37 +444,6 @@ async def join_request(
436444
await self._request_state_manager.close_request(request_id)
437445
return result_type.model_validate(response_or_error.result)
438446

439-
async def _send_request(
440-
self,
441-
request_id: RequestId,
442-
request: SendRequestT,
443-
metadata: MessageMetadata = None,
444-
progress_callback: ProgressFnT | None = None,
445-
):
446-
# Set up progress token if progress callback is provided
447-
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
448-
if progress_callback is not None:
449-
# Use request_id as progress token
450-
if "params" not in request_data:
451-
request_data["params"] = {}
452-
if "_meta" not in request_data["params"]:
453-
request_data["params"]["_meta"] = {}
454-
request_data["params"]["_meta"]["progressToken"] = request_id
455-
# Store the callback for this request
456-
self._request_state_manager.add_progress_callback(request_id, progress_callback)
457-
458-
jsonrpc_request = JSONRPCRequest(
459-
jsonrpc="2.0",
460-
id=request_id,
461-
**request_data,
462-
)
463-
464-
try:
465-
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
466-
return request_id
467-
except Exception as e:
468-
await self._request_state_manager.close_request(request_id)
469-
raise e
470447

471448
async def cancel_request(self, request_id: RequestId) -> bool:
472449
"""

0 commit comments

Comments
 (0)