Skip to content

Commit e70f441

Browse files
committed
Implement support for input_required status
1 parent 0dc8d43 commit e70f441

File tree

4 files changed

+283
-86
lines changed

4 files changed

+283
-86
lines changed

src/mcp/server/lowlevel/async_operations.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,32 @@ def cancel_session_operations(self, session_id: str) -> int:
172172

173173
return canceled_count
174174

175+
def mark_input_required(self, token: str) -> bool:
176+
"""Mark operation as requiring input from client."""
177+
operation = self._operations.get(token)
178+
if not operation:
179+
return False
180+
181+
# Can only move to input_required from submitted or working states
182+
if operation.status not in ("submitted", "working"):
183+
return False
184+
185+
operation.status = "input_required"
186+
return True
187+
188+
def mark_input_completed(self, token: str) -> bool:
189+
"""Mark operation as no longer requiring input, return to working state."""
190+
operation = self._operations.get(token)
191+
if not operation:
192+
return False
193+
194+
# Can only move from input_required back to working
195+
if operation.status != "input_required":
196+
return False
197+
198+
operation.status = "working"
199+
return True
200+
175201
async def start_cleanup_task(self) -> None:
176202
"""Start the background cleanup task."""
177203
if self._cleanup_task is not None:

src/mcp/server/lowlevel/server.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,35 @@ async def _handle_cancelled_notification(self, notification: types.CancelledNoti
641641
logger.debug(f"Received cancellation notification for request {request_id}")
642642
self.handle_cancelled_notification(request_id)
643643

644+
def send_request_for_operation(self, token: str, request: types.ServerRequest) -> None:
645+
"""Send a request associated with an async operation."""
646+
# Mark operation as requiring input
647+
if self.async_operations.mark_input_required(token):
648+
# Add operation token to request
649+
if hasattr(request.root, "params") and request.root.params is not None:
650+
if not hasattr(request.root.params, "operation") or request.root.params.operation is None:
651+
# Create operation field if it doesn't exist
652+
operation_data = types.RequestParams.Operation(token=token)
653+
request.root.params.operation = operation_data
654+
logger.debug(f"Marked operation {token} as input_required and added to request")
655+
656+
def send_notification_for_operation(self, token: str, notification: types.ServerNotification) -> None:
657+
"""Send a notification associated with an async operation."""
658+
# Mark operation as requiring input
659+
if self.async_operations.mark_input_required(token):
660+
# Add operation token to notification
661+
if hasattr(notification.root, "params") and notification.root.params is not None:
662+
if not hasattr(notification.root.params, "operation") or notification.root.params.operation is None:
663+
# Create operation field if it doesn't exist
664+
operation_data = types.NotificationParams.Operation(token=token)
665+
notification.root.params.operation = operation_data
666+
logger.debug(f"Marked operation {token} as input_required and added to notification")
667+
668+
def complete_request_for_operation(self, token: str) -> None:
669+
"""Mark that a request for an operation has been completed."""
670+
if self.async_operations.mark_input_completed(token):
671+
logger.debug(f"Marked operation {token} as no longer requiring input")
672+
644673
async def run(
645674
self,
646675
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
@@ -741,6 +770,16 @@ async def _handle_request(
741770
)
742771
response = await handler(req)
743772

773+
# Handle operation token in response (for input_required operations)
774+
if (
775+
hasattr(req, "params")
776+
and req.params is not None
777+
and hasattr(req.params, "operation")
778+
and req.params.operation is not None
779+
):
780+
operation_token = req.params.operation.token
781+
self.complete_request_for_operation(operation_token)
782+
744783
# Track async operations for cancellation
745784
if isinstance(req, types.CallToolRequest):
746785
result = response.root

tests/server/test_cancellation_logic.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

tests/server/test_lowlevel_async_operations.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,221 @@ async def run_handler():
239239
assert isinstance(response, types.ServerResult)
240240
payload_result = cast(types.GetOperationPayloadResult, response.root)
241241
assert payload_result.result == result
242+
243+
244+
class TestCancellationLogic:
245+
"""Test cancellation logic for async operations."""
246+
247+
def test_handle_cancelled_notification(self):
248+
"""Test handling of cancelled notifications."""
249+
manager = AsyncOperationManager()
250+
server = Server("Test", async_operations=manager)
251+
252+
# Create an operation
253+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
254+
255+
# Track the operation with a request ID
256+
request_id = "req_123"
257+
server._request_to_operation[request_id] = operation.token
258+
259+
# Handle cancellation
260+
server.handle_cancelled_notification(request_id)
261+
262+
# Verify operation was cancelled
263+
cancelled_op = manager.get_operation(operation.token)
264+
assert cancelled_op is not None
265+
assert cancelled_op.status == "canceled"
266+
267+
# Verify mapping was cleaned up
268+
assert request_id not in server._request_to_operation
269+
270+
def test_cancelled_notification_handler(self):
271+
"""Test the async cancelled notification handler."""
272+
manager = AsyncOperationManager()
273+
server = Server("Test", async_operations=manager)
274+
275+
# Create an operation
276+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
277+
278+
# Track the operation with a request ID
279+
request_id = "req_456"
280+
server._request_to_operation[request_id] = operation.token
281+
282+
# Create cancelled notification
283+
notification = types.CancelledNotification(params=types.CancelledNotificationParams(requestId=request_id))
284+
285+
# Handle the notification
286+
import asyncio
287+
288+
asyncio.run(server._handle_cancelled_notification(notification))
289+
290+
# Verify operation was cancelled
291+
cancelled_op = manager.get_operation(operation.token)
292+
assert cancelled_op is not None
293+
assert cancelled_op.status == "canceled"
294+
295+
def test_validate_operation_token_cancelled(self):
296+
"""Test that cancelled operations are rejected."""
297+
manager = AsyncOperationManager()
298+
server = Server("Test", async_operations=manager)
299+
300+
# Create and cancel an operation
301+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
302+
manager.cancel_operation(operation.token)
303+
304+
# Verify that accessing cancelled operation raises error
305+
with pytest.raises(McpError) as exc_info:
306+
server._validate_operation_token(operation.token)
307+
308+
assert exc_info.value.error.code == -32602
309+
assert "cancelled" in exc_info.value.error.message.lower()
310+
311+
def test_nonexistent_request_id_cancellation(self):
312+
"""Test cancellation of non-existent request ID."""
313+
server = Server("Test")
314+
315+
# Should not raise error for non-existent request ID
316+
server.handle_cancelled_notification("nonexistent_request")
317+
318+
# Verify no operations were affected
319+
assert len(server._request_to_operation) == 0
320+
321+
322+
class TestInputRequiredBehavior:
323+
"""Test input_required status handling for async operations."""
324+
325+
def test_mark_input_required(self):
326+
"""Test marking operation as requiring input."""
327+
manager = AsyncOperationManager()
328+
329+
# Create operation in submitted state
330+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
331+
assert operation.status == "submitted"
332+
333+
# Mark as input required
334+
result = manager.mark_input_required(operation.token)
335+
assert result is True
336+
337+
# Verify status changed
338+
updated_op = manager.get_operation(operation.token)
339+
assert updated_op is not None
340+
assert updated_op.status == "input_required"
341+
342+
def test_mark_input_required_from_working(self):
343+
"""Test marking working operation as requiring input."""
344+
manager = AsyncOperationManager()
345+
346+
# Create and mark as working
347+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
348+
manager.mark_working(operation.token)
349+
assert operation.status == "working"
350+
351+
# Mark as input required
352+
result = manager.mark_input_required(operation.token)
353+
assert result is True
354+
assert operation.status == "input_required"
355+
356+
def test_mark_input_required_invalid_states(self):
357+
"""Test that input_required can only be set from valid states."""
358+
manager = AsyncOperationManager()
359+
360+
# Test from completed state
361+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
362+
manager.complete_operation(operation.token, types.CallToolResult(content=[]))
363+
364+
result = manager.mark_input_required(operation.token)
365+
assert result is False
366+
assert operation.status == "completed"
367+
368+
def test_mark_input_completed(self):
369+
"""Test marking input as completed."""
370+
manager = AsyncOperationManager()
371+
372+
# Create operation and mark as input required
373+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
374+
manager.mark_input_required(operation.token)
375+
assert operation.status == "input_required"
376+
377+
# Mark input as completed
378+
result = manager.mark_input_completed(operation.token)
379+
assert result is True
380+
assert operation.status == "working"
381+
382+
def test_mark_input_completed_invalid_state(self):
383+
"""Test that input can only be completed from input_required state."""
384+
manager = AsyncOperationManager()
385+
386+
# Create operation in submitted state
387+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
388+
assert operation.status == "submitted"
389+
390+
# Try to mark input completed from wrong state
391+
result = manager.mark_input_completed(operation.token)
392+
assert result is False
393+
assert operation.status == "submitted"
394+
395+
def test_nonexistent_token_operations(self):
396+
"""Test input_required operations on nonexistent tokens."""
397+
manager = AsyncOperationManager()
398+
399+
# Test with fake token
400+
assert manager.mark_input_required("fake_token") is False
401+
assert manager.mark_input_completed("fake_token") is False
402+
403+
def test_server_send_request_for_operation(self):
404+
"""Test server method for sending requests with operation tokens."""
405+
manager = AsyncOperationManager()
406+
server = Server("Test", async_operations=manager)
407+
408+
# Create operation
409+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
410+
manager.mark_working(operation.token)
411+
412+
# Create a mock request
413+
request = types.ServerRequest(
414+
types.CreateMessageRequest(
415+
params=types.CreateMessageRequestParams(
416+
messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="test"))],
417+
maxTokens=100,
418+
)
419+
)
420+
)
421+
422+
# Send request for operation
423+
server.send_request_for_operation(operation.token, request)
424+
425+
# Verify operation status changed
426+
updated_op = manager.get_operation(operation.token)
427+
assert updated_op is not None
428+
assert updated_op.status == "input_required"
429+
430+
def test_server_complete_request_for_operation(self):
431+
"""Test server method for completing requests."""
432+
manager = AsyncOperationManager()
433+
server = Server("Test", async_operations=manager)
434+
435+
# Create operation and mark as input required
436+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
437+
manager.mark_input_required(operation.token)
438+
439+
# Complete request for operation
440+
server.complete_request_for_operation(operation.token)
441+
442+
# Verify operation status changed back to working
443+
updated_op = manager.get_operation(operation.token)
444+
assert updated_op is not None
445+
assert updated_op.status == "working"
446+
447+
def test_input_required_is_terminal_check(self):
448+
"""Test that input_required is not considered a terminal state."""
449+
manager = AsyncOperationManager()
450+
451+
# Create operation and mark as input required
452+
operation = manager.create_operation("test_tool", {"arg": "value"}, "session1")
453+
manager.mark_input_required(operation.token)
454+
455+
# Verify it's not terminal
456+
assert not operation.is_terminal
457+
458+
# Verify it doesn't expire while in input_required state
459+
assert not operation.is_expired

0 commit comments

Comments
 (0)