Skip to content

Commit dd93d9a

Browse files
authored
Merge pull request modelcontextprotocol#11 from jkoelker/jk/fastmcp
refactor(tools): use fastmcp context
2 parents 2e0e6db + 6aad954 commit dd93d9a

File tree

17 files changed

+166
-245
lines changed

17 files changed

+166
-245
lines changed

src/schwab_mcp/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, cast
55

66
from schwab.client import AsyncClient
7+
from mcp.server.fastmcp import Context as MCPContext
78

89
if TYPE_CHECKING:
910
from schwab_mcp.tools._protocols import (
@@ -44,4 +45,7 @@ def __post_init__(self) -> None:
4445
self.transactions = cast(TransactionsClient, self.client)
4546

4647

47-
__all__ = ["SchwabServerContext"]
48+
SchwabContext = MCPContext[Any, SchwabServerContext, Any]
49+
50+
51+
__all__ = ["SchwabServerContext", "SchwabContext"]

src/schwab_mcp/tools/__init__.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3+
import inspect
4+
35
from mcp.server.fastmcp import FastMCP
46
from schwab.client import AsyncClient
57

6-
from schwab_mcp.tools.registry import iter_registered_tools, register
7-
from schwab_mcp.tools import utils as tool_utils
8+
from schwab_mcp.tools.registry import register
89

910
# Import tool modules for their registration side effects
1011
from schwab_mcp.tools import tools as _tools # noqa: F401
@@ -15,22 +16,34 @@
1516
from schwab_mcp.tools import quotes as _quotes # noqa: F401
1617
from schwab_mcp.tools import transactions as _txns # noqa: F401
1718

19+
_TOOL_MODULES = (
20+
_tools,
21+
_account,
22+
_history,
23+
_options,
24+
_orders,
25+
_quotes,
26+
_txns,
27+
)
28+
1829

1930
def register_tools(server: FastMCP, client: AsyncClient, *, allow_write: bool) -> None:
2031
"""Register all Schwab tools with the provided FastMCP server."""
21-
tool_utils.set_write_enabled(allow_write)
22-
23-
for func in iter_registered_tools():
24-
if getattr(func, "_write", False) and not allow_write:
25-
continue
26-
annotations = getattr(func, "_tool_annotations", None)
27-
tool_kwargs = {
28-
"name": func.__name__,
29-
"description": func.__doc__,
30-
}
31-
if annotations is not None:
32-
tool_kwargs["annotations"] = annotations
33-
server.tool(**tool_kwargs)(func)
32+
_ = client
33+
34+
for module in _TOOL_MODULES:
35+
for _, func in inspect.getmembers(module, inspect.iscoroutinefunction):
36+
if not getattr(func, "_registered_tool", False):
37+
continue
38+
if getattr(func, "_write", False) and not allow_write:
39+
continue
40+
annotations = getattr(func, "_tool_annotations", None)
41+
server.add_tool(
42+
func,
43+
name=func.__name__,
44+
description=func.__doc__,
45+
annotations=annotations,
46+
)
3447

3548

3649
__all__ = ["register_tools", "register"]

src/schwab_mcp/tools/account.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,41 @@
22

33
from typing import Annotated
44

5-
from mcp.server.fastmcp import Context
6-
5+
from schwab_mcp.context import SchwabContext, SchwabServerContext
76
from schwab_mcp.tools.registry import register
8-
from schwab_mcp.tools.utils import call, get_context
7+
from schwab_mcp.tools.utils import call
98

109

1110
@register
1211
async def get_account_numbers(
13-
ctx: Context,
12+
ctx: SchwabContext,
1413
) -> str:
1514
"""
1615
Returns mapping of account IDs to account hashes. Hashes required for account-specific calls. Use first.
1716
"""
18-
context = get_context(ctx)
17+
context: SchwabServerContext = ctx.request_context.lifespan_context
1918
return await call(context.accounts.get_account_numbers)
2019

2120

2221
@register
2322
async def get_accounts(
24-
ctx: Context,
23+
ctx: SchwabContext,
2524
) -> str:
2625
"""
2726
Returns balances/info for all linked accounts (funds, cash, margin). Does not return hashes; use get_account_numbers first.
2827
"""
29-
context = get_context(ctx)
28+
context: SchwabServerContext = ctx.request_context.lifespan_context
3029
return await call(context.accounts.get_accounts)
3130

3231

3332
@register
3433
async def get_accounts_with_positions(
35-
ctx: Context,
34+
ctx: SchwabContext,
3635
) -> str:
3736
"""
3837
Returns balances, info, and positions (holdings, cost, gain/loss) for all linked accounts. Does not return hashes; use get_account_numbers first.
3938
"""
40-
context = get_context(ctx)
39+
context: SchwabServerContext = ctx.request_context.lifespan_context
4140
return await call(
4241
context.accounts.get_accounts,
4342
fields=[context.accounts.Account.Fields.POSITIONS],
@@ -46,25 +45,25 @@ async def get_accounts_with_positions(
4645

4746
@register
4847
async def get_account(
49-
ctx: Context,
48+
ctx: SchwabContext,
5049
account_hash: Annotated[str, "Account hash for the Schwab account"],
5150
) -> str:
5251
"""
5352
Returns balance/info for a specific account via account_hash (from get_account_numbers). Includes funds, cash, margin info.
5453
"""
55-
context = get_context(ctx)
54+
context: SchwabServerContext = ctx.request_context.lifespan_context
5655
return await call(context.accounts.get_account, account_hash)
5756

5857

5958
@register
6059
async def get_account_with_positions(
61-
ctx: Context,
60+
ctx: SchwabContext,
6261
account_hash: Annotated[str, "Account hash for the Schwab account"],
6362
) -> str:
6463
"""
6564
Returns balance, info, and positions for a specific account via account_hash. Includes holdings, quantity, cost basis, unrealized gain/loss.
6665
"""
67-
context = get_context(ctx)
66+
context: SchwabServerContext = ctx.request_context.lifespan_context
6867
return await call(
6968
context.accounts.get_account,
7069
account_hash,
@@ -74,10 +73,10 @@ async def get_account_with_positions(
7473

7574
@register
7675
async def get_user_preferences(
77-
ctx: Context,
76+
ctx: SchwabContext,
7877
) -> str:
7978
"""
8079
Returns user preferences (nicknames, display settings, notifications) for all linked accounts.
8180
"""
82-
context = get_context(ctx)
81+
context: SchwabServerContext = ctx.request_context.lifespan_context
8382
return await call(context.accounts.get_user_preferences)

src/schwab_mcp/tools/history.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from typing import Annotated
44

55
import datetime
6-
from mcp.server.fastmcp import Context
76

7+
from schwab_mcp.context import SchwabContext, SchwabServerContext
88
from schwab_mcp.tools.registry import register
9-
from schwab_mcp.tools.utils import call, get_context
9+
from schwab_mcp.tools.utils import call
1010

1111

1212
def _parse_iso_datetime(value: str | None) -> datetime.datetime | None:
@@ -15,7 +15,7 @@ def _parse_iso_datetime(value: str | None) -> datetime.datetime | None:
1515

1616
@register
1717
async def get_advanced_price_history(
18-
ctx: Context,
18+
ctx: SchwabContext,
1919
symbol: Annotated[str, "Symbol of the security"],
2020
period_type: Annotated[
2121
str | None, "Period type: DAY, MONTH, YEAR, YEAR_TO_DATE"
@@ -60,7 +60,7 @@ async def get_advanced_price_history(
6060
YEAR_TO_DATE: DAILY, WEEKLY (default)
6161
Dates must be in ISO format.
6262
"""
63-
context = get_context(ctx)
63+
context: SchwabServerContext = ctx.request_context.lifespan_context
6464
client = context.price_history
6565

6666
start_dt = _parse_iso_datetime(start_datetime)
@@ -97,7 +97,7 @@ async def get_advanced_price_history(
9797

9898
@register
9999
async def get_price_history_every_minute(
100-
ctx: Context,
100+
ctx: SchwabContext,
101101
symbol: Annotated[str, "Symbol of the security"],
102102
start_datetime: Annotated[
103103
str | None, "Start date for history (ISO format, e.g., '2023-01-01T09:30:00')"
@@ -111,7 +111,7 @@ async def get_price_history_every_minute(
111111
"""
112112
Get OHLCV price history per minute. For detailed intraday analysis. Max 48 days history. Dates ISO format.
113113
"""
114-
context = get_context(ctx)
114+
context: SchwabServerContext = ctx.request_context.lifespan_context
115115
client = context.price_history
116116

117117
start_dt = _parse_iso_datetime(start_datetime)
@@ -129,7 +129,7 @@ async def get_price_history_every_minute(
129129

130130
@register
131131
async def get_price_history_every_five_minutes(
132-
ctx: Context,
132+
ctx: SchwabContext,
133133
symbol: Annotated[str, "Symbol of the security"],
134134
start_datetime: Annotated[
135135
str | None, "Start date for history (ISO format, e.g., '2023-01-01T09:30:00')"
@@ -143,7 +143,7 @@ async def get_price_history_every_five_minutes(
143143
"""
144144
Get OHLCV price history per 5 minutes. Balance between detail and noise. Approx. 9 months history. Dates ISO format.
145145
"""
146-
context = get_context(ctx)
146+
context: SchwabServerContext = ctx.request_context.lifespan_context
147147
client = context.price_history
148148

149149
start_dt = _parse_iso_datetime(start_datetime)
@@ -161,7 +161,7 @@ async def get_price_history_every_five_minutes(
161161

162162
@register
163163
async def get_price_history_every_ten_minutes(
164-
ctx: Context,
164+
ctx: SchwabContext,
165165
symbol: Annotated[str, "Symbol of the security"],
166166
start_datetime: Annotated[
167167
str | None, "Start date for history (ISO format, e.g., '2023-01-01T09:30:00')"
@@ -175,7 +175,7 @@ async def get_price_history_every_ten_minutes(
175175
"""
176176
Get OHLCV price history per 10 minutes. Good for intraday trends/levels. Approx. 9 months history. Dates ISO format.
177177
"""
178-
context = get_context(ctx)
178+
context: SchwabServerContext = ctx.request_context.lifespan_context
179179
client = context.price_history
180180

181181
start_dt = _parse_iso_datetime(start_datetime)
@@ -193,7 +193,7 @@ async def get_price_history_every_ten_minutes(
193193

194194
@register
195195
async def get_price_history_every_fifteen_minutes(
196-
ctx: Context,
196+
ctx: SchwabContext,
197197
symbol: Annotated[str, "Symbol of the security"],
198198
start_datetime: Annotated[
199199
str | None, "Start date for history (ISO format, e.g., '2023-01-01T09:30:00')"
@@ -207,7 +207,7 @@ async def get_price_history_every_fifteen_minutes(
207207
"""
208208
Get OHLCV price history per 15 minutes. Shows significant intraday moves, filters noise. Approx. 9 months history. Dates ISO format.
209209
"""
210-
context = get_context(ctx)
210+
context: SchwabServerContext = ctx.request_context.lifespan_context
211211
client = context.price_history
212212

213213
start_dt = _parse_iso_datetime(start_datetime)
@@ -225,7 +225,7 @@ async def get_price_history_every_fifteen_minutes(
225225

226226
@register
227227
async def get_price_history_every_thirty_minutes(
228-
ctx: Context,
228+
ctx: SchwabContext,
229229
symbol: Annotated[str, "Symbol of the security"],
230230
start_datetime: Annotated[
231231
str | None, "Start date for history (ISO format, e.g., '2023-01-01T09:30:00')"
@@ -239,7 +239,7 @@ async def get_price_history_every_thirty_minutes(
239239
"""
240240
Get OHLCV price history per 30 minutes. For broader intraday trends, filters noise. Approx. 9 months history. Dates ISO format.
241241
"""
242-
context = get_context(ctx)
242+
context: SchwabServerContext = ctx.request_context.lifespan_context
243243
client = context.price_history
244244

245245
start_dt = _parse_iso_datetime(start_datetime)
@@ -257,7 +257,7 @@ async def get_price_history_every_thirty_minutes(
257257

258258
@register
259259
async def get_price_history_every_day(
260-
ctx: Context,
260+
ctx: SchwabContext,
261261
symbol: Annotated[str, "Symbol of the security to fetch price history for"],
262262
start_datetime: Annotated[
263263
str | None,
@@ -277,7 +277,7 @@ async def get_price_history_every_day(
277277
"""
278278
Get daily OHLCV price history. For medium/long-term analysis. Extensive history (back to 1985 possible). Dates ISO format.
279279
"""
280-
context = get_context(ctx)
280+
context: SchwabServerContext = ctx.request_context.lifespan_context
281281
client = context.price_history
282282

283283
start_dt = _parse_iso_datetime(start_datetime)
@@ -295,7 +295,7 @@ async def get_price_history_every_day(
295295

296296
@register
297297
async def get_price_history_every_week(
298-
ctx: Context,
298+
ctx: SchwabContext,
299299
symbol: Annotated[str, "Symbol of the security"],
300300
start_datetime: Annotated[
301301
str | None, "Start date for history (ISO format, e.g., '2023-01-01T00:00:00')"
@@ -309,7 +309,7 @@ async def get_price_history_every_week(
309309
"""
310310
Get weekly OHLCV price history. For long-term analysis, major cycles. Extensive history (back to 1985 possible). Dates ISO format.
311311
"""
312-
context = get_context(ctx)
312+
context: SchwabServerContext = ctx.request_context.lifespan_context
313313
client = context.price_history
314314

315315
start_dt = _parse_iso_datetime(start_datetime)

src/schwab_mcp/tools/options.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from typing import Annotated
44

55
import datetime
6-
from mcp.server.fastmcp import Context
76

7+
from schwab_mcp.context import SchwabContext, SchwabServerContext
88
from schwab_mcp.tools.registry import register
9-
from schwab_mcp.tools.utils import call, get_context
9+
from schwab_mcp.tools.utils import call
1010

1111

1212
def _parse_date(value: str | datetime.date | None) -> datetime.date | None:
@@ -21,7 +21,7 @@ def _parse_date(value: str | datetime.date | None) -> datetime.date | None:
2121

2222
@register
2323
async def get_option_chain(
24-
ctx: Context,
24+
ctx: SchwabContext,
2525
symbol: Annotated[str, "Symbol of the underlying security (e.g., 'AAPL', 'SPY')"],
2626
contract_type: Annotated[
2727
str | None, "Type of option contracts: CALL, PUT, or ALL (default)"
@@ -47,7 +47,7 @@ async def get_option_chain(
4747
Params: symbol, contract_type (CALL/PUT/ALL), strike_count (default 25), include_quotes (bool), from_date (YYYY-MM-DD), to_date (YYYY-MM-DD).
4848
Limit data returned using strike_count and date parameters.
4949
"""
50-
context = get_context(ctx)
50+
context: SchwabServerContext = ctx.request_context.lifespan_context
5151
client = context.options
5252

5353
from_date_obj = _parse_date(from_date)
@@ -68,7 +68,7 @@ async def get_option_chain(
6868

6969
@register
7070
async def get_advanced_option_chain(
71-
ctx: Context,
71+
ctx: SchwabContext,
7272
symbol: Annotated[str, "Symbol of the underlying security"],
7373
contract_type: Annotated[
7474
str | None, "Type of contracts: CALL, PUT, or ALL (default)"
@@ -123,7 +123,7 @@ async def get_advanced_option_chain(
123123
Params: symbol, contract_type, strike_count, include_quotes, strategy (SINGLE/ANALYTICAL/etc.), interval, strike, strike_range (ITM/NTM/etc.), from/to_date, volatility/underlying_price/interest_rate/days_to_expiration (for ANALYTICAL), exp_month, option_type (STANDARD/NON_STANDARD/ALL).
124124
Limit data returned using strike_count and date parameters.
125125
"""
126-
context = get_context(ctx)
126+
context: SchwabServerContext = ctx.request_context.lifespan_context
127127
client = context.options
128128

129129
from_date_obj = _parse_date(from_date)
@@ -158,12 +158,12 @@ async def get_advanced_option_chain(
158158

159159
@register
160160
async def get_option_expiration_chain(
161-
ctx: Context,
161+
ctx: SchwabContext,
162162
symbol: Annotated[str, "Symbol of the underlying security"],
163163
) -> str:
164164
"""
165165
Returns available option expiration dates for a symbol, without contract details. Lightweight call to find available cycles. Param: symbol.
166166
"""
167-
context = get_context(ctx)
167+
context: SchwabServerContext = ctx.request_context.lifespan_context
168168
client = context.options
169169
return await call(client.get_option_expiration_chain, symbol)

0 commit comments

Comments
 (0)