Skip to content

Commit 2e0e6db

Browse files
authored
Merge pull request modelcontextprotocol#10 from jkoelker/jk/fastmcp
refactor(server): share schwab client via lifespan context
2 parents 54b70ed + 0ec7387 commit 2e0e6db

21 files changed

+421
-217
lines changed

src/schwab_mcp/context.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import TYPE_CHECKING, Any, cast
5+
6+
from schwab.client import AsyncClient
7+
8+
if TYPE_CHECKING:
9+
from schwab_mcp.tools._protocols import (
10+
AccountClient,
11+
OptionsClient,
12+
OrdersClient,
13+
PriceHistoryClient,
14+
QuotesClient,
15+
ToolsClient,
16+
TransactionsClient,
17+
)
18+
else: # pragma: no cover - runtime only
19+
AccountClient = OptionsClient = OrdersClient = PriceHistoryClient = QuotesClient = (
20+
ToolsClient
21+
) = TransactionsClient = Any
22+
23+
24+
@dataclass(slots=True)
25+
class SchwabServerContext:
26+
"""Typed application context shared via FastMCP lifespan."""
27+
28+
client: AsyncClient
29+
tools: ToolsClient = field(init=False)
30+
accounts: AccountClient = field(init=False)
31+
price_history: PriceHistoryClient = field(init=False)
32+
options: OptionsClient = field(init=False)
33+
orders: OrdersClient = field(init=False)
34+
quotes: QuotesClient = field(init=False)
35+
transactions: TransactionsClient = field(init=False)
36+
37+
def __post_init__(self) -> None:
38+
self.tools = cast(ToolsClient, self.client)
39+
self.accounts = cast(AccountClient, self.client)
40+
self.price_history = cast(PriceHistoryClient, self.client)
41+
self.options = cast(OptionsClient, self.client)
42+
self.orders = cast(OrdersClient, self.client)
43+
self.quotes = cast(QuotesClient, self.client)
44+
self.transactions = cast(TransactionsClient, self.client)
45+
46+
47+
__all__ = ["SchwabServerContext"]

src/schwab_mcp/server.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,41 @@
11
from __future__ import annotations
22

3+
import logging
34
import sys
4-
from typing import Optional
5+
from collections.abc import AsyncGenerator
6+
from contextlib import asynccontextmanager
7+
from typing import AsyncContextManager, Callable, Optional
58

69
import mcp.types as types
710
from mcp.server.fastmcp import FastMCP
811
from schwab.client import AsyncClient
912

1013
from schwab_mcp.tools import register_tools
14+
from schwab_mcp.context import SchwabServerContext
15+
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def _client_lifespan(
21+
client: AsyncClient,
22+
) -> Callable[[FastMCP], AsyncContextManager[SchwabServerContext]]:
23+
"""Create a FastMCP lifespan context that exposes the Schwab async client."""
24+
25+
@asynccontextmanager
26+
async def lifespan(_: FastMCP) -> AsyncGenerator[SchwabServerContext, None]:
27+
context = SchwabServerContext(client=client)
28+
try:
29+
yield context
30+
finally:
31+
try:
32+
await client.close_async_session()
33+
except Exception:
34+
logger.exception(
35+
"Failed to close Schwab async client session during shutdown."
36+
)
37+
38+
return lifespan
1139

1240

1341
class SchwabMCPServer:
@@ -19,7 +47,7 @@ def __init__(
1947
client: AsyncClient,
2048
jesus_take_the_wheel: bool = False,
2149
) -> None:
22-
self._server = FastMCP(name=name)
50+
self._server = FastMCP(name=name, lifespan=_client_lifespan(client))
2351
register_tools(self._server, client, allow_write=jesus_take_the_wheel)
2452

2553
async def run(self) -> None:

src/schwab_mcp/tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from schwab_mcp.tools import quotes as _quotes # noqa: F401
1616
from schwab_mcp.tools import transactions as _txns # noqa: F401
1717

18+
1819
def register_tools(server: FastMCP, client: AsyncClient, *, allow_write: bool) -> None:
1920
"""Register all Schwab tools with the provided FastMCP server."""
2021
tool_utils.set_write_enabled(allow_write)
21-
setattr(server, "_schwab_client", client)
2222

2323
for func in iter_registered_tools():
2424
if getattr(func, "_write", False) and not allow_write:

src/schwab_mcp/tools/_protocols.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,37 @@ class PriceHistoryNamespace(Protocol):
1818
class PriceHistoryClient(Protocol):
1919
PriceHistory: PriceHistoryNamespace
2020

21-
def get_advanced_price_history(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
21+
def get_advanced_price_history(
22+
self, symbol: str, **kwargs: Any
23+
) -> Awaitable[Any]: ...
2224

23-
def get_price_history_every_minute(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
25+
def get_price_history_every_minute(
26+
self, symbol: str, **kwargs: Any
27+
) -> Awaitable[Any]: ...
2428

25-
def get_price_history_every_five_minutes(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
29+
def get_price_history_every_five_minutes(
30+
self, symbol: str, **kwargs: Any
31+
) -> Awaitable[Any]: ...
2632

27-
def get_price_history_every_ten_minutes(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
33+
def get_price_history_every_ten_minutes(
34+
self, symbol: str, **kwargs: Any
35+
) -> Awaitable[Any]: ...
2836

29-
def get_price_history_every_fifteen_minutes(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
37+
def get_price_history_every_fifteen_minutes(
38+
self, symbol: str, **kwargs: Any
39+
) -> Awaitable[Any]: ...
3040

31-
def get_price_history_every_thirty_minutes(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
41+
def get_price_history_every_thirty_minutes(
42+
self, symbol: str, **kwargs: Any
43+
) -> Awaitable[Any]: ...
3244

33-
def get_price_history_every_day(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
45+
def get_price_history_every_day(
46+
self, symbol: str, **kwargs: Any
47+
) -> Awaitable[Any]: ...
3448

35-
def get_price_history_every_week(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
49+
def get_price_history_every_week(
50+
self, symbol: str, **kwargs: Any
51+
) -> Awaitable[Any]: ...
3652

3753

3854
class OptionsNamespace(Protocol):
@@ -48,7 +64,9 @@ class OptionsClient(Protocol):
4864

4965
def get_option_chain(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
5066

51-
def get_option_expiration_chain(self, symbol: str, **kwargs: Any) -> Awaitable[Any]: ...
67+
def get_option_expiration_chain(
68+
self, symbol: str, **kwargs: Any
69+
) -> Awaitable[Any]: ...
5270

5371

5472
class QuoteFieldsNamespace(Protocol):
@@ -122,11 +140,17 @@ class OrderNamespace(Protocol):
122140
class OrdersClient(Protocol):
123141
Order: OrderNamespace
124142

125-
def get_orders_for_account(self, account_hash: str, **kwargs: Any) -> Awaitable[Any]: ...
143+
def get_orders_for_account(
144+
self, account_hash: str, **kwargs: Any
145+
) -> Awaitable[Any]: ...
126146

127-
def get_order(self, order_id: str, account_hash: str, **kwargs: Any) -> Awaitable[Any]: ...
147+
def get_order(
148+
self, order_id: str, account_hash: str, **kwargs: Any
149+
) -> Awaitable[Any]: ...
128150

129-
def cancel_order(self, order_id: str, account_hash: str, **kwargs: Any) -> Awaitable[Any]: ...
151+
def cancel_order(
152+
self, order_id: str, account_hash: str, **kwargs: Any
153+
) -> Awaitable[Any]: ...
130154

131155
def place_order(self, account_hash: str, **kwargs: Any) -> Awaitable[Any]: ...
132156

@@ -144,4 +168,6 @@ class TransactionsClient(Protocol):
144168

145169
def get_transactions(self, account_hash: str, **kwargs: Any) -> Awaitable[Any]: ...
146170

147-
def get_transaction(self, account_hash: str, transaction_id: str, **kwargs: Any) -> Awaitable[Any]: ...
171+
def get_transaction(
172+
self, account_hash: str, transaction_id: str, **kwargs: Any
173+
) -> Awaitable[Any]: ...

src/schwab_mcp/tools/account.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mcp.server.fastmcp import Context
66

77
from schwab_mcp.tools.registry import register
8-
from schwab_mcp.tools.utils import call, get_account_client
8+
from schwab_mcp.tools.utils import call, get_context
99

1010

1111
@register
@@ -15,8 +15,8 @@ async def get_account_numbers(
1515
"""
1616
Returns mapping of account IDs to account hashes. Hashes required for account-specific calls. Use first.
1717
"""
18-
client = get_account_client(ctx)
19-
return await call(client.get_account_numbers)
18+
context = get_context(ctx)
19+
return await call(context.accounts.get_account_numbers)
2020

2121

2222
@register
@@ -26,8 +26,8 @@ async def get_accounts(
2626
"""
2727
Returns balances/info for all linked accounts (funds, cash, margin). Does not return hashes; use get_account_numbers first.
2828
"""
29-
client = get_account_client(ctx)
30-
return await call(client.get_accounts)
29+
context = get_context(ctx)
30+
return await call(context.accounts.get_accounts)
3131

3232

3333
@register
@@ -37,8 +37,11 @@ async def get_accounts_with_positions(
3737
"""
3838
Returns balances, info, and positions (holdings, cost, gain/loss) for all linked accounts. Does not return hashes; use get_account_numbers first.
3939
"""
40-
client = get_account_client(ctx)
41-
return await call(client.get_accounts, fields=[client.Account.Fields.POSITIONS])
40+
context = get_context(ctx)
41+
return await call(
42+
context.accounts.get_accounts,
43+
fields=[context.accounts.Account.Fields.POSITIONS],
44+
)
4245

4346

4447
@register
@@ -49,8 +52,8 @@ async def get_account(
4952
"""
5053
Returns balance/info for a specific account via account_hash (from get_account_numbers). Includes funds, cash, margin info.
5154
"""
52-
client = get_account_client(ctx)
53-
return await call(client.get_account, account_hash)
55+
context = get_context(ctx)
56+
return await call(context.accounts.get_account, account_hash)
5457

5558

5659
@register
@@ -61,9 +64,11 @@ async def get_account_with_positions(
6164
"""
6265
Returns balance, info, and positions for a specific account via account_hash. Includes holdings, quantity, cost basis, unrealized gain/loss.
6366
"""
64-
client = get_account_client(ctx)
67+
context = get_context(ctx)
6568
return await call(
66-
client.get_account, account_hash, fields=[client.Account.Fields.POSITIONS]
69+
context.accounts.get_account,
70+
account_hash,
71+
fields=[context.accounts.Account.Fields.POSITIONS],
6772
)
6873

6974

@@ -74,5 +79,5 @@ async def get_user_preferences(
7479
"""
7580
Returns user preferences (nicknames, display settings, notifications) for all linked accounts.
7681
"""
77-
client = get_account_client(ctx)
78-
return await call(client.get_user_preferences)
82+
context = get_context(ctx)
83+
return await call(context.accounts.get_user_preferences)

src/schwab_mcp/tools/history.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mcp.server.fastmcp import Context
77

88
from schwab_mcp.tools.registry import register
9-
from schwab_mcp.tools.utils import call, get_price_history_client
9+
from schwab_mcp.tools.utils import call, get_context
1010

1111

1212
def _parse_iso_datetime(value: str | None) -> datetime.datetime | None:
@@ -17,7 +17,9 @@ def _parse_iso_datetime(value: str | None) -> datetime.datetime | None:
1717
async def get_advanced_price_history(
1818
ctx: Context,
1919
symbol: Annotated[str, "Symbol of the security"],
20-
period_type: Annotated[str | None, "Period type: DAY, MONTH, YEAR, YEAR_TO_DATE"] = None,
20+
period_type: Annotated[
21+
str | None, "Period type: DAY, MONTH, YEAR, YEAR_TO_DATE"
22+
] = None,
2123
period: Annotated[
2224
str | None,
2325
(
@@ -26,7 +28,8 @@ async def get_advanced_price_history(
2628
),
2729
] = None,
2830
frequency_type: Annotated[
29-
str | None, "Frequency type: MINUTE (for DAY), DAILY/WEEKLY (for MONTH/YTD), DAILY/WEEKLY/MONTHLY (for YEAR)"
31+
str | None,
32+
"Frequency type: MINUTE (for DAY), DAILY/WEEKLY (for MONTH/YTD), DAILY/WEEKLY/MONTHLY (for YEAR)",
3033
] = None,
3134
frequency: Annotated[
3235
int | str | None,
@@ -57,16 +60,15 @@ async def get_advanced_price_history(
5760
YEAR_TO_DATE: DAILY, WEEKLY (default)
5861
Dates must be in ISO format.
5962
"""
60-
client = get_price_history_client(ctx)
63+
context = get_context(ctx)
64+
client = context.price_history
6165

6266
start_dt = _parse_iso_datetime(start_datetime)
6367
end_dt = _parse_iso_datetime(end_datetime)
6468

6569
# Normalize enum-like strings
6670
period_type_enum = (
67-
client.PriceHistory.PeriodType[period_type.upper()]
68-
if period_type
69-
else None
71+
client.PriceHistory.PeriodType[period_type.upper()] if period_type else None
7072
)
7173
period_enum = client.PriceHistory.Period[period.upper()] if period else None
7274
frequency_type_enum = (
@@ -109,7 +111,8 @@ async def get_price_history_every_minute(
109111
"""
110112
Get OHLCV price history per minute. For detailed intraday analysis. Max 48 days history. Dates ISO format.
111113
"""
112-
client = get_price_history_client(ctx)
114+
context = get_context(ctx)
115+
client = context.price_history
113116

114117
start_dt = _parse_iso_datetime(start_datetime)
115118
end_dt = _parse_iso_datetime(end_datetime)
@@ -140,7 +143,8 @@ async def get_price_history_every_five_minutes(
140143
"""
141144
Get OHLCV price history per 5 minutes. Balance between detail and noise. Approx. 9 months history. Dates ISO format.
142145
"""
143-
client = get_price_history_client(ctx)
146+
context = get_context(ctx)
147+
client = context.price_history
144148

145149
start_dt = _parse_iso_datetime(start_datetime)
146150
end_dt = _parse_iso_datetime(end_datetime)
@@ -171,7 +175,8 @@ async def get_price_history_every_ten_minutes(
171175
"""
172176
Get OHLCV price history per 10 minutes. Good for intraday trends/levels. Approx. 9 months history. Dates ISO format.
173177
"""
174-
client = get_price_history_client(ctx)
178+
context = get_context(ctx)
179+
client = context.price_history
175180

176181
start_dt = _parse_iso_datetime(start_datetime)
177182
end_dt = _parse_iso_datetime(end_datetime)
@@ -202,7 +207,8 @@ async def get_price_history_every_fifteen_minutes(
202207
"""
203208
Get OHLCV price history per 15 minutes. Shows significant intraday moves, filters noise. Approx. 9 months history. Dates ISO format.
204209
"""
205-
client = get_price_history_client(ctx)
210+
context = get_context(ctx)
211+
client = context.price_history
206212

207213
start_dt = _parse_iso_datetime(start_datetime)
208214
end_dt = _parse_iso_datetime(end_datetime)
@@ -233,7 +239,8 @@ async def get_price_history_every_thirty_minutes(
233239
"""
234240
Get OHLCV price history per 30 minutes. For broader intraday trends, filters noise. Approx. 9 months history. Dates ISO format.
235241
"""
236-
client = get_price_history_client(ctx)
242+
context = get_context(ctx)
243+
client = context.price_history
237244

238245
start_dt = _parse_iso_datetime(start_datetime)
239246
end_dt = _parse_iso_datetime(end_datetime)
@@ -270,7 +277,8 @@ async def get_price_history_every_day(
270277
"""
271278
Get daily OHLCV price history. For medium/long-term analysis. Extensive history (back to 1985 possible). Dates ISO format.
272279
"""
273-
client = get_price_history_client(ctx)
280+
context = get_context(ctx)
281+
client = context.price_history
274282

275283
start_dt = _parse_iso_datetime(start_datetime)
276284
end_dt = _parse_iso_datetime(end_datetime)
@@ -301,7 +309,8 @@ async def get_price_history_every_week(
301309
"""
302310
Get weekly OHLCV price history. For long-term analysis, major cycles. Extensive history (back to 1985 possible). Dates ISO format.
303311
"""
304-
client = get_price_history_client(ctx)
312+
context = get_context(ctx)
313+
client = context.price_history
305314

306315
start_dt = _parse_iso_datetime(start_datetime)
307316
end_dt = _parse_iso_datetime(end_datetime)

0 commit comments

Comments
 (0)