diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 719595916a..e0b09df84a 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -61,7 +61,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations +from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ListToolsRequest, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -298,9 +298,19 @@ def _setup_handlers(self) -> None: self._mcp_server.get_prompt()(self.get_prompt) self._mcp_server.list_resource_templates()(self.list_resource_templates) - async def list_tools(self) -> list[MCPTool]: - """List all available tools.""" - tools = self._tool_manager.list_tools() + async def list_tools( + self, + request: ListToolsRequest | None = None, + ) -> list[MCPTool]: + """List all available tools, optionally filtered by include/exclude parameters.""" + if request and request.params: + tools = self._tool_manager.list_tools( + include=request.params.include, + exclude=request.params.exclude, + ) + else: + tools = self._tool_manager.list_tools() + return [ MCPTool( name=info.name, diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 095753de69..b61b6cb1e5 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -38,8 +38,40 @@ def get_tool(self, name: str) -> Tool | None: """Get tool by name.""" return self._tools.get(name) - def list_tools(self) -> list[Tool]: - """List all registered tools.""" + def _include_tools(self, tools: dict[str, Tool], include: list[str]) -> list[Tool]: + """Filter tools to include only the specified tool names.""" + filtered_tools: list[Tool] = [] + for tool_name in include: + tool = tools.get(tool_name) + if tool is None: + raise ValueError(f"Tool '{tool_name}' not found in available tools, cannot be included.") + filtered_tools.append(tool) + return filtered_tools + + def _exclude_tools(self, tools: dict[str, Tool], exclude: list[str]) -> list[Tool]: + """Filter tools to exclude the specified tool names.""" + exclude_set = set(exclude) + + for tool_name in exclude: + if tool_name not in tools: + raise ValueError(f"Tool '{tool_name}' not found in available tools, cannot be excluded.") + + return [tool for name, tool in tools.items() if name not in exclude_set] + + def list_tools( + self, + *, + include: list[str] | None = None, + exclude: list[str] | None = None, + ) -> list[Tool]: + """List all registered tools, optionally filtered by include or exclude parameters.""" + if include is not None and exclude is not None: + raise ValueError("Cannot specify both 'include' and 'exclude' parameters") + elif include is not None: + return self._include_tools(self._tools, include) + elif exclude is not None: + return self._exclude_tools(self._tools, exclude) + return list(self._tools.values()) def add_tool( diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index f5a745db2f..87f4e0b616 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -1,11 +1,35 @@ import inspect from collections.abc import Callable -from typing import Any, TypeVar, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints T = TypeVar("T") R = TypeVar("R") +def _type_matches_request(param_type: Any, request_type: type[T]) -> bool: + """ + Check if a parameter type matches the request type. + + This handles direct matches, Union types (e.g., RequestType | None), + and Optional types (e.g., Optional[RequestType]). + """ + if param_type == request_type: + return True + + origin = get_origin(param_type) + args = get_args(param_type) + + # Handle typing.Union and Python 3.10+ | syntax + if origin is Union: + return request_type in args + + # Handle types.UnionType from Python 3.10+ | syntax + if hasattr(param_type, "__args__") and args: + return request_type in args + + return False + + def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callable[[T], R]: """ Create a wrapper function that knows how to call func with the request object. @@ -13,9 +37,12 @@ def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callab Returns a wrapper function that takes the request and calls func appropriately. The wrapper handles three calling patterns: - 1. Positional-only parameter typed as request_type (no default): func(req) - 2. Positional/keyword parameter typed as request_type (no default): func(**{param_name: req}) - 3. No request parameter or parameter with default: func() + 1. Positional-only parameter typed as request_type or Union containing request_type: func(req) + 2. Positional/keyword parameter typed as request_type or Union containing request_type: func(**{param_name: req}) + 3. No matching request parameter: func() + + Union types like `RequestType | None` and `Optional[RequestType]` are supported, + allowing for optional request parameters with default values. """ try: sig = inspect.signature(func) @@ -27,23 +54,16 @@ def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callab for param_name, param in sig.parameters.items(): if param.kind == inspect.Parameter.POSITIONAL_ONLY: param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: - return lambda _: func() - # Found positional-only parameter with correct type and no default + if _type_matches_request(param_type, request_type): + # Found positional-only parameter with correct type return lambda req: func(req) # Check for any positional/keyword parameter typed as request_type for param_name, param in sig.parameters.items(): if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY): param_type = type_hints.get(param_name) - if param_type == request_type: - # Check if it has a default - if so, treat as old style - if param.default is not inspect.Parameter.empty: - return lambda _: func() - - # Found keyword parameter with correct type and no default + if _type_matches_request(param_type, request_type): + # Found keyword parameter with correct type # Need to capture param_name in closure properly def make_keyword_wrapper(name: str) -> Callable[[Any], Any]: return lambda req: func(**{name: req}) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 9cec31bab1..e2a2507334 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -408,7 +408,7 @@ async def handler(req: types.UnsubscribeRequest): def list_tools(self): def decorator( - func: Callable[[], Awaitable[list[types.Tool]]] + func: Callable[..., Awaitable[list[types.Tool]]] | Callable[[types.ListToolsRequest], Awaitable[types.ListToolsResult]], ): logger.debug("Registering handler for ListToolsRequest") diff --git a/src/mcp/types.py b/src/mcp/types.py index 8713227404..2f247f556a 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -56,11 +56,11 @@ class Meta(BaseModel): class PaginatedRequestParams(RequestParams): + """Request parameters for paginated operations with optional filtering.""" + cursor: Cursor | None = None - """ - An opaque token representing the current pagination position. - If provided, the server should return results starting after this cursor. - """ + include: list[str] | None = None + exclude: list[str] | None = None class NotificationParams(BaseModel): diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 6dccec84d9..4c4acd8f90 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -20,7 +20,7 @@ def dummy_tool_func(): globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection # Get all tools - tools = await mcp.list_tools() + tools = await mcp.list_tools(request=None) # Verify we get all tools assert len(tools) == num_tools, f"Expected {num_tools} tools, but got {len(tools)}" diff --git a/tests/issues/test_1338_icons_and_metadata.py b/tests/issues/test_1338_icons_and_metadata.py index 8a9897fcf7..1fd821bafc 100644 --- a/tests/issues/test_1338_icons_and_metadata.py +++ b/tests/issues/test_1338_icons_and_metadata.py @@ -55,7 +55,7 @@ def test_resource_template(city: str) -> str: assert mcp.icons[0].sizes == test_icon.sizes # Test tool includes icon - tools = await mcp.list_tools() + tools = await mcp.list_tools(request=None) assert len(tools) == 1 tool = tools[0] assert tool.name == "test_tool" @@ -109,7 +109,7 @@ def multi_icon_tool() -> str: return "success" # Test tool has all icons - tools = await mcp.list_tools() + tools = await mcp.list_tools(request=None) assert len(tools) == 1 tool = tools[0] assert tool.icons is not None @@ -135,7 +135,7 @@ def basic_tool() -> str: assert mcp.icons is None # Test tool has no icons - tools = await mcp.list_tools() + tools = await mcp.list_tools(request=None) assert len(tools) == 1 tool = tools[0] assert tool.name == "basic_tool" diff --git a/tests/server/fastmcp/test_parameter_descriptions.py b/tests/server/fastmcp/test_parameter_descriptions.py index 29470ed19c..2bb592a76c 100644 --- a/tests/server/fastmcp/test_parameter_descriptions.py +++ b/tests/server/fastmcp/test_parameter_descriptions.py @@ -18,7 +18,7 @@ def greet( """A greeting tool""" return f"Hello {title} {name}" - tools = await mcp.list_tools() + tools = await mcp.list_tools(request=None) assert len(tools) == 1 tool = tools[0] diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 7a53ec37a1..6d975242e1 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -447,7 +447,7 @@ def echo(message: str) -> str: """Echo a message back.""" return message - tools = await app.list_tools() + tools = await app.list_tools(request=None) assert len(tools) == 1 assert tools[0].annotations is not None assert tools[0].annotations.title == "Echo Tool" @@ -704,7 +704,7 @@ def analyze_text(text: str) -> dict[str, Any]: """Analyze text content.""" return {"length": len(text), "words": len(text.split())} - tools = await app.list_tools() + tools = await app.list_tools(request=None) assert len(tools) == 1 assert tools[0].meta is not None assert tools[0].meta == metadata @@ -733,7 +733,7 @@ def tool3(z: bool) -> bool: """Third tool without metadata.""" return z - tools = await app.list_tools() + tools = await app.list_tools(request=None) assert len(tools) == 3 # Find tools by name and check metadata @@ -799,7 +799,7 @@ def combined_tool(data: str) -> str: """Tool with both metadata and annotations.""" return data - tools = await app.list_tools() + tools = await app.list_tools(request=None) assert len(tools) == 1 assert tools[0].meta == metadata assert tools[0].annotations is not None @@ -807,6 +807,179 @@ def combined_tool(data: str) -> str: assert tools[0].annotations.readOnlyHint is True +class TestListTools: + """Test tool listing functionality in the tool manager.""" + + def test_list_all_tools(self): + """Test listing all tools when no filters are applied.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + def divide(a: int, b: int) -> float: + """Divide two numbers.""" + return a / b + + manager = ToolManager() + manager.add_tool(add) + manager.add_tool(multiply) + manager.add_tool(divide) + + tools = manager.list_tools() + assert len(tools) == 3 + tool_names = {tool.name for tool in tools} + assert tool_names == {"add", "multiply", "divide"} + + def test_list_tools_with_include(self): + """Test listing tools with include filter.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + def divide(a: int, b: int) -> float: + """Divide two numbers.""" + return a / b + + manager = ToolManager() + manager.add_tool(add) + manager.add_tool(multiply) + manager.add_tool(divide) + + # Test including specific tools + tools = manager.list_tools(include=["add", "multiply"]) + assert len(tools) == 2 + tool_names = {tool.name for tool in tools} + assert tool_names == {"add", "multiply"} + + # Test including single tool + tools = manager.list_tools(include=["divide"]) + assert len(tools) == 1 + assert tools[0].name == "divide" + + # Test including all tools explicitly + tools = manager.list_tools(include=["add", "multiply", "divide"]) + assert len(tools) == 3 + tool_names = {tool.name for tool in tools} + assert tool_names == {"add", "multiply", "divide"} + + def test_list_tools_with_exclude(self): + """Test listing tools with exclude filter.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + def divide(a: int, b: int) -> float: + """Divide two numbers.""" + return a / b + + manager = ToolManager() + manager.add_tool(add) + manager.add_tool(multiply) + manager.add_tool(divide) + + # Test excluding specific tools + tools = manager.list_tools(exclude=["divide"]) + assert len(tools) == 2 + tool_names = {tool.name for tool in tools} + assert tool_names == {"add", "multiply"} + + # Test excluding multiple tools + tools = manager.list_tools(exclude=["add", "multiply"]) + assert len(tools) == 1 + assert tools[0].name == "divide" + + # Test excluding all tools + tools = manager.list_tools(exclude=["add", "multiply", "divide"]) + assert len(tools) == 0 + + def test_list_tools_include_nonexistent(self): + """Test that including a non-existent tool raises ValueError.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + + with pytest.raises(ValueError, match="Tool 'nonexistent' not found in available tools, cannot be included."): + manager.list_tools(include=["add", "nonexistent"]) + + def test_list_tools_exclude_nonexistent(self): + """Test that excluding a non-existent tool raises ValueError.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + + with pytest.raises(ValueError): + manager.list_tools(exclude=["add", "nonexistent"]) + + def test_list_tools_include_and_exclude_error(self): + """Test that providing both include and exclude raises ValueError.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + + with pytest.raises(ValueError): + manager.list_tools(include=["add"], exclude=["add"]) + + def test_list_tools_empty_include(self): + """Test listing tools with empty include list.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + manager = ToolManager() + manager.add_tool(add) + + tools = manager.list_tools(include=[]) + assert len(tools) == 0 + + def test_list_tools_empty_exclude(self): + """Test listing tools with empty exclude list.""" + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def multiply(a: int, b: int) -> int: + """Multiply two numbers.""" + return a * b + + manager = ToolManager() + manager.add_tool(add) + manager.add_tool(multiply) + + tools = manager.list_tools(exclude=[]) + assert len(tools) == 2 + tool_names = {tool.name for tool in tools} + assert tool_names == {"add", "multiply"} + + class TestRemoveTools: """Test tool removal functionality in the tool manager."""