Skip to content

Commit 011a363

Browse files
committed
Implement optoken to tool name map on client end for validation
1 parent 759a9a3 commit 011a363

File tree

6 files changed

+218
-100
lines changed

6 files changed

+218
-100
lines changed

src/mcp/client/session.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import AnyUrl, TypeAdapter
1010

1111
import mcp.types as types
12+
from mcp.shared.async_operations import ClientAsyncOperationManager
1213
from mcp.shared.context import RequestContext
1314
from mcp.shared.message import SessionMessage
1415
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
@@ -136,6 +137,7 @@ def __init__(
136137
self._logging_callback = logging_callback or _default_logging_callback
137138
self._message_handler = message_handler or _default_message_handler
138139
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
140+
self._operation_manager = ClientAsyncOperationManager()
139141

140142
async def initialize(self) -> types.InitializeResult:
141143
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -174,8 +176,15 @@ async def initialize(self) -> types.InitializeResult:
174176

175177
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
176178

179+
# Start cleanup task for operations
180+
await self._operation_manager.start_cleanup_task()
181+
177182
return result
178183

184+
async def close(self) -> None:
185+
"""Clean up resources."""
186+
await self._operation_manager.stop_cleanup_task()
187+
179188
async def send_ping(self) -> types.EmptyResult:
180189
"""Send a ping request."""
181190
return await self.send_request(
@@ -305,7 +314,14 @@ async def call_tool(
305314
)
306315

307316
if not result.isError:
308-
await self._validate_tool_result(name, result)
317+
# Track operation for async operations
318+
if result.operation is not None:
319+
self._operation_manager.track_operation(
320+
result.operation.token, name, result.operation.keepAlive or 3600
321+
)
322+
logger.debug(f"Tracking operation for token: {result.operation.token}")
323+
else:
324+
await self._validate_tool_result(name, result)
309325

310326
return result
311327

@@ -336,7 +352,7 @@ async def get_operation_result(self, token: str) -> types.GetOperationPayloadRes
336352
Returns:
337353
The final tool result
338354
"""
339-
return await self.send_request(
355+
result = await self.send_request(
340356
types.ClientRequest(
341357
types.GetOperationPayloadRequest(
342358
params=types.GetOperationPayloadParams(token=token),
@@ -345,7 +361,18 @@ async def get_operation_result(self, token: str) -> types.GetOperationPayloadRes
345361
types.GetOperationPayloadResult,
346362
)
347363

348-
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
364+
# Validate using the stored tool name
365+
if hasattr(result, "result") and result.result:
366+
# Clean up expired operations first
367+
self._operation_manager.cleanup_expired()
368+
369+
tool_name = self._operation_manager.get_tool_name(token)
370+
await self._validate_tool_result(tool_name, result.result)
371+
# Keep the operation for potential future retrievals
372+
373+
return result
374+
375+
async def _validate_tool_result(self, name: str | None, result: types.CallToolResult) -> None:
349376
"""Validate the structured content of a tool result against its output schema."""
350377
if name not in self._tool_output_schemas:
351378
# refresh output schema cache
@@ -358,6 +385,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -
358385
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
359386

360387
if output_schema is not None:
388+
logger.debug(f"Validating structured content for tool: {name}")
361389
if result.structuredContent is None:
362390
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
363391
try:

src/mcp/server/fastmcp/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from mcp.server.fastmcp.tools.base import InvocationMode
3434
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
3535
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
36-
from mcp.server.lowlevel.async_operations import AsyncOperationManager
3736
from mcp.server.lowlevel.helper_types import ReadResourceContents
3837
from mcp.server.lowlevel.server import LifespanResultT
3938
from mcp.server.lowlevel.server import Server as MCPServer
@@ -44,6 +43,7 @@
4443
from mcp.server.streamable_http import EventStore
4544
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4645
from mcp.server.transport_security import TransportSecuritySettings
46+
from mcp.shared.async_operations import ServerAsyncOperationManager
4747
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
4848
from mcp.types import (
4949
NEXT_PROTOCOL_VERSION,
@@ -138,7 +138,7 @@ def __init__(
138138
token_verifier: TokenVerifier | None = None,
139139
event_store: EventStore | None = None,
140140
*,
141-
async_operations: AsyncOperationManager | None = None,
141+
async_operations: ServerAsyncOperationManager | None = None,
142142
tools: list[Tool] | None = None,
143143
debug: bool = False,
144144
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
@@ -178,7 +178,7 @@ def __init__(
178178
transport_security=transport_security,
179179
)
180180

181-
self._async_operations = async_operations or AsyncOperationManager()
181+
self._async_operations = async_operations or ServerAsyncOperationManager()
182182

183183
self._mcp_server = MCPServer(
184184
name=name or "FastMCP",

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ async def main():
8383
from typing_extensions import TypeVar
8484

8585
import mcp.types as types
86-
from mcp.server.lowlevel.async_operations import AsyncOperation, AsyncOperationManager
8786
from mcp.server.lowlevel.helper_types import ReadResourceContents
8887
from mcp.server.models import InitializationOptions
8988
from mcp.server.session import ServerSession
89+
from mcp.shared.async_operations import ServerAsyncOperation, ServerAsyncOperationManager
9090
from mcp.shared.context import RequestContext
9191
from mcp.shared.exceptions import McpError
9292
from mcp.shared.message import ServerMessageMetadata, SessionMessage
@@ -138,7 +138,7 @@ def __init__(
138138
name: str,
139139
version: str | None = None,
140140
instructions: str | None = None,
141-
async_operations: AsyncOperationManager | None = None,
141+
async_operations: ServerAsyncOperationManager | None = None,
142142
lifespan: Callable[
143143
[Server[LifespanResultT, RequestT]],
144144
AbstractAsyncContextManager[LifespanResultT],
@@ -148,7 +148,7 @@ def __init__(
148148
self.version = version
149149
self.instructions = instructions
150150
self.lifespan = lifespan
151-
self.async_operations = async_operations or AsyncOperationManager()
151+
self.async_operations = async_operations or ServerAsyncOperationManager()
152152
# Track request ID to operation token mapping for cancellation
153153
self._request_to_operation: dict[RequestId, str] = {}
154154
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
@@ -469,11 +469,9 @@ async def handler(req: types.CallToolRequest):
469469
# Check for async execution
470470
if tool and self.async_operations and self._should_execute_async(tool):
471471
# Create async operation
472-
session_id = f"session_{id(self.request_context.session)}"
473472
operation = self.async_operations.create_operation(
474473
tool_name=tool_name,
475474
arguments=arguments,
476-
session_id=session_id,
477475
)
478476
logger.debug(f"Created async operation with token: {operation.token}")
479477

@@ -627,7 +625,7 @@ async def handler(req: types.CompleteRequest):
627625

628626
return decorator
629627

630-
def _validate_operation_token(self, token: str) -> AsyncOperation:
628+
def _validate_operation_token(self, token: str) -> ServerAsyncOperation:
631629
"""Validate operation token and return operation if valid."""
632630
operation = self.async_operations.get_operation(token)
633631
if not operation:

0 commit comments

Comments
 (0)