From 2a09a78a614966fc988689f0bb75f8ab9f8401e0 Mon Sep 17 00:00:00 2001 From: 0Delta <0deltast@gmail.com> Date: Fri, 26 Sep 2025 21:55:36 +0900 Subject: [PATCH] Add support for context-only resources (#1405) --- examples/snippets/servers/context_resource.py | 11 +++ src/mcp/server/fastmcp/prompts/base.py | 4 +- src/mcp/server/fastmcp/resources/base.py | 4 +- src/mcp/server/fastmcp/resources/templates.py | 4 +- src/mcp/server/fastmcp/resources/types.py | 61 ++++++++++----- src/mcp/server/fastmcp/server.py | 25 +++--- .../fastmcp/utilities/context_injection.py | 77 ++++++++----------- tests/server/fastmcp/test_integration.py | 42 ++++++++++ tests/server/fastmcp/test_server.py | 31 +++++++- 9 files changed, 172 insertions(+), 87 deletions(-) create mode 100644 examples/snippets/servers/context_resource.py diff --git a/examples/snippets/servers/context_resource.py b/examples/snippets/servers/context_resource.py new file mode 100644 index 0000000000..d2d7c54092 --- /dev/null +++ b/examples/snippets/servers/context_resource.py @@ -0,0 +1,11 @@ +from mcp.server.fastmcp import Context, FastMCP +from mcp.server.session import ServerSession + +mcp = FastMCP(name="Context Resource Example") + + +@mcp.resource("resource://only_context") +def resource_only_context(ctx: Context[ServerSession, None]) -> str: + """Resource that only receives context.""" + assert ctx is not None + return "Resource with only context injected" diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index 48c65b57c5..dcaace22ca 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -99,12 +99,12 @@ def from_function( # Find context parameter if it exists if context_kwarg is None: # pragma: no branch - context_kwarg = find_context_parameter(fn) + context_kwarg = find_context_parameter(fn) or "" # Get schema from func_metadata, excluding context parameter func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=[context_kwarg] if context_kwarg else [], ) parameters = func_arg_metadata.arg_model.model_json_schema() diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index c733e1a46b..2ae1810670 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -1,7 +1,7 @@ """Base classes and interfaces for FastMCP resources.""" import abc -from typing import Annotated +from typing import Annotated, Any from pydantic import ( AnyUrl, @@ -44,6 +44,6 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: raise ValueError("Either name or uri must be provided") @abc.abstractmethod - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the resource content.""" pass # pragma: no cover diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index a98d37f0ac..38e4191802 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -54,12 +54,12 @@ def from_function( # Find context parameter if it exists if context_kwarg is None: # pragma: no branch - context_kwarg = find_context_parameter(fn) + context_kwarg = find_context_parameter(fn) or "" # Get schema from func_metadata, excluding context parameter func_arg_metadata = func_metadata( fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], + skip_names=[context_kwarg] if context_kwarg else [], ) parameters = func_arg_metadata.arg_model.model_json_schema() diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index 680e72dc09..be9fff5ffc 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -14,6 +14,7 @@ from pydantic import AnyUrl, Field, ValidationInfo, validate_call from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context from mcp.types import Annotations, Icon @@ -22,7 +23,7 @@ class TextResource(Resource): text: str = Field(description="Text content of the resource") - async def read(self) -> str: + async def read(self, context: Any | None = None) -> str: """Read the text content.""" return self.text # pragma: no cover @@ -32,7 +33,7 @@ class BinaryResource(Resource): data: bytes = Field(description="Binary content of the resource") - async def read(self) -> bytes: + async def read(self, context: Any | None = None) -> bytes: """Read the binary content.""" return self.data # pragma: no cover @@ -51,24 +52,39 @@ class FunctionResource(Resource): """ fn: Callable[[], Any] = Field(exclude=True) + context_kwarg: str | None = Field(None, exclude=True) + + async def read(self, context: Any | None = None) -> str | bytes: + """Read the resource content by calling the function.""" + # Inject context using utility which handles optimization + # If context_kwarg is set, it's used directly (fast) + # If not set (manual init), it falls back to inspection (safe) + args = inject_context(self.fn, {}, context, self.context_kwarg) - async def read(self) -> str | bytes: - """Read the resource by calling the wrapped function.""" try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it + if inspect.iscoroutinefunction(self.fn): + result = await self.fn(**args) + else: + result = self.fn(**args) + + # Support cases where a sync function returns a coroutine if inspect.iscoroutine(result): - result = await result + result = await result # pragma: no cover - if isinstance(result, Resource): # pragma: no cover - return await result.read() - elif isinstance(result, bytes): - return result - elif isinstance(result, str): + # Support returning a Resource instance (recursive read) + if isinstance(result, Resource): + return await result.read(context) # pragma: no cover + + if isinstance(result, str | bytes): return result - else: - return pydantic_core.to_json(result, fallback=str, indent=2).decode() + if isinstance(result, pydantic.BaseModel): + return result.model_dump_json(indent=2) + + # For other types, convert to a JSON string + try: + return json.dumps(pydantic_core.to_jsonable_python(result)) + except pydantic_core.PydanticSerializationError: + return json.dumps(str(result)) except Exception as e: raise ValueError(f"Error reading resource {self.uri}: {e}") @@ -86,8 +102,10 @@ def from_function( ) -> "FunctionResource": """Create a FunctionResource from a function.""" func_name = name or fn.__name__ - if func_name == "": # pragma: no cover - raise ValueError("You must provide a name for lambda functions") + if func_name == "": + raise ValueError("You must provide a name for lambda functions") # pragma: no cover + + context_kwarg = find_context_parameter(fn) or "" # ensure the arguments are properly cast fn = validate_call(fn) @@ -100,6 +118,7 @@ def from_function( mime_type=mime_type or "text/plain", fn=fn, icons=icons, + context_kwarg=context_kwarg, annotations=annotations, ) @@ -125,7 +144,7 @@ class FileResource(Resource): def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover """Ensure path is absolute.""" if not path.is_absolute(): - raise ValueError("Path must be absolute") + raise ValueError("Path must be absolute") # pragma: no cover return path @pydantic.field_validator("is_binary") @@ -137,7 +156,7 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo mime_type = info.data.get("mime_type", "text/plain") return not mime_type.startswith("text/") - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the file content.""" try: if self.is_binary: @@ -153,7 +172,7 @@ class HttpResource(Resource): url: str = Field(description="URL to fetch content from") mime_type: str = Field(default="application/json", description="MIME type of the resource content") - async def read(self) -> str | bytes: + async def read(self, context: Any | None = None) -> str | bytes: """Read the HTTP content.""" async with httpx.AsyncClient() as client: # pragma: no cover response = await client.get(self.url) @@ -191,7 +210,7 @@ def list_files(self) -> list[Path]: # pragma: no cover except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") - async def read(self) -> str: # Always returns JSON string # pragma: no cover + async def read(self, context: Any | None = None) -> str: # Always returns JSON string # pragma: no cover """Read the directory listing.""" try: files = await anyio.to_thread.run_sync(self.list_files) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2e596c9f9a..6b81b3e264 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -376,7 +376,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent raise ResourceError(f"Unknown resource: {uri}") try: - content = await resource.read() + content = await resource.read(context=context) return [ReadResourceContents(content=content, mime_type=resource.mime_type)] except Exception as e: # pragma: no cover logger.exception(f"Error reading resource {uri}") @@ -575,27 +575,24 @@ async def get_weather(city: str) -> str: ) def decorator(fn: AnyFunction) -> AnyFunction: - # Check if this should be a template sig = inspect.signature(fn) - has_uri_params = "{" in uri and "}" in uri - has_func_params = bool(sig.parameters) + context_param = find_context_parameter(fn) + + # Determine effective parameters, excluding context + effective_func_params = {p for p in sig.parameters.keys() if p != context_param} - if has_uri_params or has_func_params: - # Check for Context parameter to exclude from validation - context_param = find_context_parameter(fn) + has_uri_params = "{" in uri and "}" in uri + has_effective_func_params = bool(effective_func_params) - # Validate that URI params match function params (excluding context) + if has_uri_params or has_effective_func_params: + # Register as template uri_params = set(re.findall(r"{(\w+)}", uri)) - # We need to remove the context_param from the resource function if - # there is any. - func_params = {p for p in sig.parameters.keys() if p != context_param} - if uri_params != func_params: + if uri_params != effective_func_params: raise ValueError( - f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" + f"Mismatch between URI parameters {uri_params} and function parameters {effective_func_params}" ) - # Register as template self._resource_manager.add_template( fn=fn, uri_template=uri, diff --git a/src/mcp/server/fastmcp/utilities/context_injection.py b/src/mcp/server/fastmcp/utilities/context_injection.py index 66d0cbaa0c..7a1f042808 100644 --- a/src/mcp/server/fastmcp/utilities/context_injection.py +++ b/src/mcp/server/fastmcp/utilities/context_injection.py @@ -1,47 +1,40 @@ -"""Context injection utilities for FastMCP.""" - -from __future__ import annotations - import inspect import typing from collections.abc import Callable -from typing import Any - - -def find_context_parameter(fn: Callable[..., Any]) -> str | None: - """Find the parameter that should receive the Context object. +from typing import TYPE_CHECKING, Any - Searches through the function's signature to find a parameter - with a Context type annotation. +if TYPE_CHECKING: + from mcp.server.fastmcp import Context - Args: - fn: The function to inspect - Returns: - The name of the context parameter, or None if not found +def find_context_parameter(fn: Callable[..., Any]) -> str | None: """ - from mcp.server.fastmcp.server import Context + Inspect a function signature to find a parameter annotated with Context. + Returns the name of the parameter if found, otherwise None. + """ + from mcp.server.fastmcp import Context - # Get type hints to properly resolve string annotations try: - hints = typing.get_type_hints(fn) - except Exception: - # If we can't resolve type hints, we can't find the context parameter + sig = inspect.signature(fn) + except ValueError: # pragma: no cover + # Can't inspect signature (e.g. some builtins/wrappers) return None - # Check each parameter's type hint - for param_name, annotation in hints.items(): - # Handle direct Context type + for param_name, param in sig.parameters.items(): + annotation = param.annotation + if annotation is inspect.Parameter.empty: + continue + + # Handle Optional[Context], Annotated[Context, ...], etc. + origin = typing.get_origin(annotation) + + # Check if the annotation itself is Context or a subclass if inspect.isclass(annotation) and issubclass(annotation, Context): return param_name - # Handle generic types like Optional[Context] - origin = typing.get_origin(annotation) - if origin is not None: - args = typing.get_args(annotation) - for arg in args: - if inspect.isclass(arg) and issubclass(arg, Context): - return param_name + # Check if it's a generic alias of Context (e.g., Context[...]) + if origin is not None and inspect.isclass(origin) and issubclass(origin, Context): + return param_name # pragma: no cover return None @@ -49,20 +42,16 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: def inject_context( fn: Callable[..., Any], kwargs: dict[str, Any], - context: Any | None, - context_kwarg: str | None, + context: "Context[Any, Any, Any] | None", + context_kwarg: str | None = None, ) -> dict[str, Any]: - """Inject context into function kwargs if needed. - - Args: - fn: The function that will be called - kwargs: The current keyword arguments - context: The context object to inject (if any) - context_kwarg: The name of the parameter to inject into - - Returns: - Updated kwargs with context injected if applicable """ - if context_kwarg is not None and context is not None: - return {**kwargs, context_kwarg: context} + Inject the Context object into kwargs if the function expects it. + Returns the updated kwargs. + """ + if context_kwarg is None: + context_kwarg = find_context_parameter(fn) + + if context_kwarg: + kwargs[context_kwarg] = context return kwargs diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index b1cefca29c..eb28c01ee1 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -25,6 +25,7 @@ basic_resource, basic_tool, completion, + context_resource, elicitation, fastmcp_quickstart, notifications, @@ -124,6 +125,8 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No mcp = fastmcp_quickstart.mcp elif module_name == "structured_output": mcp = structured_output.mcp + elif module_name == "context_resource": + mcp = context_resource.mcp else: raise ImportError(f"Unknown module: {module_name}") @@ -686,3 +689,42 @@ async def test_structured_output(server_transport: str, server_url: str) -> None assert "sunny" in result_text # condition assert "45" in result_text # humidity assert "5.2" in result_text # wind_speed + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "server_transport", + [ + ("context_resource", "sse"), + ("context_resource", "streamable-http"), + ], + indirect=True, +) +async def test_context_only_resource(server_transport: str, server_url: str) -> None: + """Test that a resource with only a context argument is registered as a regular resource.""" + transport = server_transport + client_cm = create_client_for_transport(transport, server_url) + + async with client_cm as client_streams: + read_stream, write_stream = unpack_streams(client_streams) + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "Context Resource Example" + + # Check that it is not in templates + templates = await session.list_resource_templates() + assert len(templates.resourceTemplates) == 0 + + # Check that it is in resources + resources = await session.list_resources() + assert len(resources.resources) == 1 + resource = resources.resources[0] + assert resource.uri == AnyUrl("resource://only_context") + + # Check that we can read it + read_result = await session.read_resource(AnyUrl("resource://only_context")) + assert len(read_result.contents) == 1 + assert isinstance(read_result.contents[0], TextResourceContents) + assert read_result.contents[0].text == "Resource with only context injected" diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index fdbb04694c..c007b29136 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -686,7 +686,7 @@ async def test_text_resource(self): def get_text(): return "Hello, world!" - resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text) + resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text, context_kwarg=None) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -1085,7 +1085,7 @@ def resource_no_context(name: str) -> str: templates = mcp._resource_manager.list_templates() assert len(templates) == 1 template = templates[0] - assert template.context_kwarg is None + assert not template.context_kwarg # Test via client async with client_session(mcp._mcp_server) as client: @@ -1120,6 +1120,33 @@ def resource_custom_ctx(id: str, my_ctx: Context[ServerSession, None]) -> str: assert isinstance(content, TextResourceContents) assert "Resource 123 with context" in content.text + @pytest.mark.anyio + async def test_resource_only_context(self): + """Test that resources without template args can receive context.""" + mcp = FastMCP() + + @mcp.resource("resource://only_context", name="resource_with_context_no_args") + def resource_only_context(ctx: Context[ServerSession, None]) -> str: + """Resource that only receives context.""" + assert ctx is not None + return "Resource with only context injected" + + # Test via client + async with client_session(mcp._mcp_server) as client: + # Verify resource is registered via client + resources = await client.list_resources() + assert len(resources.resources) == 1 + resource = resources.resources[0] + assert resource.uri == AnyUrl("resource://only_context") + assert resource.name == "resource_with_context_no_args" + + # Test reading the resource + result = await client.read_resource(AnyUrl("resource://only_context")) + assert len(result.contents) == 1 + content = result.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Resource with only context injected" + @pytest.mark.anyio async def test_prompt_with_context(self): """Test that prompts can receive context parameter."""