@@ -135,6 +135,7 @@ def __init__(
135135 self ._logging_callback = logging_callback or _default_logging_callback
136136 self ._message_handler = message_handler or _default_message_handler
137137 self ._tool_output_schemas : dict [str , dict [str , Any ] | None ] = {}
138+ self ._resumable = False
138139
139140 async def initialize (self ) -> types .InitializeResult :
140141 sampling = types .SamplingCapability () if self ._sampling_callback is not _default_sampling_callback else None
@@ -172,24 +173,12 @@ async def initialize(self) -> types.InitializeResult:
172173 if result .protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS :
173174 raise RuntimeError (f"Unsupported protocol version from the server: { result .protocolVersion } " )
174175
176+ self ._resumable = result .capabilities .resume and result .capabilities .resume .resumable
177+
175178 await self .send_notification (
176179 types .ClientNotification (types .InitializedNotification (method = "notifications/initialized" ))
177180 )
178181
179- resume_token = await self ._request_state_manager .get_resume_token ()
180- if resume_token :
181- metadata = ClientMessageMetadata (resumption_token = resume_token )
182- timeout = None
183- if self ._session_read_timeout_seconds is not None :
184- timeout = self ._session_read_timeout_seconds .total_seconds ()
185-
186- await self .send_request (
187- request = types .PingRequest (method = "ping" ), # type: ignore
188- result_type = types .EmptyResult ,
189- request_read_timeout_seconds = None if timeout is None else timedelta (seconds = timeout ),
190- metadata = metadata ,
191- )
192-
193182 return result
194183
195184 async def send_ping (self ) -> types .EmptyResult :
@@ -303,21 +292,58 @@ async def request_call_tool(
303292 arguments : dict [str , Any ] | None = None ,
304293 progress_callback : ProgressFnT | None = None ,
305294 ) -> types .RequestId :
306- metadata = ClientMessageMetadata (on_resumption_token_update = self ._request_state_manager .update_resume_token )
307-
308- return await self .start_request (
309- types .ClientRequest (
310- types .CallToolRequest (
311- method = "tools/call" ,
312- params = types .CallToolRequestParams (
313- name = name ,
314- arguments = arguments ,
295+ if self ._resumable :
296+ send_stream , receive_stream = anyio .create_memory_object_stream [str ](1 )
297+
298+ async def close () -> None :
299+ await send_stream .aclose ()
300+ await receive_stream .aclose ()
301+
302+ self ._exit_stack .push_async_callback (close )
303+
304+ with send_stream , receive_stream :
305+
306+ async def send_token (token : str ):
307+ try :
308+ await send_stream .send (token )
309+ except anyio .BrokenResourceError as e :
310+ raise e
311+
312+ metadata = ClientMessageMetadata (on_resumption_token_update = send_token )
313+
314+ request_id = await self .start_request (
315+ types .ClientRequest (
316+ types .CallToolRequest (
317+ method = "tools/call" ,
318+ params = types .CallToolRequestParams (
319+ name = name ,
320+ arguments = arguments ,
321+ ),
322+ )
315323 ),
324+ progress_callback = progress_callback ,
325+ metadata = metadata ,
316326 )
317- ),
318- progress_callback = progress_callback ,
319- metadata = metadata ,
320- )
327+
328+ await anyio .lowlevel .checkpoint ()
329+
330+ token = await receive_stream .receive ()
331+ await self ._request_state_manager .update_resume_token (request_id , token )
332+
333+ return request_id
334+ else :
335+ return await self .start_request (
336+ types .ClientRequest (
337+ types .CallToolRequest (
338+ method = "tools/call" ,
339+ params = types .CallToolRequestParams (
340+ name = name ,
341+ arguments = arguments ,
342+ ),
343+ )
344+ ),
345+ progress_callback = progress_callback ,
346+ )
321347
322348 async def join_call_tool (
323349 self ,
0 commit comments