Skip to content

Commit 93b3c66

Browse files
committed
Add trivial authorisation check - confirms if user is the same
1 parent 58fc239 commit 93b3c66

File tree

1 file changed

+38
-14
lines changed

1 file changed

+38
-14
lines changed

src/mcp/server/lowlevel/result_cache.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Awaitable, Callable
22
from dataclasses import dataclass, field
3+
from logging import getLogger
34
from time import time
45
from typing import Any
56
from uuid import uuid4
@@ -9,12 +10,17 @@
910
from cachetools import TTLCache
1011

1112
from mcp import types
13+
from mcp.server.auth.middleware.auth_context import auth_context_var as user_context
14+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
1215
from mcp.shared.context import BaseSession, RequestContext, SessionT
1316

17+
logger = getLogger(__name__)
18+
1419

1520
@dataclass
1621
class InProgress:
1722
token: str
23+
user: AuthenticatedUser | None = None
1824
task_group: TaskGroup | None = None
1925
sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field(
2026
default_factory=lambda: []
@@ -27,10 +33,9 @@ class ResultCache:
2733
Its purpose is to act as a central point for managing in progress
2834
async calls, allowing multiple clients to join and receive progress
2935
updates, get results and/or cancel in progress calls
30-
TODO CRITICAL!! Decide how to limit Async tokens for security purposes
31-
suggest use authentication protocol for identity - may need to add an
32-
authorisation layer to decide if a user is allowed to join an existing
33-
async call
36+
TODO IMPORTANT! may need to add an authorisation layer to decide if
37+
a user is allowed to get/join/cancel an existing async call current
38+
simple logic only allows same user to perform these tasks
3439
TODO name is probably not quite right, more of a result broker?
3540
TODO externalise cachetools to allow for other implementations
3641
e.g. redis etal for production scenarios
@@ -86,6 +91,7 @@ async def call_tool():
8691
async with create_task_group() as tg:
8792
tg.start_soon(call_tool)
8893
in_progress.task_group = tg
94+
in_progress.user = user_context.get()
8995
in_progress.sessions.append(ctx.session)
9096
result = types.CallToolAsyncResult(
9197
token=in_progress.token,
@@ -107,17 +113,27 @@ async def join_call(
107113
# to get message describing why it wasn't accepted
108114
return types.CallToolAsyncResult(accepted=False)
109115
else:
110-
in_progress.sessions.append(ctx.session)
111-
return types.CallToolAsyncResult(accepted=True)
112-
113-
return
116+
# TODO consider adding authorisation layer to make this decision
117+
if in_progress.user == user_context.get():
118+
in_progress.sessions.append(ctx.session)
119+
return types.CallToolAsyncResult(accepted=True)
120+
else:
121+
# TODO consider creating new token to allow client
122+
# to get message describing why it wasn't accepted
123+
return types.CallToolAsyncResult(accepted=False)
114124

115125
async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
116126
async with self._lock:
117127
in_progress = self._in_progress.get(notification.params.token)
118128
if in_progress is not None and in_progress.task_group is not None:
119-
in_progress.task_group.cancel_scope.cancel()
120-
del self._in_progress[notification.params.token]
129+
if in_progress.user == user_context.get():
130+
in_progress.task_group.cancel_scope.cancel()
131+
del self._in_progress[notification.params.token]
132+
else:
133+
logger.warning(
134+
"Permission denied for cancel notification received"
135+
f"from {user_context.get()}"
136+
)
121137

122138
async def get_result(self, req: types.GetToolAsyncResultRequest):
123139
async with self._lock:
@@ -130,11 +146,19 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
130146
isError=True,
131147
)
132148
else:
133-
result = self._result_cache.get(in_progress.token)
134-
if result is None:
135-
return types.CallToolResult(content=[], isPending=True)
149+
if in_progress.user == user_context.get():
150+
result = self._result_cache.get(in_progress.token)
151+
if result is None:
152+
return types.CallToolResult(content=[], isPending=True)
153+
else:
154+
return result
136155
else:
137-
return result
156+
return types.CallToolResult(
157+
content=[
158+
types.TextContent(type="text", text="Permission denied")
159+
],
160+
isError=True,
161+
)
138162

139163
async def _new_in_progress(self) -> InProgress:
140164
async with self._lock:

0 commit comments

Comments
 (0)