@@ -90,6 +90,7 @@ async def main():
9090from mcp .shared .exceptions import McpError
9191from mcp .shared .message import ServerMessageMetadata , SessionMessage
9292from mcp .shared .session import RequestResponder
93+ from mcp .types import RequestId
9394
9495logger = logging .getLogger (__name__ )
9596
@@ -147,10 +148,14 @@ def __init__(
147148 self .instructions = instructions
148149 self .lifespan = lifespan
149150 self .async_operations = async_operations or AsyncOperationManager ()
151+ # Track request ID to operation token mapping for cancellation
152+ self ._request_to_operation : dict [RequestId , str ] = {}
150153 self .request_handlers : dict [type , Callable [..., Awaitable [types .ServerResult ]]] = {
151154 types .PingRequest : _ping_handler ,
152155 }
153- self .notification_handlers : dict [type , Callable [..., Awaitable [None ]]] = {}
156+ self .notification_handlers : dict [type , Callable [..., Awaitable [None ]]] = {
157+ types .CancelledNotification : self ._handle_cancelled_notification ,
158+ }
154159 self ._tool_cache : dict [str , types .Tool ] = {}
155160 logger .debug ("Initializing server %r" , name )
156161
@@ -566,6 +571,10 @@ def _validate_operation_token(self, token: str) -> AsyncOperation:
566571 if operation .is_expired :
567572 raise McpError (types .ErrorData (code = - 32602 , message = "Token expired" ))
568573
574+ # Check if operation was cancelled - ignore subsequent requests
575+ if operation .status == "canceled" :
576+ raise McpError (types .ErrorData (code = - 32602 , message = "Operation was cancelled" ))
577+
569578 return operation
570579
571580 def get_operation_status (self ):
@@ -615,6 +624,23 @@ async def handler(req: types.GetOperationPayloadRequest):
615624
616625 return decorator
617626
627+ def handle_cancelled_notification (self , request_id : RequestId ) -> None :
628+ """Handle cancellation notification for a request."""
629+ # Check if this request ID corresponds to an async operation
630+ if request_id in self ._request_to_operation :
631+ token = self ._request_to_operation [request_id ]
632+ # Cancel the operation
633+ if self .async_operations .cancel_operation (token ):
634+ logger .debug (f"Cancelled async operation { token } for request { request_id } " )
635+ # Clean up the mapping
636+ del self ._request_to_operation [request_id ]
637+
638+ async def _handle_cancelled_notification (self , notification : types .CancelledNotification ) -> None :
639+ """Handle cancelled notification from client."""
640+ request_id = notification .params .requestId
641+ logger .debug (f"Received cancellation notification for request { request_id } " )
642+ self .handle_cancelled_notification (request_id )
643+
618644 async def run (
619645 self ,
620646 read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ],
@@ -695,7 +721,7 @@ async def _handle_request(
695721 if handler := self .request_handlers .get (type (req )): # type: ignore
696722 logger .debug ("Dispatching request of type %s" , type (req ).__name__ )
697723
698- token = None
724+ context_token = None
699725 try :
700726 # Extract request context from message metadata
701727 request_data = None
@@ -704,7 +730,7 @@ async def _handle_request(
704730
705731 # Set our global state that can be retrieved via
706732 # app.get_request_context()
707- token = request_ctx .set (
733+ context_token = request_ctx .set (
708734 RequestContext (
709735 message .request_id ,
710736 message .request_meta ,
@@ -714,6 +740,16 @@ async def _handle_request(
714740 )
715741 )
716742 response = await handler (req )
743+
744+ # Track async operations for cancellation
745+ if isinstance (req , types .CallToolRequest ):
746+ result = response .root
747+ if isinstance (result , types .CallToolResult ) and result .operation_result is not None :
748+ # This is an async operation, track the request ID to token mapping
749+ operation_token = result .operation_result .token
750+ self ._request_to_operation [message .request_id ] = operation_token
751+ logger .debug (f"Tracking async operation { operation_token } for request { message .request_id } " )
752+
717753 except McpError as err :
718754 response = err .error
719755 except anyio .get_cancelled_exc_class ():
@@ -728,8 +764,8 @@ async def _handle_request(
728764 response = types .ErrorData (code = 0 , message = str (err ), data = None )
729765 finally :
730766 # Reset the global state after we are done
731- if token is not None :
732- request_ctx .reset (token )
767+ if context_token is not None :
768+ request_ctx .reset (context_token )
733769
734770 await message .respond (response )
735771 else :
0 commit comments