Skip to content

Commit e5e4078

Browse files
committed
Implement session functions for async tools
1 parent cbda6e3 commit e5e4078

File tree

6 files changed

+270
-4
lines changed

6 files changed

+270
-4
lines changed

src/mcp/client/session.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import timedelta
33
from typing import Any, Protocol
44

5+
import anyio
56
import anyio.lowlevel
67
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
78
from jsonschema import SchemaError, ValidationError, validate
@@ -273,15 +274,26 @@ async def call_tool(
273274
arguments: dict[str, Any] | None = None,
274275
read_timeout_seconds: timedelta | None = None,
275276
progress_callback: ProgressFnT | None = None,
277+
*,
278+
async_properties: types.AsyncRequestProperties | None = None,
276279
) -> types.CallToolResult:
277-
"""Send a tools/call request with optional progress callback support."""
280+
"""Send a tools/call request with optional progress callback support.
281+
282+
Args:
283+
name: Name of the tool to call
284+
arguments: Arguments to pass to the tool
285+
read_timeout_seconds: Read timeout for the request
286+
progress_callback: Optional progress callback
287+
async_properties: Optional async parameters for async tool execution
288+
"""
278289

279290
result = await self.send_request(
280291
types.ClientRequest(
281292
types.CallToolRequest(
282293
params=types.CallToolRequestParams(
283294
name=name,
284295
arguments=arguments,
296+
async_properties=async_properties,
285297
),
286298
)
287299
),
@@ -295,6 +307,42 @@ async def call_tool(
295307

296308
return result
297309

310+
async def check_tool_async_status(self, token: str) -> types.CheckToolAsyncStatusResult:
311+
"""Check the status of an async tool operation.
312+
313+
Args:
314+
token: Token returned from async call_tool
315+
316+
Returns:
317+
Status result with current operation state
318+
"""
319+
return await self.send_request(
320+
types.ClientRequest(
321+
types.CheckToolAsyncStatusRequest(
322+
params=types.CheckToolAsyncStatusParams(token=token),
323+
)
324+
),
325+
types.CheckToolAsyncStatusResult,
326+
)
327+
328+
async def get_tool_async_result(self, token: str) -> types.GetToolAsyncPayloadResult:
329+
"""Get the result of a completed async tool operation.
330+
331+
Args:
332+
token: Token returned from async call_tool
333+
334+
Returns:
335+
The final tool result
336+
"""
337+
return await self.send_request(
338+
types.ClientRequest(
339+
types.GetToolAsyncPayloadRequest(
340+
params=types.GetToolAsyncPayloadParams(token=token),
341+
)
342+
),
343+
types.GetToolAsyncPayloadResult,
344+
)
345+
298346
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
299347
"""Validate the structured content of a tool result against its output schema."""
300348
if name not in self._tool_output_schemas:

src/mcp/server/fastmcp/server.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from mcp.server.fastmcp.prompts import Prompt, PromptManager
3131
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
3232
from mcp.server.fastmcp.tools import Tool, ToolManager
33+
from mcp.server.fastmcp.tools.base import InvocationMode
3334
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
3435
from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger
3536
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -43,7 +44,7 @@
4344
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4445
from mcp.server.transport_security import TransportSecuritySettings
4546
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
46-
from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations
47+
from mcp.types import NEXT_PROTOCOL_VERSION, AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations
4748
from mcp.types import Prompt as MCPPrompt
4849
from mcp.types import PromptArgument as MCPPromptArgument
4950
from mcp.types import Resource as MCPResource
@@ -266,9 +267,39 @@ def _setup_handlers(self) -> None:
266267
self._mcp_server.get_prompt()(self.get_prompt)
267268
self._mcp_server.list_resource_templates()(self.list_resource_templates)
268269

270+
def _client_supports_async(self) -> bool:
271+
"""Check if the current client supports async tools based on protocol version."""
272+
try:
273+
context = self.get_context()
274+
if context.request_context and context.request_context.session.client_params:
275+
client_version = str(context.request_context.session.client_params.protocolVersion)
276+
# Only "next" version supports async tools for now
277+
return client_version == NEXT_PROTOCOL_VERSION
278+
except ValueError:
279+
# Context not available (outside of request), assume no async support
280+
pass
281+
return False
282+
283+
def _get_invocation_mode(self, info: Tool, client_supports_async: bool) -> Literal["sync", "async"] | None:
284+
"""Determine invocationMode field based on client support."""
285+
if not client_supports_async:
286+
return None # Old clients don't see invocationMode field
287+
288+
# New clients see the invocationMode field
289+
if "async" in info.invocation_modes and len(info.invocation_modes) == 1:
290+
return "async" # Async-only
291+
elif len(info.invocation_modes) > 1 or info.invocation_modes == ["sync"]:
292+
return "sync" # Hybrid or explicit sync
293+
return None
294+
269295
async def list_tools(self) -> list[MCPTool]:
270296
"""List all available tools."""
271297
tools = self._tool_manager.list_tools()
298+
299+
# Check if client supports async tools based on protocol version
300+
client_supports_async = self._client_supports_async()
301+
302+
# Filter out async-only tools for old clients and set invocationMode based on client support
272303
return [
273304
MCPTool(
274305
name=info.name,
@@ -277,8 +308,10 @@ async def list_tools(self) -> list[MCPTool]:
277308
inputSchema=info.parameters,
278309
outputSchema=info.output_schema,
279310
annotations=info.annotations,
311+
invocationMode=self._get_invocation_mode(info, client_supports_async),
280312
)
281313
for info in tools
314+
if client_supports_async or info.invocation_modes != ["async"]
282315
]
283316

284317
def get_context(self) -> Context[ServerSession, LifespanResultT, Request]:
@@ -348,6 +381,7 @@ def add_tool(
348381
description: str | None = None,
349382
annotations: ToolAnnotations | None = None,
350383
structured_output: bool | None = None,
384+
invocation_modes: list[InvocationMode] | None = None,
351385
) -> None:
352386
"""Add a tool to the server.
353387
@@ -364,6 +398,8 @@ def add_tool(
364398
- If None, auto-detects based on the function's return type annotation
365399
- If True, unconditionally creates a structured tool (return type annotation permitting)
366400
- If False, unconditionally creates an unstructured tool
401+
invocation_modes: List of supported invocation modes (e.g., ["sync", "async"])
402+
- If None, defaults to ["sync"] for backwards compatibility
367403
"""
368404
self._tool_manager.add_tool(
369405
fn,
@@ -372,6 +408,7 @@ def add_tool(
372408
description=description,
373409
annotations=annotations,
374410
structured_output=structured_output,
411+
invocation_modes=invocation_modes,
375412
)
376413

377414
def tool(
@@ -381,6 +418,7 @@ def tool(
381418
description: str | None = None,
382419
annotations: ToolAnnotations | None = None,
383420
structured_output: bool | None = None,
421+
invocation_modes: list[InvocationMode] | None = None,
384422
) -> Callable[[AnyFunction], AnyFunction]:
385423
"""Decorator to register a tool.
386424
@@ -397,6 +435,10 @@ def tool(
397435
- If None, auto-detects based on the function's return type annotation
398436
- If True, unconditionally creates a structured tool (return type annotation permitting)
399437
- If False, unconditionally creates an unstructured tool
438+
invocation_modes: List of supported invocation modes (e.g., ["sync", "async"])
439+
- If None, defaults to ["sync"] for backwards compatibility
440+
- Supports "sync" for synchronous execution and "async" for asynchronous execution
441+
- Tools with "async" mode will be hidden from clients that don't support async execution
400442
401443
Example:
402444
@server.tool()
@@ -412,6 +454,17 @@ def tool_with_context(x: int, ctx: Context) -> str:
412454
async def async_tool(x: int, context: Context) -> str:
413455
await context.report_progress(50, 100)
414456
return str(x)
457+
458+
@server.tool(invocation_modes=["async"])
459+
async def async_only_tool(data: str, ctx: Context) -> str:
460+
# This tool only supports async execution
461+
await ctx.info("Starting long-running analysis...")
462+
return await analyze_data(data)
463+
464+
@server.tool(invocation_modes=["sync", "async"])
465+
def hybrid_tool(x: int) -> str:
466+
# This tool supports both sync and async execution
467+
return str(x)
415468
"""
416469
# Check if user passed function directly instead of calling decorator
417470
if callable(name):
@@ -427,6 +480,7 @@ def decorator(fn: AnyFunction) -> AnyFunction:
427480
description=description,
428481
annotations=annotations,
429482
structured_output=structured_output,
483+
invocation_modes=invocation_modes,
430484
)
431485
return fn
432486

src/mcp/server/fastmcp/tools/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import inspect
55
from collections.abc import Callable
66
from functools import cached_property
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING, Any, Literal
88

99
from pydantic import BaseModel, Field
1010

@@ -18,6 +18,8 @@
1818
from mcp.server.session import ServerSessionT
1919
from mcp.shared.context import LifespanContextT, RequestT
2020

21+
InvocationMode = Literal["sync", "async"]
22+
2123

2224
class Tool(BaseModel):
2325
"""Internal tool registration info."""
@@ -33,6 +35,9 @@ class Tool(BaseModel):
3335
is_async: bool = Field(description="Whether the tool is async")
3436
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
3537
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
38+
invocation_modes: list[InvocationMode] = Field(
39+
default=["sync"], description="Supported invocation modes (sync/async)"
40+
)
3641

3742
@cached_property
3843
def output_schema(self) -> dict[str, Any] | None:
@@ -48,6 +53,7 @@ def from_function(
4853
context_kwarg: str | None = None,
4954
annotations: ToolAnnotations | None = None,
5055
structured_output: bool | None = None,
56+
invocation_modes: list[InvocationMode] | None = None,
5157
) -> Tool:
5258
"""Create a Tool from a function."""
5359
func_name = name or fn.__name__
@@ -68,6 +74,10 @@ def from_function(
6874
)
6975
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
7076

77+
# Default to sync mode if no invocation modes specified
78+
if invocation_modes is None:
79+
invocation_modes = ["sync"]
80+
7181
return cls(
7282
fn=fn,
7383
name=func_name,
@@ -78,6 +88,7 @@ def from_function(
7888
is_async=is_async,
7989
context_kwarg=context_kwarg,
8090
annotations=annotations,
91+
invocation_modes=invocation_modes,
8192
)
8293

8394
async def run(

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Any
55

66
from mcp.server.fastmcp.exceptions import ToolError
7-
from mcp.server.fastmcp.tools.base import Tool
7+
from mcp.server.fastmcp.tools.base import InvocationMode, Tool
88
from mcp.server.fastmcp.utilities.logging import get_logger
99
from mcp.shared.context import LifespanContextT, RequestT
1010
from mcp.types import ToolAnnotations
@@ -50,15 +50,21 @@ def add_tool(
5050
description: str | None = None,
5151
annotations: ToolAnnotations | None = None,
5252
structured_output: bool | None = None,
53+
invocation_modes: list[InvocationMode] | None = None,
5354
) -> Tool:
5455
"""Add a tool to the server."""
56+
# Default to sync mode if no invocation modes specified
57+
if invocation_modes is None:
58+
invocation_modes = ["sync"]
59+
5560
tool = Tool.from_function(
5661
fn,
5762
name=name,
5863
title=title,
5964
description=description,
6065
annotations=annotations,
6166
structured_output=structured_output,
67+
invocation_modes=invocation_modes,
6268
)
6369
existing = self._tools.get(tool.name)
6470
if existing:

tests/server/fastmcp/test_server.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,54 @@ def get_settings() -> dict[str, str]:
603603
assert result.isError is False
604604
assert result.structuredContent == {"theme": "dark", "language": "en", "timezone": "UTC"}
605605

606+
@pytest.mark.anyio
607+
async def test_list_tools_invocation_mode_sync(self):
608+
"""Test that sync tools have proper invocationMode field."""
609+
mcp = FastMCP()
610+
611+
@mcp.tool()
612+
def sync_tool(x: int) -> int:
613+
"""A sync tool."""
614+
return x * 2
615+
616+
async with client_session(mcp._mcp_server) as client:
617+
tools = await client.list_tools()
618+
tool = next(t for t in tools.tools if t.name == "sync_tool")
619+
# Sync tools should not have invocationMode field (None) for old clients
620+
assert tool.invocationMode is None
621+
622+
@pytest.mark.anyio
623+
async def test_list_tools_invocation_mode_async_only(self):
624+
"""Test that async-only tools have proper invocationMode field."""
625+
mcp = FastMCP()
626+
627+
@mcp.tool(invocation_modes=["async"])
628+
async def async_only_tool(x: int) -> int:
629+
"""An async-only tool."""
630+
return x * 2
631+
632+
async with client_session(mcp._mcp_server) as client:
633+
tools = await client.list_tools()
634+
# Async-only tools should be filtered out for old clients
635+
async_tools = [t for t in tools.tools if t.name == "async_only_tool"]
636+
assert len(async_tools) == 0
637+
638+
@pytest.mark.anyio
639+
async def test_list_tools_invocation_mode_hybrid(self):
640+
"""Test that hybrid tools have proper invocationMode field."""
641+
mcp = FastMCP()
642+
643+
@mcp.tool(invocation_modes=["sync", "async"])
644+
def hybrid_tool(x: int) -> int:
645+
"""A hybrid tool."""
646+
return x * 2
647+
648+
async with client_session(mcp._mcp_server) as client:
649+
tools = await client.list_tools()
650+
tool = next(t for t in tools.tools if t.name == "hybrid_tool")
651+
# Hybrid tools should not have invocationMode field (None) for old clients
652+
assert tool.invocationMode is None
653+
606654

607655
class TestServerResources:
608656
@pytest.mark.anyio

0 commit comments

Comments
 (0)