Skip to content

Commit 04c4a72

Browse files
committed
add authorizer plugin to enable fine granded authorization checks on tools/resources/prompts
1 parent 4d45bb8 commit 04c4a72

File tree

6 files changed

+193
-19
lines changed

6 files changed

+193
-19
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import abc
4+
from typing import TYPE_CHECKING, Any
5+
from pydantic import AnyUrl
6+
7+
from mcp.shared.context import LifespanContextT, RequestT
8+
9+
if TYPE_CHECKING:
10+
from mcp.server.fastmcp.server import Context
11+
from mcp.server.session import ServerSessionT
12+
13+
14+
class Authorizer:
15+
__metaclass__ = abc.ABCMeta
16+
17+
@abc.abstractmethod
18+
def permit_get_tool(self, name: str) -> bool:
19+
"""Check if the specified tool can be retrieved from the associated mcp server"""
20+
return False
21+
22+
@abc.abstractmethod
23+
def permit_list_tool(self, name: str) -> bool:
24+
"""Check if the specified tool can be listed from the associated mcp server"""
25+
return False
26+
27+
@abc.abstractmethod
28+
def permit_call_tool(
29+
self,
30+
name: str,
31+
arguments: dict[str, Any],
32+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
33+
) -> bool:
34+
"""Check if the specified tool can be called from the associated mcp server"""
35+
return False
36+
37+
@abc.abstractmethod
38+
def permit_get_resource(self, resource: AnyUrl | str) -> bool:
39+
"""Check if the specified resource can be retrieved from the associated mcp server"""
40+
return False
41+
42+
@abc.abstractmethod
43+
def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool:
44+
"""Check if the specified resource can be created on the associated mcp server"""
45+
return False
46+
47+
@abc.abstractmethod
48+
def permit_list_resource(self, resource: AnyUrl | str) -> bool:
49+
"""Check if the specified resource can be listed from the associated mcp server"""
50+
return False
51+
52+
@abc.abstractmethod
53+
def permit_list_template(self, resource: AnyUrl | str) -> bool:
54+
"""Check if the specified template can be listed from the associated mcp server"""
55+
return False
56+
57+
@abc.abstractmethod
58+
def permit_get_prompt(self, name: str) -> bool:
59+
"""Check if the specified prompt can be retrieved from the associated mcp server"""
60+
return False
61+
62+
@abc.abstractmethod
63+
def permit_list_prompt(self, name: str) -> bool:
64+
"""Check if the specified prompt can be listed from the associated mcp server"""
65+
return False
66+
67+
@abc.abstractmethod
68+
def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool:
69+
"""Check if the specified prompt can be rendered from the associated mcp server"""
70+
return False
71+
72+
class AllAllAuthorizer(Authorizer):
73+
def permit_get_tool(self, name: str) -> bool:
74+
return True
75+
76+
def permit_list_tool(self, name: str) -> bool:
77+
return True
78+
79+
def permit_call_tool(
80+
self,
81+
name: str,
82+
arguments: dict[str, Any],
83+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
84+
) -> bool:
85+
return True
86+
87+
def permit_get_resource(self, resource: AnyUrl | str) -> bool:
88+
return True
89+
90+
def permit_create_resource(self, uri: str, params: dict[str, Any]) -> bool:
91+
return True
92+
93+
def permit_list_resource(self, resource: AnyUrl | str) -> bool:
94+
return True
95+
96+
def permit_list_template(self, resource: AnyUrl | str) -> bool:
97+
return True
98+
99+
def permit_get_prompt(self, name: str) -> bool:
100+
return True
101+
102+
def permit_list_prompt(self, name: str) -> bool:
103+
return True
104+
105+
def permit_render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> bool:
106+
return True
107+

src/mcp/server/fastmcp/prompts/manager.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import Any
44

5+
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
56
from mcp.server.fastmcp.prompts.base import Message, Prompt
67
from mcp.server.fastmcp.utilities.logging import get_logger
78

@@ -11,17 +12,25 @@
1112
class PromptManager:
1213
"""Manages FastMCP prompts."""
1314

14-
def __init__(self, warn_on_duplicate_prompts: bool = True):
15+
def __init__(
16+
self,
17+
warn_on_duplicate_prompts: bool = True,
18+
authorizer: Authorizer = AllAllAuthorizer(),
19+
):
1520
self._prompts: dict[str, Prompt] = {}
21+
self._authorizer = authorizer
1622
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
1723

1824
def get_prompt(self, name: str) -> Prompt | None:
1925
"""Get prompt by name."""
20-
return self._prompts.get(name)
26+
if self._authorizer.permit_get_prompt(name):
27+
return self._prompts.get(name)
28+
else:
29+
return None
2130

2231
def list_prompts(self) -> list[Prompt]:
2332
"""List all registered prompts."""
24-
return list(self._prompts.values())
33+
return [prompt for name, prompt in self._prompts.items() if self._authorizer.permit_list_prompt(name)]
2534

2635
def add_prompt(
2736
self,
@@ -44,5 +53,7 @@ async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None
4453
prompt = self.get_prompt(name)
4554
if not prompt:
4655
raise ValueError(f"Unknown prompt: {name}")
47-
48-
return await prompt.render(arguments)
56+
if self._authorizer.permit_render_prompt(name, arguments):
57+
return await prompt.render(arguments)
58+
else:
59+
raise ValueError(f"Unknown prompt: {name}")

src/mcp/server/fastmcp/resources/resource_manager.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pydantic import AnyUrl
77

8+
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
89
from mcp.server.fastmcp.resources.base import Resource
910
from mcp.server.fastmcp.resources.templates import ResourceTemplate
1011
from mcp.server.fastmcp.utilities.logging import get_logger
@@ -15,10 +16,15 @@
1516
class ResourceManager:
1617
"""Manages FastMCP resources."""
1718

18-
def __init__(self, warn_on_duplicate_resources: bool = True):
19+
def __init__(
20+
self,
21+
warn_on_duplicate_resources: bool = True,
22+
authorizer: Authorizer = AllAllAuthorizer(),
23+
):
1924
self._resources: dict[str, Resource] = {}
2025
self._templates: dict[str, ResourceTemplate] = {}
2126
self.warn_on_duplicate_resources = warn_on_duplicate_resources
27+
self._authorizer = authorizer
2228

2329
def add_resource(self, resource: Resource) -> Resource:
2430
"""Add a resource to the manager.
@@ -74,13 +80,19 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
7480

7581
# First check concrete resources
7682
if resource := self._resources.get(uri_str):
77-
return resource
83+
if self._authorizer.permit_get_resource(uri_str):
84+
return resource
85+
else:
86+
raise ValueError(f"Unknown resource: {uri}")
7887

7988
# Then check templates
8089
for template in self._templates.values():
8190
if params := template.matches(uri_str):
8291
try:
83-
return await template.create_resource(uri_str, params)
92+
if self._authorizer.permit_create_resource(uri_str, params):
93+
return await template.create_resource(uri_str, params)
94+
else:
95+
raise ValueError(f"Unknown resource: {uri}")
8496
except Exception as e:
8597
raise ValueError(f"Error creating resource from template: {e}")
8698

@@ -89,9 +101,9 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
89101
def list_resources(self) -> list[Resource]:
90102
"""List all registered resources."""
91103
logger.debug("Listing resources", extra={"count": len(self._resources)})
92-
return list(self._resources.values())
104+
return [resource for uri, resource in self._resources.items() if self._authorizer.permit_list_resource(uri)]
93105

94106
def list_templates(self) -> list[ResourceTemplate]:
95107
"""List all registered templates."""
96108
logger.debug("Listing templates", extra={"count": len(self._templates)})
97-
return list(self._templates.values())
109+
return [template for uri, template in self._templates.items() if self._authorizer.permit_list_template(uri)]

src/mcp/server/fastmcp/server.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier
3434
from mcp.server.auth.settings import AuthSettings
3535
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
36+
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
3637
from mcp.server.fastmcp.exceptions import ResourceError
3738
from mcp.server.fastmcp.prompts import Prompt, PromptManager
3839
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
@@ -120,6 +121,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
120121
# Transport security settings (DNS rebinding protection)
121122
transport_security: TransportSecuritySettings | None = None
122123

124+
authorizer: Authorizer = AllAllAuthorizer()
125+
123126

124127
def lifespan_wrapper(
125128
app: FastMCP,
@@ -152,9 +155,19 @@ def __init__(
152155
instructions=instructions,
153156
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
154157
)
155-
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
156-
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
157-
self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts)
158+
self._tool_manager = ToolManager(
159+
tools=tools,
160+
warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools,
161+
authorizer=self.settings.authorizer,
162+
)
163+
self._resource_manager = ResourceManager(
164+
warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources,
165+
authorizer=self.settings.authorizer,
166+
)
167+
self._prompt_manager = PromptManager(
168+
warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts,
169+
authorizer=self.settings.authorizer,
170+
)
158171
# Validate auth configuration
159172
if self.settings.auth is not None:
160173
if auth_server_provider and token_verifier:

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable
44
from typing import TYPE_CHECKING, Any
55

6+
from mcp.server.fastmcp.authorizer import AllAllAuthorizer, Authorizer
67
from mcp.server.fastmcp.exceptions import ToolError
78
from mcp.server.fastmcp.tools.base import Tool
89
from mcp.server.fastmcp.utilities.logging import get_logger
@@ -24,6 +25,7 @@ def __init__(
2425
warn_on_duplicate_tools: bool = True,
2526
*,
2627
tools: list[Tool] | None = None,
28+
authorizer: Authorizer = AllAllAuthorizer(),
2729
):
2830
self._tools: dict[str, Tool] = {}
2931
if tools is not None:
@@ -32,15 +34,19 @@ def __init__(
3234
logger.warning(f"Tool already exists: {tool.name}")
3335
self._tools[tool.name] = tool
3436

35-
self.warn_on_duplicate_tools = warn_on_duplicate_tools
37+
self.warn_on_duplicate_tools = (warn_on_duplicate_tools,)
38+
self._authorizer = authorizer
3639

3740
def get_tool(self, name: str) -> Tool | None:
3841
"""Get tool by name."""
39-
return self._tools.get(name)
42+
if self._authorizer.permit_get_tool(name):
43+
return self._tools.get(name)
44+
else:
45+
return None
4046

4147
def list_tools(self) -> list[Tool]:
4248
"""List all registered tools."""
43-
return list(self._tools.values())
49+
return [tool for name, tool in self._tools.items() if self._authorizer.permit_list_tool(name)]
4450

4551
def add_tool(
4652
self,
@@ -67,8 +73,8 @@ async def call_tool(
6773
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
6874
) -> Any:
6975
"""Call a tool by name with arguments."""
70-
tool = self.get_tool(name)
71-
if not tool:
76+
tool = self._tools.get(name)
77+
if not tool or not self._authorizer.permit_call_tool(name, arguments, context):
7278
raise ToolError(f"Unknown tool: {name}")
7379

7480
return await tool.run(arguments, context=context)

tests/server/fastmcp/test_tool_manager.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import BaseModel
66

77
from mcp.server.fastmcp import Context, FastMCP
8+
from mcp.server.fastmcp.authorizer import Authorizer
89
from mcp.server.fastmcp.exceptions import ToolError
910
from mcp.server.fastmcp.tools import Tool, ToolManager
1011
from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata
@@ -171,7 +172,7 @@ def f(x: int) -> int:
171172

172173
manager = ToolManager()
173174
manager.add_tool(f)
174-
manager.warn_on_duplicate_tools = False
175+
manager.warn_on_duplicate_tools = False # type: ignore
175176
with caplog.at_level(logging.WARNING):
176177
manager.add_tool(f)
177178
assert "Tool already exists: f" not in caplog.text
@@ -311,6 +312,30 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]:
311312
)
312313
assert result == ["rex", "gertrude"]
313314

315+
@pytest.mark.anyio
316+
async def test_call_tool_not_permitted(self):
317+
async def double(n: int) -> int:
318+
"""Double a number."""
319+
return n * 2
320+
321+
class TestAuthorizer(Authorizer):
322+
allow: bool = True
323+
324+
def permit_list_tool(self, name):
325+
return self.allow
326+
327+
def permit_call_tool(self, name, arguments, context=None):
328+
return self.allow
329+
330+
authorizer = TestAuthorizer()
331+
manager = ToolManager(authorizer=authorizer)
332+
manager.add_tool(double)
333+
result = await manager.call_tool("double", {"n": 5})
334+
assert result == 10
335+
authorizer.allow = False
336+
with pytest.raises(ToolError, match="Unknown tool: double"):
337+
await manager.call_tool("double", {"n": 5})
338+
314339

315340
class TestToolSchema:
316341
@pytest.mark.anyio

0 commit comments

Comments
 (0)