Skip to content

Commit f667b63

Browse files
committed
Add tests for cancellation behaviour and implement
1 parent f926235 commit f667b63

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

src/mcp/server/lowlevel/result_cache.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
from collections.abc import Awaitable, Callable
3-
from concurrent.futures import Future
3+
from concurrent.futures import CancelledError, Future
44
from dataclasses import dataclass, field
55
from logging import getLogger
66
from types import TracebackType
@@ -168,7 +168,8 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
168168
if in_progress is not None:
169169
if in_progress.user == user_context.get():
170170
# in_progress.task_group.cancel_scope.cancel()
171-
del self._in_progress[notification.params.token]
171+
assert in_progress.future is not None, "In progress future not found"
172+
in_progress.future.cancel()
172173
else:
173174
logger.warning(
174175
"Permission denied for cancel notification received"
@@ -202,6 +203,11 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
202203
result = in_progress.future.result(1)
203204
logger.debug(f"Found result {result}")
204205
return result
206+
except CancelledError:
207+
return types.CallToolResult(
208+
content=[types.TextContent(type="text", text="cancelled")],
209+
isError=True,
210+
)
205211
except TimeoutError:
206212
return types.CallToolResult(
207213
content=[],

tests/server/lowlevel/test_result_cache.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import AsyncExitStack
22
from unittest.mock import AsyncMock, Mock
33

4+
import anyio
45
import pytest
56

67
from mcp import types
@@ -124,6 +125,84 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
124125
)
125126

126127

128+
@pytest.mark.anyio
129+
async def test_async_cancel_in_progress():
130+
"""Tests basic async call"""
131+
132+
async def slow_call(call: types.CallToolRequest) -> types.ServerResult:
133+
with anyio.move_on_after(10) as scope:
134+
await anyio.sleep(10)
135+
136+
if scope.cancel_called:
137+
return types.ServerResult(
138+
types.CallToolResult(
139+
content=[
140+
types.TextContent(type="text", text="should be discarded")
141+
],
142+
isError=True,
143+
)
144+
)
145+
else:
146+
return types.ServerResult(
147+
types.CallToolResult(
148+
content=[types.TextContent(type="text", text="test")]
149+
)
150+
)
151+
152+
async_call = types.CallToolAsyncRequest(
153+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
154+
)
155+
156+
mock_session_1 = AsyncMock()
157+
mock_context_1 = Mock()
158+
mock_context_1.session = mock_session_1
159+
160+
result_cache = ResultCache(max_size=1, max_keep_alive=1)
161+
async with AsyncExitStack() as stack:
162+
await stack.enter_async_context(result_cache)
163+
async_call_ref = await result_cache.start_call(
164+
slow_call, async_call, mock_context_1
165+
)
166+
assert async_call_ref.token is not None
167+
168+
await result_cache.cancel(
169+
notification=types.CancelToolAsyncNotification(
170+
method="tools/async/cancel",
171+
params=types.CancelToolAsyncNotificationParams(
172+
token=async_call_ref.token
173+
),
174+
),
175+
)
176+
177+
assert async_call_ref.token is not None
178+
await result_cache.notification_hook(
179+
session=mock_session_1,
180+
notification=types.ServerNotification(
181+
types.ProgressNotification(
182+
method="notifications/progress",
183+
params=types.ProgressNotificationParams(
184+
progressToken="test", progress=1
185+
),
186+
)
187+
),
188+
)
189+
190+
result = await result_cache.get_result(
191+
types.GetToolAsyncResultRequest(
192+
method="tools/async/get",
193+
params=types.GetToolAsyncResultRequestParams(
194+
token=async_call_ref.token
195+
),
196+
)
197+
)
198+
199+
assert result.isError
200+
assert not result.isPending
201+
assert len(result.content) == 1
202+
assert type(result.content[0]) is types.TextContent
203+
assert result.content[0].text == "cancelled"
204+
205+
127206
@pytest.mark.anyio
128207
async def test_async_call_keep_alive():
129208
"""Tests async call keep alive"""

0 commit comments

Comments
 (0)