Skip to content

Commit 2d2c20e

Browse files
authored
Merge pull request modelcontextprotocol#13 from jkoelker/jk/fastmcp
refactor(tools): use FastMCP registration hooks
2 parents f1a6f59 + 74c8c4f commit 2d2c20e

File tree

11 files changed

+251
-222
lines changed

11 files changed

+251
-222
lines changed

src/schwab_mcp/tools/__init__.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
from __future__ import annotations
22

3-
import inspect
4-
53
from mcp.server.fastmcp import FastMCP
64
from schwab.client import AsyncClient
75

8-
from schwab_mcp.tools.registry import register
9-
10-
# Import tool modules for their registration side effects
11-
from schwab_mcp.tools import tools as _tools # noqa: F401
12-
from schwab_mcp.tools import account as _account # noqa: F401
13-
from schwab_mcp.tools import history as _history # noqa: F401
14-
from schwab_mcp.tools import options as _options # noqa: F401
15-
from schwab_mcp.tools import orders as _orders # noqa: F401
16-
from schwab_mcp.tools import quotes as _quotes # noqa: F401
17-
from schwab_mcp.tools import transactions as _txns # noqa: F401
6+
from schwab_mcp.tools import account as _account
7+
from schwab_mcp.tools import history as _history
8+
from schwab_mcp.tools import options as _options
9+
from schwab_mcp.tools import orders as _orders
10+
from schwab_mcp.tools import quotes as _quotes
11+
from schwab_mcp.tools import tools as _tools
12+
from schwab_mcp.tools import transactions as _txns
1813

1914
_TOOL_MODULES = (
2015
_tools,
@@ -32,18 +27,10 @@ def register_tools(server: FastMCP, client: AsyncClient, *, allow_write: bool) -
3227
_ = client
3328

3429
for module in _TOOL_MODULES:
35-
for _, func in inspect.getmembers(module, inspect.iscoroutinefunction):
36-
if not getattr(func, "_registered_tool", False):
37-
continue
38-
if getattr(func, "_write", False) and not allow_write:
39-
continue
40-
annotations = getattr(func, "_tool_annotations", None)
41-
server.add_tool(
42-
func,
43-
name=func.__name__,
44-
description=func.__doc__,
45-
annotations=annotations,
46-
)
47-
48-
49-
__all__ = ["register_tools", "register"]
30+
register_module = getattr(module, "register", None)
31+
if register_module is None:
32+
raise AttributeError(f"Tool module {module.__name__} missing register()")
33+
register_module(server, allow_write=allow_write)
34+
35+
36+
__all__ = ["register_tools"]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Awaitable, Callable
4+
from typing import Any
5+
6+
from mcp.server.fastmcp import FastMCP
7+
from mcp.types import ToolAnnotations
8+
9+
ToolFn = Callable[..., Awaitable[Any]]
10+
11+
12+
def register_tool(
13+
server: FastMCP,
14+
func: ToolFn,
15+
*,
16+
write: bool = False,
17+
annotations: ToolAnnotations | None = None,
18+
) -> None:
19+
"""Register a Schwab tool using FastMCP's decorator plumbing."""
20+
21+
tool_annotations = annotations
22+
if tool_annotations is None:
23+
if write:
24+
tool_annotations = ToolAnnotations(
25+
readOnlyHint=False,
26+
destructiveHint=True,
27+
)
28+
else:
29+
tool_annotations = ToolAnnotations(
30+
readOnlyHint=True,
31+
)
32+
else:
33+
update: dict[str, Any] = {}
34+
if tool_annotations.readOnlyHint is None:
35+
update["readOnlyHint"] = not write
36+
if write and tool_annotations.destructiveHint is None:
37+
update["destructiveHint"] = True
38+
if update:
39+
tool_annotations = tool_annotations.model_copy(update=update)
40+
41+
server.tool(
42+
name=func.__name__,
43+
description=func.__doc__,
44+
annotations=tool_annotations,
45+
)(func)
46+
47+
48+
__all__ = ["register_tool"]

src/schwab_mcp/tools/account.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from typing import Annotated
44

5+
from mcp.server.fastmcp import FastMCP
6+
57
from schwab_mcp.context import SchwabContext, SchwabServerContext
6-
from schwab_mcp.tools.registry import register
8+
from schwab_mcp.tools._registration import register_tool
79
from schwab_mcp.tools.utils import JSONType, call
810

911

10-
@register
1112
async def get_account_numbers(
1213
ctx: SchwabContext,
1314
) -> JSONType:
@@ -18,7 +19,6 @@ async def get_account_numbers(
1819
return await call(context.accounts.get_account_numbers)
1920

2021

21-
@register
2222
async def get_accounts(
2323
ctx: SchwabContext,
2424
) -> JSONType:
@@ -29,7 +29,6 @@ async def get_accounts(
2929
return await call(context.accounts.get_accounts)
3030

3131

32-
@register
3332
async def get_accounts_with_positions(
3433
ctx: SchwabContext,
3534
) -> JSONType:
@@ -43,7 +42,6 @@ async def get_accounts_with_positions(
4342
)
4443

4544

46-
@register
4745
async def get_account(
4846
ctx: SchwabContext,
4947
account_hash: Annotated[str, "Account hash for the Schwab account"],
@@ -55,7 +53,6 @@ async def get_account(
5553
return await call(context.accounts.get_account, account_hash)
5654

5755

58-
@register
5956
async def get_account_with_positions(
6057
ctx: SchwabContext,
6158
account_hash: Annotated[str, "Account hash for the Schwab account"],
@@ -71,7 +68,6 @@ async def get_account_with_positions(
7168
)
7269

7370

74-
@register
7571
async def get_user_preferences(
7672
ctx: SchwabContext,
7773
) -> JSONType:
@@ -80,3 +76,19 @@ async def get_user_preferences(
8076
"""
8177
context: SchwabServerContext = ctx.request_context.lifespan_context
8278
return await call(context.accounts.get_user_preferences)
79+
80+
81+
_READ_ONLY_TOOLS = (
82+
get_account_numbers,
83+
get_accounts,
84+
get_accounts_with_positions,
85+
get_account,
86+
get_account_with_positions,
87+
get_user_preferences,
88+
)
89+
90+
91+
def register(server: FastMCP, *, allow_write: bool) -> None:
92+
_ = allow_write
93+
for func in _READ_ONLY_TOOLS:
94+
register_tool(server, func)

src/schwab_mcp/tools/history.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33
from typing import Annotated
44

55
import datetime
6+
from mcp.server.fastmcp import FastMCP
67

78
from schwab_mcp.context import SchwabContext, SchwabServerContext
8-
from schwab_mcp.tools.registry import register
9+
from schwab_mcp.tools._registration import register_tool
910
from schwab_mcp.tools.utils import JSONType, call
1011

1112

1213
def _parse_iso_datetime(value: str | None) -> datetime.datetime | None:
1314
return datetime.datetime.fromisoformat(value) if value is not None else None
1415

1516

16-
@register
1717
async def get_advanced_price_history(
1818
ctx: SchwabContext,
1919
symbol: Annotated[str, "Symbol of the security"],
@@ -95,7 +95,6 @@ async def get_advanced_price_history(
9595
)
9696

9797

98-
@register
9998
async def get_price_history_every_minute(
10099
ctx: SchwabContext,
101100
symbol: Annotated[str, "Symbol of the security"],
@@ -127,7 +126,6 @@ async def get_price_history_every_minute(
127126
)
128127

129128

130-
@register
131129
async def get_price_history_every_five_minutes(
132130
ctx: SchwabContext,
133131
symbol: Annotated[str, "Symbol of the security"],
@@ -159,7 +157,6 @@ async def get_price_history_every_five_minutes(
159157
)
160158

161159

162-
@register
163160
async def get_price_history_every_ten_minutes(
164161
ctx: SchwabContext,
165162
symbol: Annotated[str, "Symbol of the security"],
@@ -191,7 +188,6 @@ async def get_price_history_every_ten_minutes(
191188
)
192189

193190

194-
@register
195191
async def get_price_history_every_fifteen_minutes(
196192
ctx: SchwabContext,
197193
symbol: Annotated[str, "Symbol of the security"],
@@ -223,7 +219,6 @@ async def get_price_history_every_fifteen_minutes(
223219
)
224220

225221

226-
@register
227222
async def get_price_history_every_thirty_minutes(
228223
ctx: SchwabContext,
229224
symbol: Annotated[str, "Symbol of the security"],
@@ -255,7 +250,6 @@ async def get_price_history_every_thirty_minutes(
255250
)
256251

257252

258-
@register
259253
async def get_price_history_every_day(
260254
ctx: SchwabContext,
261255
symbol: Annotated[str, "Symbol of the security to fetch price history for"],
@@ -293,7 +287,6 @@ async def get_price_history_every_day(
293287
)
294288

295289

296-
@register
297290
async def get_price_history_every_week(
298291
ctx: SchwabContext,
299292
symbol: Annotated[str, "Symbol of the security"],
@@ -323,3 +316,21 @@ async def get_price_history_every_week(
323316
need_extended_hours_data=extended_hours,
324317
need_previous_close=previous_close,
325318
)
319+
320+
321+
_READ_ONLY_TOOLS = (
322+
get_advanced_price_history,
323+
get_price_history_every_minute,
324+
get_price_history_every_five_minutes,
325+
get_price_history_every_ten_minutes,
326+
get_price_history_every_fifteen_minutes,
327+
get_price_history_every_thirty_minutes,
328+
get_price_history_every_day,
329+
get_price_history_every_week,
330+
)
331+
332+
333+
def register(server: FastMCP, *, allow_write: bool) -> None:
334+
_ = allow_write
335+
for func in _READ_ONLY_TOOLS:
336+
register_tool(server, func)

src/schwab_mcp/tools/options.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from typing import Annotated
44

55
import datetime
6+
from mcp.server.fastmcp import FastMCP
67

78
from schwab_mcp.context import SchwabContext, SchwabServerContext
8-
from schwab_mcp.tools.registry import register
9+
from schwab_mcp.tools._registration import register_tool
910
from schwab_mcp.tools.utils import JSONType, call
1011

1112

@@ -19,7 +20,6 @@ def _parse_date(value: str | datetime.date | None) -> datetime.date | None:
1920
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
2021

2122

22-
@register
2323
async def get_option_chain(
2424
ctx: SchwabContext,
2525
symbol: Annotated[str, "Symbol of the underlying security (e.g., 'AAPL', 'SPY')"],
@@ -66,7 +66,6 @@ async def get_option_chain(
6666
)
6767

6868

69-
@register
7069
async def get_advanced_option_chain(
7170
ctx: SchwabContext,
7271
symbol: Annotated[str, "Symbol of the underlying security"],
@@ -156,7 +155,6 @@ async def get_advanced_option_chain(
156155
)
157156

158157

159-
@register
160158
async def get_option_expiration_chain(
161159
ctx: SchwabContext,
162160
symbol: Annotated[str, "Symbol of the underlying security"],
@@ -167,3 +165,16 @@ async def get_option_expiration_chain(
167165
context: SchwabServerContext = ctx.request_context.lifespan_context
168166
client = context.options
169167
return await call(client.get_option_expiration_chain, symbol)
168+
169+
170+
_READ_ONLY_TOOLS = (
171+
get_option_chain,
172+
get_advanced_option_chain,
173+
get_option_expiration_chain,
174+
)
175+
176+
177+
def register(server: FastMCP, *, allow_write: bool) -> None:
178+
_ = allow_write
179+
for func in _READ_ONLY_TOOLS:
180+
register_tool(server, func)

0 commit comments

Comments
 (0)