Skip to content

Commit d0e463e

Browse files
committed
rename result_cache class and allow other implementations for enterprise scenarios where memory cache may not be sufficient
1 parent f1d1116 commit d0e463e

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

src/mcp/server/lowlevel/result_cache.py renamed to src/mcp/server/lowlevel/async_request_manager.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,36 @@
2020
logger = getLogger(__name__)
2121

2222

23+
class AsyncRequestManager:
24+
async def __aenter__(self): ...
25+
async def __aexit__(
26+
self,
27+
exc_type: type[BaseException] | None,
28+
exc_val: BaseException | None,
29+
exc_tb: TracebackType | None,
30+
) -> bool | None: ...
31+
async def start_call(
32+
self,
33+
call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]],
34+
req: types.CallToolAsyncRequest,
35+
ctx: RequestContext[ServerSession, Any, Any],
36+
) -> types.CallToolAsyncResult: ...
37+
async def join_call(
38+
self,
39+
req: types.JoinCallToolAsyncRequest,
40+
ctx: RequestContext[ServerSession, Any, Any],
41+
) -> types.CallToolAsyncResult: ...
42+
async def cancel(self, notification: types.CancelToolAsyncNotification) -> None: ...
43+
async def get_result(
44+
self, req: types.GetToolAsyncResultRequest
45+
) -> types.CallToolResult: ...
46+
47+
async def notification_hook(
48+
self, session: ServerSession, notification: types.ServerNotification
49+
) -> None: ...
50+
async def session_close_hook(self, session: ServerSession): ...
51+
52+
2353
@dataclass
2454
class InProgress:
2555
token: str
@@ -40,23 +70,18 @@ def is_expired(self):
4070
return int(self.timer()) > self.keep_alive_start + self.keep_alive
4171

4272

43-
class ResultCache:
73+
class SimpleInMemoryAsyncRequestManager(AsyncRequestManager):
4474
"""
4575
Note this class is a work in progress
4676
Its purpose is to act as a central point for managing in progress
4777
async calls, allowing multiple clients to join and receive progress
4878
updates, get results and/or cancel in progress calls
49-
TODO CRITICAL not obvious user context will be passed to background thread
50-
add tests to assert behaviour with authenticated calls
5179
TODO MAJOR needs a lot more testing around edge cases/failure scenarios
5280
TODO MAJOR decide if async.Locks are required for integrity of internal
5381
data structures
54-
TODO ENHANCEMENT externalise cachetools to allow for other implementations
55-
e.g. redis etal for production scenarios
5682
TODO ENHANCEMENT may need to add an authorisation layer to decide if
5783
a user is allowed to get/join/cancel an existing async call current
5884
simple logic only allows same user to perform these tasks
59-
TODO TRIVIAL name is probably not quite right, more of a result broker?
6085
"""
6186

6287
_in_progress: dict[types.AsyncToken, InProgress]
@@ -178,7 +203,9 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
178203
f"from {user_context.get()}"
179204
)
180205

181-
async def get_result(self, req: types.GetToolAsyncResultRequest):
206+
async def get_result(
207+
self, req: types.GetToolAsyncResultRequest
208+
) -> types.CallToolResult:
182209
logger.debug("Getting result")
183210
async_token = req.params.token
184211
in_progress = self._in_progress.get(async_token)

src/mcp/server/lowlevel/server.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,11 @@ async def main():
8080
from typing_extensions import TypeVar
8181

8282
import mcp.types as types
83+
from mcp.server.lowlevel.async_request_manager import (
84+
AsyncRequestManager,
85+
SimpleInMemoryAsyncRequestManager,
86+
)
8387
from mcp.server.lowlevel.helper_types import ReadResourceContents
84-
from mcp.server.lowlevel.result_cache import ResultCache
8588
from mcp.server.models import InitializationOptions
8689
from mcp.server.session import ServerSession
8790
from mcp.server.stdio import stdio_server as stdio_server
@@ -136,8 +139,9 @@ def __init__(
136139
[Server[LifespanResultT, RequestT]],
137140
AbstractAsyncContextManager[LifespanResultT],
138141
] = lifespan,
139-
max_cache_size: int = 1000,
140-
max_cache_ttl: int = 60,
142+
async_request_manager: AsyncRequestManager = SimpleInMemoryAsyncRequestManager(
143+
max_size=1000, max_keep_alive=60
144+
),
141145
):
142146
self.name = name
143147
self.version = version
@@ -148,7 +152,7 @@ def __init__(
148152
] = {
149153
types.PingRequest: _ping_handler,
150154
}
151-
self.result_cache = ResultCache(max_cache_size, max_cache_ttl)
155+
self.async_request_manager = async_request_manager
152156
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
153157
self.notification_options = NotificationOptions()
154158
logger.debug("Initializing server %r", name)
@@ -432,19 +436,19 @@ async def handler(req: types.CallToolRequest):
432436

433437
async def async_call_handler(req: types.CallToolAsyncRequest):
434438
ctx = request_ctx.get()
435-
result = await self.result_cache.start_call(handler, req, ctx)
439+
result = await self.async_request_manager.start_call(handler, req, ctx)
436440
return types.ServerResult(result)
437441

438442
async def async_join_handler(req: types.JoinCallToolAsyncRequest):
439443
ctx = request_ctx.get()
440-
result = await self.result_cache.join_call(req, ctx)
444+
result = await self.async_request_manager.join_call(req, ctx)
441445
return types.ServerResult(result)
442446

443447
async def async_cancel_handler(req: types.CancelToolAsyncNotification):
444-
await self.result_cache.cancel(req)
448+
await self.async_request_manager.cancel(req)
445449

446450
async def async_result_handler(req: types.GetToolAsyncResultRequest):
447-
result = await self.result_cache.get_result(req)
451+
result = await self.async_request_manager.get_result(req)
448452
return types.ServerResult(result)
449453

450454
self.request_handlers[types.CallToolRequest] = handler
@@ -534,11 +538,11 @@ async def run(
534538
write_stream,
535539
initialization_options,
536540
stateless=stateless,
537-
notification_hook=self.result_cache.notification_hook,
538-
session_close_hook=self.result_cache.session_close_hook,
541+
notification_hook=self.async_request_manager.notification_hook,
542+
session_close_hook=self.async_request_manager.session_close_hook,
539543
)
540544
)
541-
await stack.enter_async_context(self.result_cache)
545+
await stack.enter_async_context(self.async_request_manager)
542546

543547
async with anyio.create_task_group() as tg:
544548
async for message in session.incoming_messages:

tests/server/lowlevel/test_result_cache.py renamed to tests/server/lowlevel/test_async_request_manager.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mcp.server.auth.middleware.auth_context import (
99
auth_context_var as user_context,
1010
)
11-
from mcp.server.lowlevel.result_cache import ResultCache
11+
from mcp.server.lowlevel.async_request_manager import SimpleInMemoryAsyncRequestManager
1212

1313

1414
@pytest.mark.anyio
@@ -27,7 +27,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
2727
mock_session = AsyncMock()
2828
mock_context = Mock()
2929
mock_context.session = mock_session
30-
result_cache = ResultCache(max_size=1, max_keep_alive=1)
30+
result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1)
3131
async with AsyncExitStack() as stack:
3232
await stack.enter_async_context(result_cache)
3333
async_call_ref = await result_cache.start_call(
@@ -73,7 +73,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
7373

7474
mock_context_2.session = mock_session_2
7575

76-
result_cache = ResultCache(max_size=1, max_keep_alive=1)
76+
result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1)
7777
async with AsyncExitStack() as stack:
7878
await stack.enter_async_context(result_cache)
7979
async_call_ref = await result_cache.start_call(
@@ -160,7 +160,7 @@ async def slow_call(call: types.CallToolRequest) -> types.ServerResult:
160160
mock_context_1 = Mock()
161161
mock_context_1.session = mock_session_1
162162

163-
result_cache = ResultCache(max_size=1, max_keep_alive=1)
163+
result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1)
164164
async with AsyncExitStack() as stack:
165165
await stack.enter_async_context(result_cache)
166166
async_call_ref = await result_cache.start_call(
@@ -228,7 +228,7 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
228228

229229
mock_context_2.session = mock_session_2
230230

231-
result_cache = ResultCache(max_size=1, max_keep_alive=10)
231+
result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=10)
232232
async with AsyncExitStack() as stack:
233233
await stack.enter_async_context(result_cache)
234234
async_call_ref = await result_cache.start_call(
@@ -307,7 +307,9 @@ async def test_call(call: types.CallToolRequest) -> types.ServerResult:
307307
def test_timer():
308308
return time
309309

310-
result_cache = ResultCache(max_size=1, max_keep_alive=1, timer=test_timer)
310+
result_cache = SimpleInMemoryAsyncRequestManager(
311+
max_size=1, max_keep_alive=1, timer=test_timer
312+
)
311313
async with AsyncExitStack() as stack:
312314
await stack.enter_async_context(result_cache)
313315
async_call_ref = await result_cache.start_call(
@@ -391,7 +393,7 @@ async def test_async_call_pass_auth():
391393
mock_session = AsyncMock()
392394
mock_context = Mock()
393395
mock_context.session = mock_session
394-
result_cache = ResultCache(max_size=1, max_keep_alive=1)
396+
result_cache = SimpleInMemoryAsyncRequestManager(max_size=1, max_keep_alive=1)
395397

396398
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
397399
user = user_context.get()

0 commit comments

Comments
 (0)