@@ -239,3 +239,221 @@ async def run_handler():
239239 assert isinstance (response , types .ServerResult )
240240 payload_result = cast (types .GetOperationPayloadResult , response .root )
241241 assert payload_result .result == result
242+
243+
244+ class TestCancellationLogic :
245+ """Test cancellation logic for async operations."""
246+
247+ def test_handle_cancelled_notification (self ):
248+ """Test handling of cancelled notifications."""
249+ manager = AsyncOperationManager ()
250+ server = Server ("Test" , async_operations = manager )
251+
252+ # Create an operation
253+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
254+
255+ # Track the operation with a request ID
256+ request_id = "req_123"
257+ server ._request_to_operation [request_id ] = operation .token
258+
259+ # Handle cancellation
260+ server .handle_cancelled_notification (request_id )
261+
262+ # Verify operation was cancelled
263+ cancelled_op = manager .get_operation (operation .token )
264+ assert cancelled_op is not None
265+ assert cancelled_op .status == "canceled"
266+
267+ # Verify mapping was cleaned up
268+ assert request_id not in server ._request_to_operation
269+
270+ def test_cancelled_notification_handler (self ):
271+ """Test the async cancelled notification handler."""
272+ manager = AsyncOperationManager ()
273+ server = Server ("Test" , async_operations = manager )
274+
275+ # Create an operation
276+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
277+
278+ # Track the operation with a request ID
279+ request_id = "req_456"
280+ server ._request_to_operation [request_id ] = operation .token
281+
282+ # Create cancelled notification
283+ notification = types .CancelledNotification (params = types .CancelledNotificationParams (requestId = request_id ))
284+
285+ # Handle the notification
286+ import asyncio
287+
288+ asyncio .run (server ._handle_cancelled_notification (notification ))
289+
290+ # Verify operation was cancelled
291+ cancelled_op = manager .get_operation (operation .token )
292+ assert cancelled_op is not None
293+ assert cancelled_op .status == "canceled"
294+
295+ def test_validate_operation_token_cancelled (self ):
296+ """Test that cancelled operations are rejected."""
297+ manager = AsyncOperationManager ()
298+ server = Server ("Test" , async_operations = manager )
299+
300+ # Create and cancel an operation
301+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
302+ manager .cancel_operation (operation .token )
303+
304+ # Verify that accessing cancelled operation raises error
305+ with pytest .raises (McpError ) as exc_info :
306+ server ._validate_operation_token (operation .token )
307+
308+ assert exc_info .value .error .code == - 32602
309+ assert "cancelled" in exc_info .value .error .message .lower ()
310+
311+ def test_nonexistent_request_id_cancellation (self ):
312+ """Test cancellation of non-existent request ID."""
313+ server = Server ("Test" )
314+
315+ # Should not raise error for non-existent request ID
316+ server .handle_cancelled_notification ("nonexistent_request" )
317+
318+ # Verify no operations were affected
319+ assert len (server ._request_to_operation ) == 0
320+
321+
322+ class TestInputRequiredBehavior :
323+ """Test input_required status handling for async operations."""
324+
325+ def test_mark_input_required (self ):
326+ """Test marking operation as requiring input."""
327+ manager = AsyncOperationManager ()
328+
329+ # Create operation in submitted state
330+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
331+ assert operation .status == "submitted"
332+
333+ # Mark as input required
334+ result = manager .mark_input_required (operation .token )
335+ assert result is True
336+
337+ # Verify status changed
338+ updated_op = manager .get_operation (operation .token )
339+ assert updated_op is not None
340+ assert updated_op .status == "input_required"
341+
342+ def test_mark_input_required_from_working (self ):
343+ """Test marking working operation as requiring input."""
344+ manager = AsyncOperationManager ()
345+
346+ # Create and mark as working
347+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
348+ manager .mark_working (operation .token )
349+ assert operation .status == "working"
350+
351+ # Mark as input required
352+ result = manager .mark_input_required (operation .token )
353+ assert result is True
354+ assert operation .status == "input_required"
355+
356+ def test_mark_input_required_invalid_states (self ):
357+ """Test that input_required can only be set from valid states."""
358+ manager = AsyncOperationManager ()
359+
360+ # Test from completed state
361+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
362+ manager .complete_operation (operation .token , types .CallToolResult (content = []))
363+
364+ result = manager .mark_input_required (operation .token )
365+ assert result is False
366+ assert operation .status == "completed"
367+
368+ def test_mark_input_completed (self ):
369+ """Test marking input as completed."""
370+ manager = AsyncOperationManager ()
371+
372+ # Create operation and mark as input required
373+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
374+ manager .mark_input_required (operation .token )
375+ assert operation .status == "input_required"
376+
377+ # Mark input as completed
378+ result = manager .mark_input_completed (operation .token )
379+ assert result is True
380+ assert operation .status == "working"
381+
382+ def test_mark_input_completed_invalid_state (self ):
383+ """Test that input can only be completed from input_required state."""
384+ manager = AsyncOperationManager ()
385+
386+ # Create operation in submitted state
387+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
388+ assert operation .status == "submitted"
389+
390+ # Try to mark input completed from wrong state
391+ result = manager .mark_input_completed (operation .token )
392+ assert result is False
393+ assert operation .status == "submitted"
394+
395+ def test_nonexistent_token_operations (self ):
396+ """Test input_required operations on nonexistent tokens."""
397+ manager = AsyncOperationManager ()
398+
399+ # Test with fake token
400+ assert manager .mark_input_required ("fake_token" ) is False
401+ assert manager .mark_input_completed ("fake_token" ) is False
402+
403+ def test_server_send_request_for_operation (self ):
404+ """Test server method for sending requests with operation tokens."""
405+ manager = AsyncOperationManager ()
406+ server = Server ("Test" , async_operations = manager )
407+
408+ # Create operation
409+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
410+ manager .mark_working (operation .token )
411+
412+ # Create a mock request
413+ request = types .ServerRequest (
414+ types .CreateMessageRequest (
415+ params = types .CreateMessageRequestParams (
416+ messages = [types .SamplingMessage (role = "user" , content = types .TextContent (type = "text" , text = "test" ))],
417+ maxTokens = 100 ,
418+ )
419+ )
420+ )
421+
422+ # Send request for operation
423+ server .send_request_for_operation (operation .token , request )
424+
425+ # Verify operation status changed
426+ updated_op = manager .get_operation (operation .token )
427+ assert updated_op is not None
428+ assert updated_op .status == "input_required"
429+
430+ def test_server_complete_request_for_operation (self ):
431+ """Test server method for completing requests."""
432+ manager = AsyncOperationManager ()
433+ server = Server ("Test" , async_operations = manager )
434+
435+ # Create operation and mark as input required
436+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
437+ manager .mark_input_required (operation .token )
438+
439+ # Complete request for operation
440+ server .complete_request_for_operation (operation .token )
441+
442+ # Verify operation status changed back to working
443+ updated_op = manager .get_operation (operation .token )
444+ assert updated_op is not None
445+ assert updated_op .status == "working"
446+
447+ def test_input_required_is_terminal_check (self ):
448+ """Test that input_required is not considered a terminal state."""
449+ manager = AsyncOperationManager ()
450+
451+ # Create operation and mark as input required
452+ operation = manager .create_operation ("test_tool" , {"arg" : "value" }, "session1" )
453+ manager .mark_input_required (operation .token )
454+
455+ # Verify it's not terminal
456+ assert not operation .is_terminal
457+
458+ # Verify it doesn't expire while in input_required state
459+ assert not operation .is_expired
0 commit comments