2121 ClientNotification ,
2222 ClientRequest ,
2323 ClientResult ,
24+ EmptyResult ,
2425 ErrorData ,
2526 JSONRPCError ,
2627 JSONRPCMessage ,
2728 JSONRPCNotification ,
2829 JSONRPCRequest ,
2930 JSONRPCResponse ,
31+ PingRequest ,
3032 ProgressNotification ,
3133 RequestParams ,
3234 ServerNotification ,
@@ -165,6 +167,10 @@ class RequestStateManager(
165167):
166168 def new_request (self , request : SendRequestT ) -> RequestId : ...
167169
170+ def update_resume_token (self , request_id : RequestId , token : str ) -> None : ...
171+
172+ def get_resume_token (self , request_id : RequestId ) -> str | None : ...
173+
168174 def add_progress_callback (self , request_id : RequestId , progress_callback : ProgressFnT ): ...
169175
170176 async def send_progress (
@@ -204,11 +210,13 @@ class ImMemoryRequestStateManager(
204210 ],
205211 ]
206212 _progress_callbacks : dict [RequestId , list [ProgressFnT ]]
213+ _resume_tokens : dict [RequestId , str ]
207214
208215 def __init__ (self ):
209216 self ._request_id = 0
210217 self ._response_streams = {}
211218 self ._progress_callbacks = {}
219+ self ._resume_tokens = {}
212220
213221 def new_request (self , request : SendRequestT ) -> RequestId :
214222 request_id = self ._request_id
@@ -219,6 +227,12 @@ def new_request(self, request: SendRequestT) -> RequestId:
219227
220228 return request_id
221229
230+ def update_resume_token (self , request_id : RequestId , token : str ) -> None :
231+ self ._resume_tokens [request_id ] = token
232+
233+ def get_resume_token (self , request_id : RequestId ) -> str | None :
234+ return self ._resume_tokens .get (request_id )
235+
222236 def add_progress_callback (self , request_id : RequestId , progress_callback : ProgressFnT ):
223237 progress_list = self ._progress_callbacks .get (request_id )
224238 if progress_list is None :
@@ -289,6 +303,7 @@ async def close_request(self, request_id: RequestId) -> bool:
289303 await response_stream_reader .aclose ()
290304
291305 self ._progress_callbacks .pop (request_id , None )
306+ self ._resume_tokens .pop (request_id , None )
292307
293308 return response_stream is not None
294309
@@ -373,38 +388,17 @@ async def start_request(
373388 instead.
374389 """
375390 request_id = self ._request_state_manager .new_request (request )
376-
377- # Set up progress token if progress callback is provided
378- request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
379- if progress_callback is not None :
380- # Use request_id as progress token
381- if "params" not in request_data :
382- request_data ["params" ] = {}
383- if "_meta" not in request_data ["params" ]:
384- request_data ["params" ]["_meta" ] = {}
385- request_data ["params" ]["_meta" ]["progressToken" ] = request_id
386- # Store the callback for this request
387- self ._request_state_manager .add_progress_callback (request_id , progress_callback )
388-
389- jsonrpc_request = JSONRPCRequest (
390- jsonrpc = "2.0" ,
391- id = request_id ,
392- ** request_data ,
391+ return await self ._send_request (
392+ request_id = request_id , request = request , metadata = metadata , progress_callback = progress_callback
393393 )
394394
395- try :
396- await self ._write_stream .send (SessionMessage (message = JSONRPCMessage (jsonrpc_request ), metadata = metadata ))
397- return request_id
398- except Exception as e :
399- await self ._request_state_manager .close_request (request_id )
400- raise e
401-
402395 async def join_request (
403396 self ,
404397 request_id : RequestId ,
405398 result_type : type [ReceiveResultT ],
406399 request_read_timeout_seconds : timedelta | None = None ,
407400 progress_callback : ProgressFnT | None = None ,
401+ metadata : MessageMetadata | None = None ,
408402 done_on_timeout : bool = True ,
409403 ) -> ReceiveResultT :
410404 """
@@ -420,6 +414,15 @@ async def join_request(
420414 elif self ._session_read_timeout_seconds is not None :
421415 timeout = self ._session_read_timeout_seconds .total_seconds ()
422416
417+ if metadata :
418+ # need to resend metadata - primary use case is client resume support
419+ await self .send_request (
420+ request = PingRequest (method = "ping" ), # type: ignore
421+ result_type = EmptyResult ,
422+ request_read_timeout_seconds = None if timeout is None else timedelta (seconds = timeout ),
423+ metadata = metadata ,
424+ )
425+
423426 response_or_error = await self ._request_state_manager .receive_response (request_id , timeout )
424427
425428 if isinstance (response_or_error , JSONRPCError ):
@@ -433,6 +436,38 @@ async def join_request(
433436 await self ._request_state_manager .close_request (request_id )
434437 return result_type .model_validate (response_or_error .result )
435438
439+ async def _send_request (
440+ self ,
441+ request_id : RequestId ,
442+ request : SendRequestT ,
443+ metadata : MessageMetadata = None ,
444+ progress_callback : ProgressFnT | None = None ,
445+ ):
446+ # Set up progress token if progress callback is provided
447+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
448+ if progress_callback is not None :
449+ # Use request_id as progress token
450+ if "params" not in request_data :
451+ request_data ["params" ] = {}
452+ if "_meta" not in request_data ["params" ]:
453+ request_data ["params" ]["_meta" ] = {}
454+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
455+ # Store the callback for this request
456+ self ._request_state_manager .add_progress_callback (request_id , progress_callback )
457+
458+ jsonrpc_request = JSONRPCRequest (
459+ jsonrpc = "2.0" ,
460+ id = request_id ,
461+ ** request_data ,
462+ )
463+
464+ try :
465+ await self ._write_stream .send (SessionMessage (message = JSONRPCMessage (jsonrpc_request ), metadata = metadata ))
466+ return request_id
467+ except Exception as e :
468+ await self ._request_state_manager .close_request (request_id )
469+ raise e
470+
436471 async def cancel_request (self , request_id : RequestId ) -> bool :
437472 """
438473 Cancels a request previously started via start_request
0 commit comments