2121 ClientNotification ,
2222 ClientRequest ,
2323 ClientResult ,
24- EmptyResult ,
2524 ErrorData ,
2625 JSONRPCError ,
2726 JSONRPCMessage ,
2827 JSONRPCNotification ,
2928 JSONRPCRequest ,
3029 JSONRPCResponse ,
31- PingRequest ,
3230 ProgressNotification ,
3331 RequestParams ,
3432 ServerNotification ,
@@ -167,9 +165,9 @@ class RequestStateManager(
167165):
168166 def new_request (self , request : SendRequestT ) -> RequestId : ...
169167
170- def update_resume_token (self , request_id : RequestId , token : str ) -> None : ...
168+ async def update_resume_token (self , token : str ) -> None : ...
171169
172- def get_resume_token (self , request_id : RequestId ) -> str | None : ...
170+ async def get_resume_token (self ) -> str | None : ...
173171
174172 def add_progress_callback (self , request_id : RequestId , progress_callback : ProgressFnT ): ...
175173
@@ -210,13 +208,13 @@ class InMemoryRequestStateManager(
210208 ],
211209 ]
212210 _progress_callbacks : dict [RequestId , list [ProgressFnT ]]
213- _resume_tokens : dict [ RequestId , str ]
211+ _resume_token : str | None
214212
215213 def __init__ (self ):
216214 self ._request_id = 0
217215 self ._response_streams = {}
218216 self ._progress_callbacks = {}
219- self ._resume_tokens = {}
217+ self ._resume_token = None
220218
221219 def new_request (self , request : SendRequestT ) -> RequestId :
222220 request_id = self ._request_id
@@ -227,11 +225,11 @@ def new_request(self, request: SendRequestT) -> RequestId:
227225
228226 return request_id
229227
230- def update_resume_token (self , request_id : RequestId , token : str ) -> None :
231- self ._resume_tokens [ request_id ] = token
228+ async def update_resume_token (self , token : str ) -> None :
229+ self ._resume_token = token
232230
233- def get_resume_token (self , request_id : RequestId ) -> str | None :
234- return self ._resume_tokens . get ( request_id )
231+ async def get_resume_token (self ) -> str | None :
232+ return self ._resume_token
235233
236234 def add_progress_callback (self , request_id : RequestId , progress_callback : ProgressFnT ):
237235 progress_list = self ._progress_callbacks .get (request_id )
@@ -303,7 +301,6 @@ async def close_request(self, request_id: RequestId) -> bool:
303301 await response_stream_reader .aclose ()
304302
305303 self ._progress_callbacks .pop (request_id , None )
306- self ._resume_tokens .pop (request_id , None )
307304
308305 return response_stream is not None
309306
@@ -388,17 +385,37 @@ async def start_request(
388385 instead.
389386 """
390387 request_id = self ._request_state_manager .new_request (request )
391- return await self ._send_request (
392- request_id = request_id , request = request , metadata = metadata , progress_callback = progress_callback
388+ # Set up progress token if progress callback is provided
389+ request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
390+ if progress_callback is not None :
391+ # Use request_id as progress token
392+ if "params" not in request_data :
393+ request_data ["params" ] = {}
394+ if "_meta" not in request_data ["params" ]:
395+ request_data ["params" ]["_meta" ] = {}
396+ request_data ["params" ]["_meta" ]["progressToken" ] = request_id
397+ # Store the callback for this request
398+ self ._request_state_manager .add_progress_callback (request_id , progress_callback )
399+
400+ jsonrpc_request = JSONRPCRequest (
401+ jsonrpc = "2.0" ,
402+ id = request_id ,
403+ ** request_data ,
393404 )
394405
406+ try :
407+ await self ._write_stream .send (SessionMessage (message = JSONRPCMessage (jsonrpc_request ), metadata = metadata ))
408+ return request_id
409+ except Exception as e :
410+ await self ._request_state_manager .close_request (request_id )
411+ raise e
412+
395413 async def join_request (
396414 self ,
397415 request_id : RequestId ,
398416 result_type : type [ReceiveResultT ],
399417 request_read_timeout_seconds : timedelta | None = None ,
400418 progress_callback : ProgressFnT | None = None ,
401- metadata : MessageMetadata | None = None ,
402419 done_on_timeout : bool = True ,
403420 ) -> ReceiveResultT :
404421 """
@@ -414,15 +431,6 @@ async def join_request(
414431 elif self ._session_read_timeout_seconds is not None :
415432 timeout = self ._session_read_timeout_seconds .total_seconds ()
416433
417- if metadata :
418- # need to resend metadata - primary use case is client resume support
419- await self .send_request (
420- request = PingRequest (method = "ping" ), # type: ignore
421- result_type = EmptyResult ,
422- request_read_timeout_seconds = None if timeout is None else timedelta (seconds = timeout ),
423- metadata = metadata ,
424- )
425-
426434 response_or_error = await self ._request_state_manager .receive_response (request_id , timeout )
427435
428436 if isinstance (response_or_error , JSONRPCError ):
@@ -436,37 +444,6 @@ async def join_request(
436444 await self ._request_state_manager .close_request (request_id )
437445 return result_type .model_validate (response_or_error .result )
438446
439- async def _send_request (
440- self ,
441- request_id : RequestId ,
442- request : SendRequestT ,
443- metadata : MessageMetadata = None ,
444- progress_callback : ProgressFnT | None = None ,
445- ):
446- # Set up progress token if progress callback is provided
447- request_data = request .model_dump (by_alias = True , mode = "json" , exclude_none = True )
448- if progress_callback is not None :
449- # Use request_id as progress token
450- if "params" not in request_data :
451- request_data ["params" ] = {}
452- if "_meta" not in request_data ["params" ]:
453- request_data ["params" ]["_meta" ] = {}
454- request_data ["params" ]["_meta" ]["progressToken" ] = request_id
455- # Store the callback for this request
456- self ._request_state_manager .add_progress_callback (request_id , progress_callback )
457-
458- jsonrpc_request = JSONRPCRequest (
459- jsonrpc = "2.0" ,
460- id = request_id ,
461- ** request_data ,
462- )
463-
464- try :
465- await self ._write_stream .send (SessionMessage (message = JSONRPCMessage (jsonrpc_request ), metadata = metadata ))
466- return request_id
467- except Exception as e :
468- await self ._request_state_manager .close_request (request_id )
469- raise e
470447
471448 async def cancel_request (self , request_id : RequestId ) -> bool :
472449 """
0 commit comments