99from pydantic import AnyUrl , TypeAdapter
1010
1111import mcp .types as types
12+ from mcp .shared .async_operations import ClientAsyncOperationManager
1213from mcp .shared .context import RequestContext
1314from mcp .shared .message import SessionMessage
1415from mcp .shared .session import BaseSession , ProgressFnT , RequestResponder
@@ -136,6 +137,7 @@ def __init__(
136137 self ._logging_callback = logging_callback or _default_logging_callback
137138 self ._message_handler = message_handler or _default_message_handler
138139 self ._tool_output_schemas : dict [str , dict [str , Any ] | None ] = {}
140+ self ._operation_manager = ClientAsyncOperationManager ()
139141
140142 async def initialize (self ) -> types .InitializeResult :
141143 sampling = types .SamplingCapability () if self ._sampling_callback is not _default_sampling_callback else None
@@ -174,8 +176,15 @@ async def initialize(self) -> types.InitializeResult:
174176
175177 await self .send_notification (types .ClientNotification (types .InitializedNotification ()))
176178
179+ # Start cleanup task for operations
180+ await self ._operation_manager .start_cleanup_task ()
181+
177182 return result
178183
184+ async def close (self ) -> None :
185+ """Clean up resources."""
186+ await self ._operation_manager .stop_cleanup_task ()
187+
179188 async def send_ping (self ) -> types .EmptyResult :
180189 """Send a ping request."""
181190 return await self .send_request (
@@ -305,7 +314,14 @@ async def call_tool(
305314 )
306315
307316 if not result .isError :
308- await self ._validate_tool_result (name , result )
317+ # Track operation for async operations
318+ if result .operation is not None :
319+ self ._operation_manager .track_operation (
320+ result .operation .token , name , result .operation .keepAlive or 3600
321+ )
322+ logger .debug (f"Tracking operation for token: { result .operation .token } " )
323+ else :
324+ await self ._validate_tool_result (name , result )
309325
310326 return result
311327
@@ -336,7 +352,7 @@ async def get_operation_result(self, token: str) -> types.GetOperationPayloadRes
336352 Returns:
337353 The final tool result
338354 """
339- return await self .send_request (
355+ result = await self .send_request (
340356 types .ClientRequest (
341357 types .GetOperationPayloadRequest (
342358 params = types .GetOperationPayloadParams (token = token ),
@@ -345,7 +361,18 @@ async def get_operation_result(self, token: str) -> types.GetOperationPayloadRes
345361 types .GetOperationPayloadResult ,
346362 )
347363
348- async def _validate_tool_result (self , name : str , result : types .CallToolResult ) -> None :
364+ # Validate using the stored tool name
365+ if hasattr (result , "result" ) and result .result :
366+ # Clean up expired operations first
367+ self ._operation_manager .cleanup_expired ()
368+
369+ tool_name = self ._operation_manager .get_tool_name (token )
370+ await self ._validate_tool_result (tool_name , result .result )
371+ # Keep the operation for potential future retrievals
372+
373+ return result
374+
375+ async def _validate_tool_result (self , name : str | None , result : types .CallToolResult ) -> None :
349376 """Validate the structured content of a tool result against its output schema."""
350377 if name not in self ._tool_output_schemas :
351378 # refresh output schema cache
@@ -358,6 +385,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -
358385 logger .warning (f"Tool { name } not listed by server, cannot validate any structured content" )
359386
360387 if output_schema is not None :
388+ logger .debug (f"Validating structured content for tool: { name } " )
361389 if result .structuredContent is None :
362390 raise RuntimeError (f"Tool { name } has an output schema but did not return structured content" )
363391 try :
0 commit comments