Skip to content

Commit 9b1fd81

Browse files
committed
work in progress towards async protocol - simple call and get work
1 parent a2c0ade commit 9b1fd81

File tree

7 files changed

+238
-1
lines changed

7 files changed

+238
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"sse-starlette>=1.6.1",
3232
"pydantic-settings>=2.5.2",
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
34+
"cachetools==6",
3435
]
3536

3637
[project.optional-dependencies]

src/mcp/server/fastmcp/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
tools: list[Tool] | None = None,
148148
**settings: Any,
149149
):
150-
self.settings = Settings(**settings) # type: ignore
150+
self.settings = Settings(**settings) # type: ignore
151151

152152
self._mcp_server = MCPServer(
153153
name=name or "FastMCP",
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from collections.abc import Awaitable, Callable
2+
from dataclasses import dataclass, field
3+
from time import time
4+
from typing import Any
5+
from uuid import uuid4
6+
7+
from anyio import Lock, create_task_group, move_on_after
8+
from anyio.abc import TaskGroup
9+
from cachetools import TTLCache
10+
11+
from mcp import types
12+
from mcp.shared.context import BaseSession, RequestContext, SessionT
13+
14+
15+
@dataclass
16+
class InProgress:
17+
token: str
18+
task_group: TaskGroup | None = None
19+
sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field(
20+
default_factory=lambda: []
21+
)
22+
23+
24+
class ResultCache:
25+
_in_progress: dict[types.AsyncToken, InProgress]
26+
27+
def __init__(self, max_size: int, max_keep_alive: int):
28+
self._max_size = max_size
29+
self._max_keep_alive = max_keep_alive
30+
self._result_cache = TTLCache[types.AsyncToken, types.CallToolResult](
31+
self._max_size, self._max_keep_alive
32+
)
33+
self._in_progress = {}
34+
self._lock = Lock()
35+
36+
async def add_call(
37+
self,
38+
call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]],
39+
req: types.CallToolAsyncRequest,
40+
ctx: RequestContext[SessionT, Any, Any],
41+
) -> types.CallToolAsyncResult:
42+
in_progress = await self._new_in_progress()
43+
timeout = min(
44+
req.params.keepAlive or self._max_keep_alive, self._max_keep_alive
45+
)
46+
47+
async def call_tool():
48+
with move_on_after(timeout) as scope:
49+
result = await call(
50+
types.CallToolRequest(
51+
method="tools/call",
52+
params=types.CallToolRequestParams(
53+
name=req.params.name, arguments=req.params.arguments
54+
),
55+
)
56+
)
57+
if not scope.cancel_called:
58+
async with self._lock:
59+
assert type(result.root) is types.CallToolResult
60+
self._result_cache[in_progress.token] = result.root
61+
62+
async with create_task_group() as tg:
63+
tg.start_soon(call_tool)
64+
in_progress.task_group = tg
65+
in_progress.sessions.append(ctx.session)
66+
result = types.CallToolAsyncResult(
67+
token=in_progress.token,
68+
recieved=round(time()),
69+
keepAlive=timeout,
70+
accepted=True,
71+
)
72+
return result
73+
74+
async def join_call(
75+
self,
76+
req: types.JoinCallToolAsyncRequest,
77+
ctx: RequestContext[SessionT, Any, Any],
78+
) -> types.CallToolAsyncResult:
79+
async with self._lock:
80+
in_progress = self._in_progress.get(req.params.token)
81+
if in_progress is None:
82+
# TODO consider creating new token to allow client
83+
# to get message describing why it wasn't accepted
84+
return types.CallToolAsyncResult(accepted=False)
85+
else:
86+
in_progress.sessions.append(ctx.session)
87+
return types.CallToolAsyncResult(accepted=True)
88+
89+
return
90+
91+
async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
92+
async with self._lock:
93+
in_progress = self._in_progress.get(notification.params.token)
94+
if in_progress is not None and in_progress.task_group is not None:
95+
in_progress.task_group.cancel_scope.cancel()
96+
del self._in_progress[notification.params.token]
97+
98+
async def get_result(self, req: types.GetToolAsyncResultRequest):
99+
async with self._lock:
100+
in_progress = self._in_progress.get(req.params.token)
101+
if in_progress is None:
102+
return types.CallToolResult(
103+
content=[
104+
types.TextContent(type="text", text="Unknown progress token")
105+
],
106+
isError=True,
107+
)
108+
else:
109+
result = self._result_cache.get(in_progress.token)
110+
if result is None:
111+
return types.CallToolResult(content=[], isPending=True)
112+
else:
113+
return result
114+
115+
async def _new_in_progress(self) -> InProgress:
116+
async with self._lock:
117+
while True:
118+
token = str(uuid4())
119+
if token not in self._in_progress:
120+
new_in_progress = InProgress(token)
121+
self._in_progress[token] = new_in_progress
122+
return new_in_progress

src/mcp/server/lowlevel/server.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ async def main():
8181

8282
import mcp.types as types
8383
from mcp.server.lowlevel.helper_types import ReadResourceContents
84+
from mcp.server.lowlevel.result_cache import ResultCache
8485
from mcp.server.models import InitializationOptions
8586
from mcp.server.session import ServerSession
8687
from mcp.server.stdio import stdio_server as stdio_server
@@ -135,6 +136,8 @@ def __init__(
135136
[Server[LifespanResultT, RequestT]],
136137
AbstractAsyncContextManager[LifespanResultT],
137138
] = lifespan,
139+
max_cache_size: int = 1000,
140+
max_cache_ttl: int = 60,
138141
):
139142
self.name = name
140143
self.version = version
@@ -145,6 +148,7 @@ def __init__(
145148
] = {
146149
types.PingRequest: _ping_handler,
147150
}
151+
self.result_cache = ResultCache(max_cache_size, max_cache_ttl)
148152
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
149153
self.notification_options = NotificationOptions()
150154
logger.debug(f"Initializing server '{name}'")
@@ -426,7 +430,32 @@ async def handler(req: types.CallToolRequest):
426430
)
427431
)
428432

433+
async def async_call_handler(req: types.CallToolAsyncRequest):
434+
ctx = request_ctx.get()
435+
result = await self.result_cache.add_call(handler, req, ctx)
436+
return types.ServerResult(result)
437+
438+
async def async_join_handler(req: types.JoinCallToolAsyncRequest):
439+
ctx = request_ctx.get()
440+
result = await self.result_cache.join_call(req, ctx)
441+
return types.ServerResult(result)
442+
443+
async def async_cancel_handler(req: types.CancelToolAsyncNotification):
444+
await self.result_cache.cancel(req)
445+
446+
async def async_result_handler(req: types.GetToolAsyncResultRequest):
447+
result = await self.result_cache.get_result(req)
448+
return types.ServerResult(result)
449+
429450
self.request_handlers[types.CallToolRequest] = handler
451+
self.request_handlers[types.CallToolAsyncRequest] = async_call_handler
452+
self.request_handlers[types.JoinCallToolAsyncRequest] = async_join_handler
453+
self.request_handlers[types.GetToolAsyncResultRequest] = (
454+
async_result_handler
455+
)
456+
self.notification_handlers[types.CancelToolAsyncNotification] = (
457+
async_cancel_handler
458+
)
430459
return func
431460

432461
return decorator

src/mcp/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,8 @@ class Tool(BaseModel):
799799
"""A JSON Schema object defining the expected parameters for the tool."""
800800
annotations: ToolAnnotations | None = None
801801
"""Optional additional tool information."""
802+
preferAsync: bool | None = None
803+
"""Optional flag to suggest to client async calls should be preferred"""
802804
model_config = ConfigDict(extra="allow")
803805

804806

@@ -1227,6 +1229,7 @@ class ClientRequest(
12271229
| CallToolRequest
12281230
| CallToolAsyncRequest
12291231
| JoinCallToolAsyncRequest
1232+
| GetToolAsyncResultRequest
12301233
| ListToolsRequest
12311234
]
12321235
):

tests/shared/test_session.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,77 @@ async def make_request(client_session):
129129
await ev_cancelled.wait()
130130

131131

132+
@pytest.mark.anyio
133+
async def test_request_async():
134+
"""Test that requests can be run asynchronously."""
135+
# The tool is already registered in the fixture
136+
137+
ev_tool_called = anyio.Event()
138+
139+
# Start the request in a separate task so we can cancel it
140+
def make_server() -> Server:
141+
server = Server(name="TestSessionServer")
142+
143+
# Register the tool handler
144+
@server.call_tool()
145+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
146+
nonlocal ev_tool_called
147+
if name == "async_tool":
148+
ev_tool_called.set()
149+
return [types.TextContent(type="text", text="test")]
150+
raise ValueError(f"Unknown tool: {name}")
151+
152+
# Register the tool so it shows up in list_tools
153+
@server.list_tools()
154+
async def handle_list_tools() -> list[types.Tool]:
155+
return [
156+
types.Tool(
157+
name="async_tool",
158+
description="A tool that does things asynchronously",
159+
inputSchema={},
160+
preferAsync=True,
161+
)
162+
]
163+
164+
return server
165+
166+
async def make_request(client_session: ClientSession):
167+
return await client_session.send_request(
168+
ClientRequest(
169+
types.CallToolAsyncRequest(
170+
method="tools/async/call",
171+
params=types.CallToolAsyncRequestParams(
172+
name="async_tool", arguments={}
173+
),
174+
)
175+
),
176+
types.CallToolAsyncResult,
177+
)
178+
179+
async def get_result(client_session: ClientSession, async_token: types.AsyncToken):
180+
return await client_session.send_request(
181+
ClientRequest(
182+
types.GetToolAsyncResultRequest(
183+
method="tools/async/get",
184+
params=types.GetToolAsyncResultRequestParams(token=async_token),
185+
)
186+
),
187+
types.CallToolResult,
188+
)
189+
190+
async with create_connected_server_and_client_session(
191+
make_server()
192+
) as client_session:
193+
async_result = await make_request(client_session)
194+
assert async_result is not None
195+
assert async_result.token is not None
196+
with anyio.fail_after(1): # Timeout after 1 second
197+
await ev_tool_called.wait()
198+
result = await get_result(client_session, async_result.token)
199+
assert type(result.content[0]) is types.TextContent
200+
assert result.content[0].text == "test"
201+
202+
132203
@pytest.mark.anyio
133204
async def test_connection_closed():
134205
"""

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)