Skip to content

Commit 7a6dfa7

Browse files
committed
Add tool enable/disable functionality with client notifications
1 parent f2f4dbd commit 7a6dfa7

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from mcp.server.fastmcp.exceptions import ToolError
1111
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
12-
from mcp.types import ToolAnnotations
12+
from mcp.types import ServerNotification, ToolAnnotations, ToolListChangedNotification
1313

1414
if TYPE_CHECKING:
1515
from mcp.server.fastmcp.server import Context
@@ -35,6 +35,7 @@ class Tool(BaseModel):
3535
annotations: ToolAnnotations | None = Field(
3636
None, description="Optional annotations for the tool"
3737
)
38+
enabled: bool = Field(default=True, description="Whether the tool is enabled")
3839

3940
@classmethod
4041
def from_function(
@@ -100,6 +101,32 @@ async def run(
100101
except Exception as e:
101102
raise ToolError(f"Error executing tool {self.name}: {e}") from e
102103

104+
async def enable(
105+
self, context: Context[ServerSessionT, LifespanContextT] | None = None
106+
) -> None:
107+
"""Enable the tool and notify clients."""
108+
if not self.enabled:
109+
self.enabled = True
110+
if context and context.session:
111+
notification = ToolListChangedNotification(
112+
method="notifications/tools/list_changed"
113+
)
114+
server_notification = ServerNotification.model_validate(notification)
115+
await context.session.send_notification(server_notification)
116+
117+
async def disable(
118+
self, context: Context[ServerSessionT, LifespanContextT] | None = None
119+
) -> None:
120+
"""Disable the tool and notify clients."""
121+
if self.enabled:
122+
self.enabled = False
123+
if context and context.session:
124+
notification = ToolListChangedNotification(
125+
method="notifications/tools/list_changed"
126+
)
127+
server_notification = ServerNotification.model_validate(notification)
128+
await context.session.send_notification(server_notification)
129+
103130

104131
def _is_async_callable(obj: Any) -> bool:
105132
while isinstance(obj, functools.partial):

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def get_tool(self, name: str) -> Tool | None:
3939
return self._tools.get(name)
4040

4141
def list_tools(self) -> list[Tool]:
42-
"""List all registered tools."""
43-
return list(self._tools.values())
42+
"""List all enabled registered tools."""
43+
return [tool for tool in self._tools.values() if tool.enabled]
4444

4545
def add_tool(
4646
self,
@@ -72,4 +72,7 @@ async def call_tool(
7272
if not tool:
7373
raise ToolError(f"Unknown tool: {name}")
7474

75+
if not tool.enabled:
76+
raise ToolError(f"Tool is disabled: {name}")
77+
7578
return await tool.run(arguments, context=context)

tests/server/fastmcp/test_tool_manager.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,52 @@ def echo(message: str) -> str:
453453
assert tools[0].annotations is not None
454454
assert tools[0].annotations.title == "Echo Tool"
455455
assert tools[0].annotations.readOnlyHint is True
456+
457+
458+
class TestToolEnableDisable:
459+
"""Test enabling and disabling tools."""
460+
461+
@pytest.mark.anyio
462+
async def test_enable_disable_tool(self):
463+
"""Test enabling and disabling a tool."""
464+
465+
def add(a: int, b: int) -> int:
466+
"""Add two numbers."""
467+
return a + b
468+
469+
manager = ToolManager()
470+
tool = manager.add_tool(add)
471+
472+
# Tool should be enabled by default
473+
assert tool.enabled is True
474+
475+
# Disable the tool
476+
await tool.disable()
477+
assert tool.enabled is False
478+
479+
# Enable the tool
480+
await tool.enable()
481+
assert tool.enabled is True
482+
483+
@pytest.mark.anyio
484+
async def test_enable_disable_no_change(self):
485+
"""Test enabling and disabling a tool when there's no state change."""
486+
487+
def add(a: int, b: int) -> int:
488+
"""Add two numbers."""
489+
return a + b
490+
491+
manager = ToolManager()
492+
tool = manager.add_tool(add)
493+
494+
# Enable an already enabled tool (should not change state)
495+
await tool.enable()
496+
assert tool.enabled is True
497+
498+
# Disable the tool
499+
await tool.disable()
500+
assert tool.enabled is False
501+
502+
# Disable an already disabled tool (should not change state)
503+
await tool.disable()
504+
assert tool.enabled is False

0 commit comments

Comments
 (0)