Skip to content

Commit 2df5e7c

Browse files
committed
Implement lowlevel async CallTool
1 parent 04bac41 commit 2df5e7c

File tree

3 files changed

+204
-41
lines changed

3 files changed

+204
-41
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ async def main():
6767

6868
from __future__ import annotations as _annotations
6969

70+
import asyncio
7071
import contextvars
7172
import json
7273
import logging
@@ -465,46 +466,55 @@ async def handler(req: types.CallToolRequest):
465466
except jsonschema.ValidationError as e:
466467
return self._make_error_result(f"Input validation error: {e.message}")
467468

468-
# tool call
469-
results = await func(tool_name, arguments)
469+
# Check for async execution
470+
if tool and self.async_operations and self._should_execute_async(tool):
471+
# Create async operation
472+
session_id = f"session_{id(self.request_context.session)}"
473+
operation = self.async_operations.create_operation(
474+
tool_name=tool_name,
475+
arguments=arguments,
476+
session_id=session_id,
477+
)
478+
logger.debug(f"Created async operation with token: {operation.token}")
470479

471-
# output normalization
472-
unstructured_content: UnstructuredContent
473-
maybe_structured_content: StructuredContent | None
474-
if isinstance(results, tuple) and len(results) == 2:
475-
# tool returned both structured and unstructured content
476-
unstructured_content, maybe_structured_content = cast(CombinationContent, results)
477-
elif isinstance(results, dict):
478-
# tool returned structured content only
479-
maybe_structured_content = cast(StructuredContent, results)
480-
unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))]
481-
elif hasattr(results, "__iter__"):
482-
# tool returned unstructured content only
483-
unstructured_content = cast(UnstructuredContent, results)
484-
maybe_structured_content = None
485-
else:
486-
return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}")
487-
488-
# output validation
489-
if tool and tool.outputSchema is not None:
490-
if maybe_structured_content is None:
491-
return self._make_error_result(
492-
"Output validation error: outputSchema defined but no structured output returned"
493-
)
494-
else:
480+
# Start async execution in background
481+
async def execute_async():
495482
try:
496-
jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema)
497-
except jsonschema.ValidationError as e:
498-
return self._make_error_result(f"Output validation error: {e.message}")
499-
500-
# result
501-
return types.ServerResult(
502-
types.CallToolResult(
503-
content=list(unstructured_content),
504-
structuredContent=maybe_structured_content,
505-
isError=False,
483+
logger.debug(f"Starting async execution of {tool_name}")
484+
results = await func(tool_name, arguments)
485+
logger.debug(f"Async execution completed for {tool_name}")
486+
487+
# Process results using shared logic
488+
result = self._process_tool_result(results, tool)
489+
self.async_operations.complete_operation(operation.token, result)
490+
logger.debug(f"Completed async operation {operation.token}")
491+
except Exception as e:
492+
logger.exception(f"Async execution failed for {tool_name}")
493+
self.async_operations.fail_operation(operation.token, str(e))
494+
495+
asyncio.create_task(execute_async())
496+
497+
# Return operation result immediately
498+
logger.info(f"Returning async operation result for {tool_name}")
499+
return types.ServerResult(
500+
types.CallToolResult(
501+
content=[],
502+
operation=types.AsyncResultProperties(
503+
token=operation.token,
504+
keepAlive=3600,
505+
),
506+
)
506507
)
507-
)
508+
509+
# tool call
510+
results = await func(tool_name, arguments)
511+
512+
# Process results using shared logic
513+
try:
514+
result = self._process_tool_result(results, tool)
515+
return types.ServerResult(result)
516+
except ValueError as e:
517+
return self._make_error_result(str(e))
508518
except Exception as e:
509519
return self._make_error_result(str(e))
510520

@@ -513,6 +523,61 @@ async def handler(req: types.CallToolRequest):
513523

514524
return decorator
515525

526+
def _process_tool_result(
527+
self, results: UnstructuredContent | StructuredContent | CombinationContent, tool: types.Tool | None = None
528+
) -> types.CallToolResult:
529+
"""Process tool results and create CallToolResult with validation."""
530+
# output normalization
531+
unstructured_content: UnstructuredContent
532+
maybe_structured_content: StructuredContent | None
533+
if isinstance(results, tuple) and len(results) == 2:
534+
# tool returned both structured and unstructured content
535+
unstructured_content, maybe_structured_content = cast(CombinationContent, results)
536+
elif isinstance(results, dict):
537+
# tool returned structured content only
538+
maybe_structured_content = cast(StructuredContent, results)
539+
unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))]
540+
elif hasattr(results, "__iter__"):
541+
# tool returned unstructured content only
542+
unstructured_content = cast(UnstructuredContent, results)
543+
maybe_structured_content = None
544+
else:
545+
raise ValueError(f"Unexpected return type from tool: {type(results).__name__}")
546+
547+
# output validation
548+
if tool and tool.outputSchema is not None:
549+
if maybe_structured_content is None:
550+
raise ValueError("Output validation error: outputSchema defined but no structured output returned")
551+
else:
552+
try:
553+
jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema)
554+
except jsonschema.ValidationError as e:
555+
raise ValueError(f"Output validation error: {e.message}")
556+
557+
# result
558+
return types.CallToolResult(
559+
content=list(unstructured_content),
560+
structuredContent=maybe_structured_content,
561+
isError=False,
562+
)
563+
564+
def _should_execute_async(self, tool: types.Tool) -> bool:
565+
"""Check if a tool should be executed asynchronously."""
566+
# Check if client supports async tools (protocol version "next")
567+
try:
568+
if self.request_context and self.request_context.session.client_params:
569+
client_version = str(self.request_context.session.client_params.protocolVersion)
570+
if client_version != "next":
571+
return False
572+
else:
573+
return False
574+
except (AttributeError, ValueError):
575+
return False
576+
577+
# Check if tool is async-only
578+
invocation_mode = getattr(tool, "invocationMode", None)
579+
return invocation_mode == "async"
580+
516581
def progress_notification(self):
517582
def decorator(
518583
func: Callable[[str | int, float, float | None, str | None], Awaitable[None]],
@@ -783,9 +848,9 @@ async def _handle_request(
783848
# Track async operations for cancellation
784849
if isinstance(req, types.CallToolRequest):
785850
result = response.root
786-
if isinstance(result, types.CallToolResult) and result.operation_result is not None:
851+
if isinstance(result, types.CallToolResult) and result.operation is not None:
787852
# This is an async operation, track the request ID to token mapping
788-
operation_token = result.operation_result.token
853+
operation_token = result.operation.token
789854
self._request_to_operation[message.request_id] = operation_token
790855
logger.debug(f"Tracking async operation {operation_token} for request {message.request_id}")
791856

src/mcp/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class Operation(BaseModel):
130130
See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
131131
for notes on _meta usage.
132132
"""
133-
operation: Operation | None = Field(alias="_operation", default=None)
133+
_operation: Operation | None = None
134134
"""
135135
Async operation parameters, only used when a result is sent in response to a request with operation parameters.
136136
"""
@@ -992,7 +992,7 @@ class CallToolResult(Result):
992992
structuredContent: dict[str, Any] | None = None
993993
"""An optional JSON object that represents the structured result of the tool call."""
994994
isError: bool = False
995-
operation_result: AsyncResultProperties | None = Field(serialization_alias="operation", default=None)
995+
operation: AsyncResultProperties | None = Field(default=None)
996996
"""Optional async execution information. Present when tool is executed asynchronously."""
997997

998998

tests/server/fastmcp/test_server.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import base64
23
from pathlib import Path
34
from typing import TYPE_CHECKING, Any
@@ -651,6 +652,103 @@ def hybrid_tool(x: int) -> int:
651652
# Hybrid tools should not have invocationMode field (None) for old clients
652653
assert tool.invocationMode is None
653654

655+
@pytest.mark.anyio
656+
async def test_async_tool_call_basic(self):
657+
"""Test basic async tool call functionality."""
658+
mcp = FastMCP("AsyncTest")
659+
660+
@mcp.tool(invocation_modes=["async"])
661+
async def async_add(a: int, b: int) -> int:
662+
"""Add two numbers asynchronously."""
663+
await asyncio.sleep(0.01) # Simulate async work
664+
return a + b
665+
666+
async with client_session(mcp._mcp_server, protocol_version="next") as client:
667+
result = await client.call_tool("async_add", {"a": 5, "b": 3})
668+
669+
# Should get operation token for async call
670+
assert result.operation is not None
671+
token = result.operation.token
672+
673+
# Poll for completion
674+
while True:
675+
status = await client.get_operation_status(token)
676+
if status.status == "completed":
677+
final_result = await client.get_operation_result(token)
678+
assert not final_result.result.isError
679+
assert len(final_result.result.content) == 1
680+
content = final_result.result.content[0]
681+
assert isinstance(content, TextContent)
682+
assert content.text == "8"
683+
break
684+
elif status.status == "failed":
685+
pytest.fail(f"Operation failed: {status.error}")
686+
await asyncio.sleep(0.01)
687+
688+
@pytest.mark.anyio
689+
async def test_async_tool_call_structured_output(self):
690+
"""Test async tool call with structured output."""
691+
mcp = FastMCP("AsyncTest")
692+
693+
class AsyncResult(BaseModel):
694+
value: int
695+
processed: bool = True
696+
697+
@mcp.tool(invocation_modes=["async"])
698+
async def async_structured_tool(x: int) -> AsyncResult:
699+
"""Process data and return structured result."""
700+
await asyncio.sleep(0.01) # Simulate async work
701+
return AsyncResult(value=x * 2)
702+
703+
async with client_session(mcp._mcp_server, protocol_version="next") as client:
704+
result = await client.call_tool("async_structured_tool", {"x": 21})
705+
706+
# Should get operation token for async call
707+
assert result.operation is not None
708+
token = result.operation.token
709+
710+
# Poll for completion
711+
while True:
712+
status = await client.get_operation_status(token)
713+
if status.status == "completed":
714+
final_result = await client.get_operation_result(token)
715+
assert not final_result.result.isError
716+
assert final_result.result.structuredContent is not None
717+
assert final_result.result.structuredContent == {"value": 42, "processed": True}
718+
break
719+
elif status.status == "failed":
720+
pytest.fail(f"Operation failed: {status.error}")
721+
await asyncio.sleep(0.01)
722+
723+
@pytest.mark.anyio
724+
async def test_async_tool_call_validation_error(self):
725+
"""Test async tool call with server-side validation error."""
726+
mcp = FastMCP("AsyncTest")
727+
728+
@mcp.tool(invocation_modes=["async"])
729+
async def async_invalid_tool() -> list[int]:
730+
"""Tool that returns invalid structured output."""
731+
await asyncio.sleep(0.01) # Simulate async work
732+
return [1, 2, 3, [4]] # type: ignore
733+
734+
async with client_session(mcp._mcp_server, protocol_version="next") as client:
735+
result = await client.call_tool("async_invalid_tool", {})
736+
737+
# Should get operation token for async call
738+
assert result.operation is not None
739+
token = result.operation.token
740+
741+
# Poll for completion - should fail due to validation error
742+
while True:
743+
status = await client.get_operation_status(token)
744+
if status.status == "failed":
745+
# Operation should fail due to validation error
746+
assert status.error is not None
747+
break
748+
elif status.status == "completed":
749+
pytest.fail("Operation should have failed due to validation error")
750+
await asyncio.sleep(0.01)
751+
654752

655753
class TestServerResources:
656754
@pytest.mark.anyio

0 commit comments

Comments
 (0)