Skip to content

Commit a6c8998

Browse files
committed
Add support for context-only resources (#1405)
1 parent 5983a65 commit a6c8998

File tree

12 files changed

+186
-84
lines changed

12 files changed

+186
-84
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from mcp.server.fastmcp import Context, FastMCP
2+
from mcp.server.session import ServerSession
3+
4+
mcp = FastMCP(name="Context Resource Example")
5+
6+
7+
@mcp.resource("resource://only_context")
8+
def resource_only_context(ctx: Context[ServerSession, None]) -> str:
9+
"""Resource that only receives context."""
10+
assert ctx is not None
11+
return "Resource with only context injected"

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ async def render(
150150

151151
try:
152152
# Add context to arguments if needed
153-
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
153+
# context_kwarg is unused since it's automatically determined by inject_context
154+
call_args = inject_context(self.fn, arguments or {}, context)
154155

155156
# Call function and check if result is a coroutine
156157
result = self.fn(**call_args)

src/mcp/server/fastmcp/resources/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Base classes and interfaces for FastMCP resources."""
22

33
import abc
4-
from typing import Annotated
4+
from typing import Annotated, Any
55

66
from pydantic import (
77
AnyUrl,
@@ -44,6 +44,6 @@ def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
4444
raise ValueError("Either name or uri must be provided")
4545

4646
@abc.abstractmethod
47-
async def read(self) -> str | bytes:
47+
async def read(self, context: Any | None = None) -> str | bytes:
4848
"""Read the resource content."""
4949
pass # pragma: no cover

src/mcp/server/fastmcp/resources/templates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ async def create_resource(
9797
"""Create a resource from the template with the given parameters."""
9898
try:
9999
# Add context to params if needed
100-
params = inject_context(self.fn, params, context, self.context_kwarg)
100+
# context_kwarg is unused since it's automatically determined by inject_context
101+
params = inject_context(self.fn, params, context)
101102

102103
# Call function and check if result is a coroutine
103104
result = self.fn(**params)

src/mcp/server/fastmcp/resources/types.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic import AnyUrl, Field, ValidationInfo, validate_call
1515

1616
from mcp.server.fastmcp.resources.base import Resource
17+
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
1718
from mcp.types import Annotations, Icon
1819

1920

@@ -22,7 +23,7 @@ class TextResource(Resource):
2223

2324
text: str = Field(description="Text content of the resource")
2425

25-
async def read(self) -> str:
26+
async def read(self, context: Any | None = None) -> str:
2627
"""Read the text content."""
2728
return self.text # pragma: no cover
2829

@@ -32,7 +33,7 @@ class BinaryResource(Resource):
3233

3334
data: bytes = Field(description="Binary content of the resource")
3435

35-
async def read(self) -> bytes:
36+
async def read(self, context: Any | None = None) -> bytes:
3637
"""Read the binary content."""
3738
return self.data # pragma: no cover
3839

@@ -51,24 +52,38 @@ class FunctionResource(Resource):
5152
"""
5253

5354
fn: Callable[[], Any] = Field(exclude=True)
55+
context_kwarg: str | None = Field(None, exclude=True)
56+
57+
async def read(self, context: Any | None = None) -> str | bytes:
58+
"""Read the resource content by calling the function."""
59+
args = {}
60+
if self.context_kwarg:
61+
args[self.context_kwarg] = context
5462

55-
async def read(self) -> str | bytes:
56-
"""Read the resource by calling the wrapped function."""
5763
try:
58-
# Call the function first to see if it returns a coroutine
59-
result = self.fn()
60-
# If it's a coroutine, await it
64+
if inspect.iscoroutinefunction(self.fn):
65+
result = await self.fn(**args)
66+
else:
67+
result = self.fn(**args)
68+
69+
# Support cases where a sync function returns a coroutine
6170
if inspect.iscoroutine(result):
62-
result = await result
71+
result = await result # pragma: no cover
6372

64-
if isinstance(result, Resource): # pragma: no cover
65-
return await result.read()
66-
elif isinstance(result, bytes):
67-
return result
68-
elif isinstance(result, str):
73+
# Support returning a Resource instance (recursive read)
74+
if isinstance(result, Resource):
75+
return await result.read(context) # pragma: no cover
76+
77+
if isinstance(result, str | bytes):
6978
return result
70-
else:
71-
return pydantic_core.to_json(result, fallback=str, indent=2).decode()
79+
if isinstance(result, pydantic.BaseModel):
80+
return result.model_dump_json(indent=2)
81+
82+
# For other types, convert to a JSON string
83+
try:
84+
return json.dumps(pydantic_core.to_jsonable_python(result))
85+
except pydantic_core.PydanticSerializationError:
86+
return json.dumps(str(result))
7287
except Exception as e:
7388
raise ValueError(f"Error reading resource {self.uri}: {e}")
7489

@@ -86,8 +101,10 @@ def from_function(
86101
) -> "FunctionResource":
87102
"""Create a FunctionResource from a function."""
88103
func_name = name or fn.__name__
89-
if func_name == "<lambda>": # pragma: no cover
90-
raise ValueError("You must provide a name for lambda functions")
104+
if func_name == "<lambda>":
105+
raise ValueError("You must provide a name for lambda functions") # pragma: no cover
106+
107+
context_kwarg = find_context_parameter(fn)
91108

92109
# ensure the arguments are properly cast
93110
fn = validate_call(fn)
@@ -100,6 +117,7 @@ def from_function(
100117
mime_type=mime_type or "text/plain",
101118
fn=fn,
102119
icons=icons,
120+
context_kwarg=context_kwarg,
103121
annotations=annotations,
104122
)
105123

@@ -125,7 +143,7 @@ class FileResource(Resource):
125143
def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover
126144
"""Ensure path is absolute."""
127145
if not path.is_absolute():
128-
raise ValueError("Path must be absolute")
146+
raise ValueError("Path must be absolute") # pragma: no cover
129147
return path
130148

131149
@pydantic.field_validator("is_binary")
@@ -137,7 +155,7 @@ def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> boo
137155
mime_type = info.data.get("mime_type", "text/plain")
138156
return not mime_type.startswith("text/")
139157

140-
async def read(self) -> str | bytes:
158+
async def read(self, context: Any | None = None) -> str | bytes:
141159
"""Read the file content."""
142160
try:
143161
if self.is_binary:
@@ -153,7 +171,7 @@ class HttpResource(Resource):
153171
url: str = Field(description="URL to fetch content from")
154172
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
155173

156-
async def read(self) -> str | bytes:
174+
async def read(self, context: Any | None = None) -> str | bytes:
157175
"""Read the HTTP content."""
158176
async with httpx.AsyncClient() as client: # pragma: no cover
159177
response = await client.get(self.url)
@@ -191,7 +209,7 @@ def list_files(self) -> list[Path]: # pragma: no cover
191209
except Exception as e:
192210
raise ValueError(f"Error listing directory {self.path}: {e}")
193211

194-
async def read(self) -> str: # Always returns JSON string # pragma: no cover
212+
async def read(self, context: Any | None = None) -> str: # Always returns JSON string # pragma: no cover
195213
"""Read the directory listing."""
196214
try:
197215
files = await anyio.to_thread.run_sync(self.list_files)

src/mcp/server/fastmcp/server.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContent
376376
raise ResourceError(f"Unknown resource: {uri}")
377377

378378
try:
379-
content = await resource.read()
379+
content = await resource.read(context=context)
380380
return [ReadResourceContents(content=content, mime_type=resource.mime_type)]
381381
except Exception as e: # pragma: no cover
382382
logger.exception(f"Error reading resource {uri}")
@@ -575,27 +575,24 @@ async def get_weather(city: str) -> str:
575575
)
576576

577577
def decorator(fn: AnyFunction) -> AnyFunction:
578-
# Check if this should be a template
579578
sig = inspect.signature(fn)
580-
has_uri_params = "{" in uri and "}" in uri
581-
has_func_params = bool(sig.parameters)
579+
context_param = find_context_parameter(fn)
580+
581+
# Determine effective parameters, excluding context
582+
effective_func_params = {p for p in sig.parameters.keys() if p != context_param}
582583

583-
if has_uri_params or has_func_params:
584-
# Check for Context parameter to exclude from validation
585-
context_param = find_context_parameter(fn)
584+
has_uri_params = "{" in uri and "}" in uri
585+
has_effective_func_params = bool(effective_func_params)
586586

587-
# Validate that URI params match function params (excluding context)
587+
if has_uri_params or has_effective_func_params:
588+
# Register as template
588589
uri_params = set(re.findall(r"{(\w+)}", uri))
589-
# We need to remove the context_param from the resource function if
590-
# there is any.
591-
func_params = {p for p in sig.parameters.keys() if p != context_param}
592590

593-
if uri_params != func_params:
591+
if uri_params != effective_func_params:
594592
raise ValueError(
595-
f"Mismatch between URI parameters {uri_params} and function parameters {func_params}"
593+
f"Mismatch between URI parameters {uri_params} and function parameters {effective_func_params}"
596594
)
597595

598-
# Register as template
599596
self._resource_manager.add_template(
600597
fn=fn,
601598
uri_template=uri,
Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,54 @@
1-
"""Context injection utilities for FastMCP."""
2-
3-
from __future__ import annotations
4-
51
import inspect
62
import typing
73
from collections.abc import Callable
8-
from typing import Any
4+
from typing import TYPE_CHECKING, Any
95

6+
if TYPE_CHECKING:
7+
from mcp.server.fastmcp import Context
108

11-
def find_context_parameter(fn: Callable[..., Any]) -> str | None:
12-
"""Find the parameter that should receive the Context object.
13-
14-
Searches through the function's signature to find a parameter
15-
with a Context type annotation.
169

17-
Args:
18-
fn: The function to inspect
19-
20-
Returns:
21-
The name of the context parameter, or None if not found
10+
def find_context_parameter(fn: Callable[..., Any]) -> str | None:
11+
"""
12+
Inspect a function signature to find a parameter annotated with Context.
13+
Returns the name of the parameter if found, otherwise None.
2214
"""
23-
from mcp.server.fastmcp.server import Context
15+
from mcp.server.fastmcp import Context
2416

25-
# Get type hints to properly resolve string annotations
2617
try:
27-
hints = typing.get_type_hints(fn)
28-
except Exception:
29-
# If we can't resolve type hints, we can't find the context parameter
18+
sig = inspect.signature(fn)
19+
except ValueError: # pragma: no cover
20+
# Can't inspect signature (e.g. some builtins/wrappers)
3021
return None
3122

32-
# Check each parameter's type hint
33-
for param_name, annotation in hints.items():
34-
# Handle direct Context type
23+
for param_name, param in sig.parameters.items():
24+
annotation = param.annotation
25+
if annotation is inspect.Parameter.empty:
26+
continue
27+
28+
# Handle Optional[Context], Annotated[Context, ...], etc.
29+
origin = typing.get_origin(annotation)
30+
31+
# Check if the annotation itself is Context or a subclass
3532
if inspect.isclass(annotation) and issubclass(annotation, Context):
3633
return param_name
3734

38-
# Handle generic types like Optional[Context]
39-
origin = typing.get_origin(annotation)
40-
if origin is not None:
41-
args = typing.get_args(annotation)
42-
for arg in args:
43-
if inspect.isclass(arg) and issubclass(arg, Context):
44-
return param_name
35+
# Check if it's a generic alias of Context (e.g., Context[...])
36+
if origin is not None and inspect.isclass(origin) and issubclass(origin, Context):
37+
return param_name # pragma: no cover
4538

4639
return None
4740

4841

4942
def inject_context(
5043
fn: Callable[..., Any],
5144
kwargs: dict[str, Any],
52-
context: Any | None,
53-
context_kwarg: str | None,
45+
context: "Context[Any, Any, Any] | None",
5446
) -> dict[str, Any]:
55-
"""Inject context into function kwargs if needed.
56-
57-
Args:
58-
fn: The function that will be called
59-
kwargs: The current keyword arguments
60-
context: The context object to inject (if any)
61-
context_kwarg: The name of the parameter to inject into
62-
63-
Returns:
64-
Updated kwargs with context injected if applicable
6547
"""
66-
if context_kwarg is not None and context is not None:
67-
return {**kwargs, context_kwarg: context}
48+
Inject the Context object into kwargs if the function expects it.
49+
Returns the updated kwargs.
50+
"""
51+
context_param = find_context_parameter(fn)
52+
if context_param:
53+
kwargs[context_param] = context
6854
return kwargs

tests/server/fastmcp/resources/test_function_resources.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def my_func() -> str: # pragma: no cover
1818
name="test",
1919
description="test function",
2020
fn=my_func,
21+
context_kwarg=None,
2122
)
2223
assert str(resource.uri) == "fn://test"
2324
assert resource.name == "test"
@@ -36,6 +37,7 @@ def get_data() -> str:
3637
uri=AnyUrl("function://test"),
3738
name="test",
3839
fn=get_data,
40+
context_kwarg=None,
3941
)
4042
content = await resource.read()
4143
assert content == "Hello, world!"
@@ -52,6 +54,7 @@ def get_data() -> bytes:
5254
uri=AnyUrl("function://test"),
5355
name="test",
5456
fn=get_data,
57+
context_kwarg=None,
5558
)
5659
content = await resource.read()
5760
assert content == b"Hello, world!"
@@ -67,6 +70,7 @@ def get_data() -> dict[str, str]:
6770
uri=AnyUrl("function://test"),
6871
name="test",
6972
fn=get_data,
73+
context_kwarg=None,
7074
)
7175
content = await resource.read()
7276
assert isinstance(content, str)
@@ -83,6 +87,7 @@ def failing_func() -> str:
8387
uri=AnyUrl("function://test"),
8488
name="test",
8589
fn=failing_func,
90+
context_kwarg=None,
8691
)
8792
with pytest.raises(ValueError, match="Error reading resource function://test"):
8893
await resource.read()
@@ -98,6 +103,7 @@ class MyModel(BaseModel):
98103
uri=AnyUrl("function://test"),
99104
name="test",
100105
fn=lambda: MyModel(name="test"),
106+
context_kwarg=None,
101107
)
102108
content = await resource.read()
103109
assert content == '{\n "name": "test"\n}'
@@ -117,6 +123,7 @@ def get_data() -> CustomData:
117123
uri=AnyUrl("function://test"),
118124
name="test",
119125
fn=get_data,
126+
context_kwarg=None,
120127
)
121128
content = await resource.read()
122129
assert isinstance(content, str)
@@ -132,6 +139,7 @@ async def get_data() -> str:
132139
uri=AnyUrl("function://test"),
133140
name="test",
134141
fn=get_data,
142+
context_kwarg=None,
135143
)
136144
content = await resource.read()
137145
assert content == "Hello, world!"

0 commit comments

Comments
 (0)