Skip to content

Commit 7dd550b

Browse files
committed
Implement server-side handling for async tool calls
1 parent e5e4078 commit 7dd550b

File tree

7 files changed

+832
-12
lines changed

7 files changed

+832
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ mccabe.max-complexity = 24 # Default is 10
128128

129129
[tool.ruff.lint.pylint]
130130
allow-magic-value-types = ["bytes", "float", "int", "str"]
131-
max-args = 23 # Default is 5
131+
max-args = 24 # Default is 5
132132
max-branches = 23 # Default is 12
133133
max-returns = 13 # Default is 6
134134
max-statements = 102 # Default is 50

src/mcp/server/fastmcp/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from mcp.server.fastmcp.tools.base import InvocationMode
3434
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
3535
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
36+
from mcp.server.lowlevel.async_operations import AsyncOperationManager
3637
from mcp.server.lowlevel.helper_types import ReadResourceContents
3738
from mcp.server.lowlevel.server import LifespanResultT
3839
from mcp.server.lowlevel.server import Server as MCPServer
@@ -129,6 +130,7 @@ def __init__(
129130
token_verifier: TokenVerifier | None = None,
130131
event_store: EventStore | None = None,
131132
*,
133+
async_operations: AsyncOperationManager | None = None,
132134
tools: list[Tool] | None = None,
133135
debug: bool = False,
134136
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
@@ -178,6 +180,7 @@ def __init__(
178180
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
179181
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
180182
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
183+
self.async_operations = async_operations or AsyncOperationManager()
181184
# Validate auth configuration
182185
if self.settings.auth is not None:
183186
if auth_server_provider and token_verifier:
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""Async operations management for FastMCP servers."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import secrets
7+
import time
8+
from collections.abc import Callable
9+
from dataclasses import dataclass
10+
from typing import Any
11+
12+
import mcp.types as types
13+
from mcp.types import AsyncOperationStatus
14+
15+
16+
@dataclass
17+
class AsyncOperation:
18+
"""Represents an async tool operation."""
19+
20+
token: str
21+
tool_name: str
22+
arguments: dict[str, Any]
23+
session_id: str
24+
status: AsyncOperationStatus
25+
created_at: float
26+
keep_alive: int
27+
result: types.CallToolResult | None = None
28+
error: str | None = None
29+
30+
@property
31+
def is_expired(self) -> bool:
32+
"""Check if operation has expired based on keepAlive."""
33+
if self.status in ("completed", "failed", "canceled"):
34+
return time.time() > (self.created_at + self.keep_alive)
35+
return False
36+
37+
@property
38+
def is_terminal(self) -> bool:
39+
"""Check if operation is in a terminal state."""
40+
return self.status in ("completed", "failed", "canceled", "unknown")
41+
42+
43+
class AsyncOperationManager:
44+
"""Manages async tool operations with token-based tracking."""
45+
46+
def __init__(self, *, token_generator: Callable[[str], str] | None = None):
47+
self._operations: dict[str, AsyncOperation] = {}
48+
self._cleanup_task: asyncio.Task[None] | None = None
49+
self._cleanup_interval = 60 # Cleanup every 60 seconds
50+
self._token_generator = token_generator or self._default_token_generator
51+
52+
def _default_token_generator(self, session_id: str) -> str:
53+
"""Default token generation using random tokens."""
54+
return secrets.token_urlsafe(32)
55+
56+
def generate_token(self, session_id: str) -> str:
57+
"""Generate a token."""
58+
return self._token_generator(session_id)
59+
60+
def create_operation(
61+
self,
62+
tool_name: str,
63+
arguments: dict[str, Any],
64+
session_id: str,
65+
keep_alive: int = 3600,
66+
) -> AsyncOperation:
67+
"""Create a new async operation."""
68+
token = self.generate_token(session_id)
69+
operation = AsyncOperation(
70+
token=token,
71+
tool_name=tool_name,
72+
arguments=arguments,
73+
session_id=session_id,
74+
status="submitted",
75+
created_at=time.time(),
76+
keep_alive=keep_alive,
77+
)
78+
self._operations[token] = operation
79+
return operation
80+
81+
def get_operation(self, token: str) -> AsyncOperation | None:
82+
"""Get operation by token."""
83+
return self._operations.get(token)
84+
85+
def mark_working(self, token: str) -> bool:
86+
"""Mark operation as working."""
87+
operation = self._operations.get(token)
88+
if not operation:
89+
return False
90+
91+
# Can only transition to working from submitted
92+
if operation.status != "submitted":
93+
return False
94+
95+
operation.status = "working"
96+
return True
97+
98+
def complete_operation(self, token: str, result: types.CallToolResult) -> bool:
99+
"""Complete operation with result."""
100+
operation = self._operations.get(token)
101+
if not operation:
102+
return False
103+
104+
# Can only complete from submitted or working states
105+
if operation.status not in ("submitted", "working"):
106+
return False
107+
108+
operation.status = "completed"
109+
operation.result = result
110+
return True
111+
112+
def fail_operation(self, token: str, error: str) -> bool:
113+
"""Fail operation with error."""
114+
operation = self._operations.get(token)
115+
if not operation:
116+
return False
117+
118+
# Can only fail from submitted or working states
119+
if operation.status not in ("submitted", "working"):
120+
return False
121+
122+
operation.status = "failed"
123+
operation.error = error
124+
return True
125+
126+
def get_operation_result(self, token: str) -> types.CallToolResult | None:
127+
"""Get result for completed operation."""
128+
operation = self._operations.get(token)
129+
if not operation or operation.status != "completed":
130+
return None
131+
return operation.result
132+
133+
def cancel_operation(self, token: str) -> bool:
134+
"""Cancel operation."""
135+
operation = self._operations.get(token)
136+
if not operation:
137+
return False
138+
139+
# Can only cancel from submitted or working states
140+
if operation.status not in ("submitted", "working"):
141+
return False
142+
143+
operation.status = "canceled"
144+
return True
145+
146+
def remove_operation(self, token: str) -> bool:
147+
"""Remove operation by token."""
148+
return self._operations.pop(token, None) is not None
149+
150+
def cleanup_expired_operations(self) -> int:
151+
"""Remove expired operations and return count removed."""
152+
expired_tokens = [token for token, op in self._operations.items() if op.is_expired]
153+
154+
for token in expired_tokens:
155+
del self._operations[token]
156+
157+
return len(expired_tokens)
158+
159+
def get_session_operations(self, session_id: str) -> list[AsyncOperation]:
160+
"""Get all operations for a session."""
161+
return [op for op in self._operations.values() if op.session_id == session_id]
162+
163+
def cancel_session_operations(self, session_id: str) -> int:
164+
"""Cancel all operations for a session."""
165+
session_ops = self.get_session_operations(session_id)
166+
canceled_count = 0
167+
168+
for op in session_ops:
169+
if not op.is_terminal:
170+
op.status = "canceled"
171+
canceled_count += 1
172+
173+
return canceled_count
174+
175+
async def start_cleanup_task(self) -> None:
176+
"""Start the background cleanup task."""
177+
if self._cleanup_task is not None:
178+
return
179+
180+
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
181+
182+
async def stop_cleanup_task(self) -> None:
183+
"""Stop the background cleanup task."""
184+
if self._cleanup_task is not None:
185+
self._cleanup_task.cancel()
186+
try:
187+
await self._cleanup_task
188+
except asyncio.CancelledError:
189+
pass
190+
self._cleanup_task = None
191+
192+
async def _cleanup_loop(self) -> None:
193+
"""Background cleanup loop."""
194+
while True:
195+
try:
196+
await asyncio.sleep(self._cleanup_interval)
197+
self.cleanup_expired_operations()
198+
except asyncio.CancelledError:
199+
break
200+
except Exception:
201+
# Log error but continue cleanup loop
202+
pass

src/mcp/server/lowlevel/server.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ async def main():
8282
from typing_extensions import TypeVar
8383

8484
import mcp.types as types
85+
from mcp.server.lowlevel.async_operations import AsyncOperation, AsyncOperationManager
8586
from mcp.server.lowlevel.helper_types import ReadResourceContents
8687
from mcp.server.models import InitializationOptions
8788
from mcp.server.session import ServerSession
@@ -135,6 +136,7 @@ def __init__(
135136
name: str,
136137
version: str | None = None,
137138
instructions: str | None = None,
139+
async_operations: AsyncOperationManager | None = None,
138140
lifespan: Callable[
139141
[Server[LifespanResultT, RequestT]],
140142
AbstractAsyncContextManager[LifespanResultT],
@@ -144,6 +146,7 @@ def __init__(
144146
self.version = version
145147
self.instructions = instructions
146148
self.lifespan = lifespan
149+
self.async_operations = async_operations or AsyncOperationManager()
147150
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
148151
types.PingRequest: _ping_handler,
149152
}
@@ -554,6 +557,64 @@ async def handler(req: types.CompleteRequest):
554557

555558
return decorator
556559

560+
def _validate_operation_token(self, token: str) -> AsyncOperation:
561+
"""Validate operation token and return operation if valid."""
562+
operation = self.async_operations.get_operation(token)
563+
if not operation:
564+
raise McpError(types.ErrorData(code=-32602, message="Invalid token"))
565+
566+
if operation.is_expired:
567+
raise McpError(types.ErrorData(code=-32602, message="Token expired"))
568+
569+
return operation
570+
571+
def check_tool_async_status(self):
572+
"""Register a handler for checking async tool execution status."""
573+
574+
def decorator(func: Callable[[str], Awaitable[types.CheckToolAsyncStatusResult]]):
575+
logger.debug("Registering handler for CheckToolAsyncStatusRequest")
576+
577+
async def handler(req: types.CheckToolAsyncStatusRequest):
578+
# Validate token and get operation
579+
operation = self._validate_operation_token(req.params.token)
580+
581+
return types.ServerResult(
582+
types.CheckToolAsyncStatusResult(
583+
status=operation.status,
584+
error=operation.error,
585+
)
586+
)
587+
588+
self.request_handlers[types.CheckToolAsyncStatusRequest] = handler
589+
return func
590+
591+
return decorator
592+
593+
def get_tool_async_result(self):
594+
"""Register a handler for retrieving async tool execution results."""
595+
596+
def decorator(func: Callable[[str], Awaitable[types.GetToolAsyncPayloadResult]]):
597+
logger.debug("Registering handler for GetToolAsyncPayloadRequest")
598+
599+
async def handler(req: types.GetToolAsyncPayloadRequest):
600+
# Validate token and get operation
601+
operation = self._validate_operation_token(req.params.token)
602+
603+
if operation.status != "completed":
604+
raise McpError(
605+
types.ErrorData(code=-32600, message=f"Operation not completed (status: {operation.status})")
606+
)
607+
608+
if not operation.result:
609+
raise McpError(types.ErrorData(code=-32600, message="No result available for completed operation"))
610+
611+
return types.ServerResult(types.GetToolAsyncPayloadResult(result=operation.result))
612+
613+
self.request_handlers[types.GetToolAsyncPayloadRequest] = handler
614+
return func
615+
616+
return decorator
617+
557618
async def run(
558619
self,
559620
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -581,17 +642,27 @@ async def run(
581642
)
582643
)
583644

584-
async with anyio.create_task_group() as tg:
585-
async for message in session.incoming_messages:
586-
logger.debug("Received message: %s", message)
645+
# Start async operations cleanup task
646+
await self.async_operations.start_cleanup_task()
587647

588-
tg.start_soon(
589-
self._handle_message,
590-
message,
591-
session,
592-
lifespan_context,
593-
raise_exceptions,
594-
)
648+
try:
649+
async with anyio.create_task_group() as tg:
650+
async for message in session.incoming_messages:
651+
logger.debug("Received message: %s", message)
652+
653+
tg.start_soon(
654+
self._handle_message,
655+
message,
656+
session,
657+
lifespan_context,
658+
raise_exceptions,
659+
)
660+
finally:
661+
# Cancel session operations and stop cleanup task
662+
session_id = getattr(session, "session_id", None)
663+
if session_id is not None:
664+
self.async_operations.cancel_session_operations(session_id)
665+
await self.async_operations.stop_cleanup_task()
595666

596667
async def _handle_message(
597668
self,

src/mcp/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -910,10 +910,14 @@ class CheckToolAsyncStatusRequest(Request[CheckToolAsyncStatusParams, Literal["t
910910
params: CheckToolAsyncStatusParams
911911

912912

913+
"""Status values for async operations."""
914+
AsyncOperationStatus = Literal["submitted", "working", "completed", "canceled", "failed", "unknown"]
915+
916+
913917
class CheckToolAsyncStatusResult(Result):
914918
"""Result of checking async tool status."""
915919

916-
status: Literal["submitted", "working", "completed", "canceled", "failed", "unknown"]
920+
status: AsyncOperationStatus
917921
"""Current status of the async operation."""
918922
error: str | None = None
919923
"""Error message if status is 'failed'."""

0 commit comments

Comments
 (0)