diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 3f26ddcea6..46cf877e16 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -10,8 +10,9 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.utilities.context_injection import find_context_parameter +from mcp.server.fastmcp.utilities.dependencies import DependencyResolver, find_dependencies from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata -from mcp.types import Icon, ToolAnnotations +from mcp.types import Depends, Icon, ToolAnnotations if TYPE_CHECKING: from mcp.server.fastmcp.server import Context @@ -32,6 +33,7 @@ class Tool(BaseModel): ) is_async: bool = Field(description="Whether the tool is async") context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + dependencies: dict[str, Depends] | None = Field(None, description="Tool dependencies") annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool") @@ -47,6 +49,7 @@ def from_function( title: str | None = None, description: str | None = None, context_kwarg: str | None = None, + dependencies: dict[str, Depends] | None = None, annotations: ToolAnnotations | None = None, icons: list[Icon] | None = None, structured_output: bool | None = None, @@ -63,9 +66,16 @@ def from_function( if context_kwarg is None: context_kwarg = find_context_parameter(fn) + if dependencies is None: + dependencies = find_dependencies(fn) + + skip_names = [context_kwarg] if context_kwarg is not None else [] + if dependencies: + skip_names.extend(dependencies.keys()) + func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=skip_names, structured_output=structured_output, ) parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) @@ -79,6 +89,7 @@ def from_function( fn_metadata=func_arg_metadata, is_async=is_async, context_kwarg=context_kwarg, + dependencies=dependencies, annotations=annotations, icons=icons, ) @@ -90,12 +101,20 @@ async def run( convert_result: bool = False, ) -> Any: """Run the tool with arguments.""" + dependency_resolver = DependencyResolver() try: + # Resolve dependencies + resolved_dependencies = await dependency_resolver.resolve_dependencies(self.dependencies or {}) + + # Prepare arguments to pass directly to the function + arguments_to_pass_directly = {self.context_kwarg: context} if self.context_kwarg is not None else {} + arguments_to_pass_directly.update(resolved_dependencies) + result = await self.fn_metadata.call_fn_with_arg_validation( self.fn, self.is_async, arguments, - {self.context_kwarg: context} if self.context_kwarg is not None else None, + arguments_to_pass_directly, ) if convert_result: diff --git a/src/mcp/server/fastmcp/utilities/dependencies.py b/src/mcp/server/fastmcp/utilities/dependencies.py new file mode 100644 index 0000000000..22c8bc01d0 --- /dev/null +++ b/src/mcp/server/fastmcp/utilities/dependencies.py @@ -0,0 +1,125 @@ +import inspect +from collections.abc import AsyncGenerator, Callable, Generator +from typing import Annotated, Any, get_args, get_origin, get_type_hints + +from mcp.types import Depends + + +def find_dependencies(fn: Callable[..., Any]) -> dict[str, Depends]: + """Find all dependencies in a function's parameters.""" + # Get type hints to properly resolve string annotations + try: + hints = get_type_hints(fn, include_extras=True) + except Exception: + # If we can't resolve type hints, we can't find dependencies + hints = {} + + dependencies: dict[str, Depends] = {} + + # Get function signature to access parameter defaults + sig = inspect.signature(fn) + + # Check each parameter's type hint and default value + for param_name, param in sig.parameters.items(): + # Check if it's in Annotated form + if param_name in hints: + annotation = hints[param_name] + if get_origin(annotation) is Annotated: + _, *extras = get_args(annotation) + dep = next((x for x in extras if isinstance(x, Depends)), None) + if dep is not None: + dependencies[param_name] = dep + continue + + # Check if default value is a Depends instance + if param.default is not inspect.Parameter.empty and isinstance(param.default, Depends): + dependencies[param_name] = param.default + + return dependencies + + +def _is_async_callable(obj: Any) -> bool: + """Check if a callable is async.""" + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) + + +def _is_generator_function(obj: Any) -> bool: + """Check if a callable is a generator function.""" + return inspect.isgeneratorfunction(obj) + + +def _is_async_generator_function(obj: Any) -> bool: + """Check if a callable is an async generator function.""" + return inspect.isasyncgenfunction(obj) + + +class DependencyResolver: + """Resolve dependencies and clean up properly when errors occur.""" + + def __init__(self): + self._generators: list[Generator[Any, None, None]] = [] + self._async_generators: list[AsyncGenerator[Any, None]] = [] + + async def resolve_dependencies(self, dependencies: dict[str, Depends]) -> dict[str, Any]: + """Resolve all dependencies and return their values.""" + if not dependencies: + return {} + + resolved: dict[str, Any] = {} + + for param_name, depends in dependencies.items(): + try: + resolved[param_name] = await self._resolve_single_dependency(depends) + except Exception as e: + # Cleanup any generators and async generators that were already created + await self.cleanup() + raise RuntimeError(f"Failed to resolve dependency '{param_name}': {e}") from e + + return resolved + + async def _resolve_single_dependency(self, depends: Depends) -> Any: + """Resolve a single dependency.""" + dependency_fn = depends.dependency + + if _is_async_generator_function(dependency_fn): + gen = dependency_fn() + self._async_generators.append(gen) + try: + value = await gen.__anext__() + return value + except StopAsyncIteration: + raise RuntimeError(f"Async generator dependency {dependency_fn.__name__} didn't yield a value") + + elif _is_generator_function(dependency_fn): + gen = dependency_fn() + self._generators.append(gen) + try: + value = next(gen) + return value + except StopIteration: + raise RuntimeError(f"Generator dependency {dependency_fn.__name__} didn't yield a value") + + elif _is_async_callable(dependency_fn): + return await dependency_fn() + + else: + return dependency_fn() + + async def cleanup(self): + """Cleanup all generator dependencies.""" + for gen in self._async_generators: + try: + await gen.aclose() + except Exception: + pass + + for gen in self._generators: + try: + gen.close() + except Exception: + pass + + self._generators.clear() + self._async_generators.clear() diff --git a/src/mcp/types.py b/src/mcp/types.py index 8713227404..8d7a6e1d32 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -40,6 +40,12 @@ AnyFunction: TypeAlias = Callable[..., Any] +class Depends(BaseModel): + """Dependency injection for tool parameters.""" + + dependency: Callable[..., Any] + + class RequestParams(BaseModel): class Meta(BaseModel): progressToken: ProgressToken | None = None diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 8caa3b1f6f..942bce974e 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,6 +1,8 @@ +import asyncio import base64 +from collections.abc import AsyncGenerator, Generator from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Annotated, Any from unittest.mock import patch import pytest @@ -10,16 +12,16 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import Message, UserMessage from mcp.server.fastmcp.resources import FileResource, FunctionResource +from mcp.server.fastmcp.utilities.dependencies import DependencyResolver from mcp.server.fastmcp.utilities.types import Audio, Image from mcp.server.session import ServerSession from mcp.shared.exceptions import McpError -from mcp.shared.memory import ( - create_connected_server_and_client_session as client_session, -) +from mcp.shared.memory import create_connected_server_and_client_session as client_session from mcp.types import ( AudioContent, BlobResourceContents, ContentBlock, + Depends, EmbeddedResource, ImageContent, TextContent, @@ -906,6 +908,134 @@ def get_csv(user: str) -> str: assert result.contents[0].text == "csv for bob" +class TestDependenciesInjection: + """Test dependency injection functionality.""" + + @pytest.mark.anyio + async def test_tool_with_regular_dependency(self): + """Test tool with regular function dependency.""" + mcp = FastMCP() + + def load_resource() -> int: + return 42 + + @mcp.tool() + def add_numbers(a: int, b: int, resource: Annotated[int, Depends(dependency=load_resource)]) -> int: + return a + b + resource + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("add_numbers", {"a": 1, "b": 2}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "45" # 1 + 2 + 42 + + @pytest.mark.anyio + async def test_tool_with_async_dependency(self): + """Test tool with async function dependency.""" + mcp = FastMCP() + + async def load_async_resource() -> str: + await asyncio.sleep(0.01) + return "async_data" + + @mcp.tool() + async def process_text(text: str, data: Annotated[str, Depends(dependency=load_async_resource)]) -> str: + return f"Processed '{text}' with {data}" + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("process_text", {"text": "hello"}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed 'hello' with async_data" + + @pytest.mark.anyio + async def test_tool_with_generator_dependency_cleanup(self): + """Test tool with generator dependency and proper cleanup.""" + mcp = FastMCP() + cleanup_called = False + + def database_connection() -> Generator[str, None, None]: + nonlocal cleanup_called + try: + yield "db_conn_123" + finally: + cleanup_called = True + + @mcp.tool() + def query_database(query: str, db_conn: Annotated[str, Depends(dependency=database_connection)]) -> str: + return f"Executed '{query}' on {db_conn}" + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("query_database", {"query": "SELECT * FROM users"}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Executed 'SELECT * FROM users' on db_conn_123" + + # Cleanup should have been called after tool execution + assert cleanup_called + + @pytest.mark.anyio + async def test_tool_with_async_generator_dependency_cleanup(self): + """Test tool with async generator dependency and proper cleanup.""" + mcp = FastMCP() + cleanup_called = False + + async def async_file_handler() -> AsyncGenerator[str, None]: + nonlocal cleanup_called + try: + yield "file_123" + finally: + cleanup_called = True + + @mcp.tool() + async def process_file( + content: str, file_handler: Annotated[str, Depends(dependency=async_file_handler)] + ) -> str: + await asyncio.sleep(0.01) + return f"Processed '{content}' with {file_handler}" + + async with client_session(mcp._mcp_server) as client: + result = await client.call_tool("process_file", {"content": "data"}) + assert len(result.content) == 1 + content = result.content[0] + assert isinstance(content, TextContent) + assert content.text == "Processed 'data' with file_123" + + # Cleanup should have been called after tool execution + assert cleanup_called + + @pytest.mark.anyio + async def test_generator_no_yield_error(self): + """Test error when generator doesn't yield a value.""" + + def empty_generator() -> Generator[str, None, None]: + return + yield # This line is never reached + + resolver = DependencyResolver() + dependencies = {"dep": Depends(dependency=empty_generator)} + + with pytest.raises(RuntimeError): + await resolver.resolve_dependencies(dependencies) + + @pytest.mark.anyio + async def test_async_generator_no_yield_error(self): + """Test error when async generator doesn't yield a value.""" + + async def empty_async_generator() -> AsyncGenerator[str, None]: + return + yield # This line is never reached + + resolver = DependencyResolver() + dependencies = {"dep": Depends(dependency=empty_async_generator)} + + with pytest.raises(RuntimeError): + await resolver.resolve_dependencies(dependencies) + + class TestContextInjection: """Test context injection in tools, resources, and prompts.""" diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 71884fba22..66d74741e0 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -53,6 +53,7 @@ class AddArguments(ArgModelBase): is_async=False, parameters=AddArguments.model_json_schema(), context_kwarg=None, + dependencies=None, annotations=None, ) manager = ToolManager(tools=[original_tool])