Skip to content

Commit 0dc8d43

Browse files
committed
Handle cancellation notifications on async ops
1 parent 8d281be commit 0dc8d43

File tree

2 files changed

+127
-5
lines changed

2 files changed

+127
-5
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ async def main():
9090
from mcp.shared.exceptions import McpError
9191
from mcp.shared.message import ServerMessageMetadata, SessionMessage
9292
from mcp.shared.session import RequestResponder
93+
from mcp.types import RequestId
9394

9495
logger = logging.getLogger(__name__)
9596

@@ -147,10 +148,14 @@ def __init__(
147148
self.instructions = instructions
148149
self.lifespan = lifespan
149150
self.async_operations = async_operations or AsyncOperationManager()
151+
# Track request ID to operation token mapping for cancellation
152+
self._request_to_operation: dict[RequestId, str] = {}
150153
self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = {
151154
types.PingRequest: _ping_handler,
152155
}
153-
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
156+
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {
157+
types.CancelledNotification: self._handle_cancelled_notification,
158+
}
154159
self._tool_cache: dict[str, types.Tool] = {}
155160
logger.debug("Initializing server %r", name)
156161

@@ -566,6 +571,10 @@ def _validate_operation_token(self, token: str) -> AsyncOperation:
566571
if operation.is_expired:
567572
raise McpError(types.ErrorData(code=-32602, message="Token expired"))
568573

574+
# Check if operation was cancelled - ignore subsequent requests
575+
if operation.status == "canceled":
576+
raise McpError(types.ErrorData(code=-32602, message="Operation was cancelled"))
577+
569578
return operation
570579

571580
def get_operation_status(self):
@@ -615,6 +624,23 @@ async def handler(req: types.GetOperationPayloadRequest):
615624

616625
return decorator
617626

627+
def handle_cancelled_notification(self, request_id: RequestId) -> None:
628+
"""Handle cancellation notification for a request."""
629+
# Check if this request ID corresponds to an async operation
630+
if request_id in self._request_to_operation:
631+
token = self._request_to_operation[request_id]
632+
# Cancel the operation
633+
if self.async_operations.cancel_operation(token):
634+
logger.debug(f"Cancelled async operation {token} for request {request_id}")
635+
# Clean up the mapping
636+
del self._request_to_operation[request_id]
637+
638+
async def _handle_cancelled_notification(self, notification: types.CancelledNotification) -> None:
639+
"""Handle cancelled notification from client."""
640+
request_id = notification.params.requestId
641+
logger.debug(f"Received cancellation notification for request {request_id}")
642+
self.handle_cancelled_notification(request_id)
643+
618644
async def run(
619645
self,
620646
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -695,7 +721,7 @@ async def _handle_request(
695721
if handler := self.request_handlers.get(type(req)): # type: ignore
696722
logger.debug("Dispatching request of type %s", type(req).__name__)
697723

698-
token = None
724+
context_token = None
699725
try:
700726
# Extract request context from message metadata
701727
request_data = None
@@ -704,7 +730,7 @@ async def _handle_request(
704730

705731
# Set our global state that can be retrieved via
706732
# app.get_request_context()
707-
token = request_ctx.set(
733+
context_token = request_ctx.set(
708734
RequestContext(
709735
message.request_id,
710736
message.request_meta,
@@ -714,6 +740,16 @@ async def _handle_request(
714740
)
715741
)
716742
response = await handler(req)
743+
744+
# Track async operations for cancellation
745+
if isinstance(req, types.CallToolRequest):
746+
result = response.root
747+
if isinstance(result, types.CallToolResult) and result.operation_result is not None:
748+
# This is an async operation, track the request ID to token mapping
749+
operation_token = result.operation_result.token
750+
self._request_to_operation[message.request_id] = operation_token
751+
logger.debug(f"Tracking async operation {operation_token} for request {message.request_id}")
752+
717753
except McpError as err:
718754
response = err.error
719755
except anyio.get_cancelled_exc_class():
@@ -728,8 +764,8 @@ async def _handle_request(
728764
response = types.ErrorData(code=0, message=str(err), data=None)
729765
finally:
730766
# Reset the global state after we are done
731-
if token is not None:
732-
request_ctx.reset(token)
767+
if context_token is not None:
768+
request_ctx.reset(context_token)
733769

734770
await message.respond(response)
735771
else:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Tests for async operation cancellation logic."""
2+
3+
import pytest
4+
5+
import mcp.types as types
6+
from mcp.server.lowlevel.async_operations import AsyncOperationManager
7+
from mcp.server.lowlevel.server import Server
8+
from mcp.shared.exceptions import McpError
9+
10+
11+
class TestCancellationLogic:
12+
"""Test cancellation logic for async operations."""
13+
14+
def test_handle_cancelled_notification(self):
15+
"""Test handling of cancelled notifications."""
16+
manager = AsyncOperationManager()
17+
server = Server("Test", async_operations=manager)
18+
19+
# Create an operation
20+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
21+
22+
# Track the operation with a request ID
23+
request_id = "req_123"
24+
server._request_to_operation[request_id] = operation.token
25+
26+
# Handle cancellation
27+
server.handle_cancelled_notification(request_id)
28+
29+
# Verify operation was cancelled
30+
cancelled_op = manager.get_operation(operation.token)
31+
assert cancelled_op is not None
32+
assert cancelled_op.status == "canceled"
33+
34+
# Verify mapping was cleaned up
35+
assert request_id not in server._request_to_operation
36+
37+
def test_cancelled_notification_handler(self):
38+
"""Test the async cancelled notification handler."""
39+
manager = AsyncOperationManager()
40+
server = Server("Test", async_operations=manager)
41+
42+
# Create an operation
43+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
44+
45+
# Track the operation with a request ID
46+
request_id = "req_456"
47+
server._request_to_operation[request_id] = operation.token
48+
49+
# Create cancelled notification
50+
notification = types.CancelledNotification(params=types.CancelledNotificationParams(requestId=request_id))
51+
52+
# Handle the notification
53+
import asyncio
54+
55+
asyncio.run(server._handle_cancelled_notification(notification))
56+
57+
# Verify operation was cancelled
58+
cancelled_op = manager.get_operation(operation.token)
59+
assert cancelled_op is not None
60+
assert cancelled_op.status == "canceled"
61+
62+
def test_validate_operation_token_cancelled(self):
63+
"""Test that cancelled operations are rejected."""
64+
manager = AsyncOperationManager()
65+
server = Server("Test", async_operations=manager)
66+
67+
# Create and cancel an operation
68+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
69+
manager.cancel_operation(operation.token)
70+
71+
# Verify that accessing cancelled operation raises error
72+
with pytest.raises(McpError) as exc_info:
73+
server._validate_operation_token(operation.token)
74+
75+
assert exc_info.value.error.code == -32602
76+
assert "cancelled" in exc_info.value.error.message.lower()
77+
78+
def test_nonexistent_request_id_cancellation(self):
79+
"""Test cancellation of non-existent request ID."""
80+
server = Server("Test")
81+
82+
# Should not raise error for non-existent request ID
83+
server.handle_cancelled_notification("nonexistent_request")
84+
85+
# Verify no operations were affected
86+
assert len(server._request_to_operation) == 0

0 commit comments

Comments
 (0)