11from collections .abc import Awaitable , Callable
22from dataclasses import dataclass , field
3+ from logging import getLogger
34from time import time
45from typing import Any
56from uuid import uuid4
910from cachetools import TTLCache
1011
1112from mcp import types
13+ from mcp .server .auth .middleware .auth_context import auth_context_var as user_context
14+ from mcp .server .auth .middleware .bearer_auth import AuthenticatedUser
1215from mcp .shared .context import BaseSession , RequestContext , SessionT
1316
17+ logger = getLogger (__name__ )
18+
1419
1520@dataclass
1621class InProgress :
1722 token : str
23+ user : AuthenticatedUser | None = None
1824 task_group : TaskGroup | None = None
1925 sessions : list [BaseSession [Any , Any , Any , Any , Any ]] = field (
2026 default_factory = lambda : []
@@ -27,10 +33,9 @@ class ResultCache:
2733 Its purpose is to act as a central point for managing in progress
2834 async calls, allowing multiple clients to join and receive progress
2935 updates, get results and/or cancel in progress calls
30- TODO CRITICAL!! Decide how to limit Async tokens for security purposes
31- suggest use authentication protocol for identity - may need to add an
32- authorisation layer to decide if a user is allowed to join an existing
33- async call
36+ TODO IMPORTANT! may need to add an authorisation layer to decide if
37+ a user is allowed to get/join/cancel an existing async call current
38+ simple logic only allows same user to perform these tasks
3439 TODO name is probably not quite right, more of a result broker?
3540 TODO externalise cachetools to allow for other implementations
3641 e.g. redis etal for production scenarios
@@ -86,6 +91,7 @@ async def call_tool():
8691 async with create_task_group () as tg :
8792 tg .start_soon (call_tool )
8893 in_progress .task_group = tg
94+ in_progress .user = user_context .get ()
8995 in_progress .sessions .append (ctx .session )
9096 result = types .CallToolAsyncResult (
9197 token = in_progress .token ,
@@ -107,17 +113,27 @@ async def join_call(
107113 # to get message describing why it wasn't accepted
108114 return types .CallToolAsyncResult (accepted = False )
109115 else :
110- in_progress .sessions .append (ctx .session )
111- return types .CallToolAsyncResult (accepted = True )
112-
113- return
116+ # TODO consider adding authorisation layer to make this decision
117+ if in_progress .user == user_context .get ():
118+ in_progress .sessions .append (ctx .session )
119+ return types .CallToolAsyncResult (accepted = True )
120+ else :
121+ # TODO consider creating new token to allow client
122+ # to get message describing why it wasn't accepted
123+ return types .CallToolAsyncResult (accepted = False )
114124
115125 async def cancel (self , notification : types .CancelToolAsyncNotification ) -> None :
116126 async with self ._lock :
117127 in_progress = self ._in_progress .get (notification .params .token )
118128 if in_progress is not None and in_progress .task_group is not None :
119- in_progress .task_group .cancel_scope .cancel ()
120- del self ._in_progress [notification .params .token ]
129+ if in_progress .user == user_context .get ():
130+ in_progress .task_group .cancel_scope .cancel ()
131+ del self ._in_progress [notification .params .token ]
132+ else :
133+ logger .warning (
134+ "Permission denied for cancel notification received"
135+ f"from { user_context .get ()} "
136+ )
121137
122138 async def get_result (self , req : types .GetToolAsyncResultRequest ):
123139 async with self ._lock :
@@ -130,11 +146,19 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
130146 isError = True ,
131147 )
132148 else :
133- result = self ._result_cache .get (in_progress .token )
134- if result is None :
135- return types .CallToolResult (content = [], isPending = True )
149+ if in_progress .user == user_context .get ():
150+ result = self ._result_cache .get (in_progress .token )
151+ if result is None :
152+ return types .CallToolResult (content = [], isPending = True )
153+ else :
154+ return result
136155 else :
137- return result
156+ return types .CallToolResult (
157+ content = [
158+ types .TextContent (type = "text" , text = "Permission denied" )
159+ ],
160+ isError = True ,
161+ )
138162
139163 async def _new_in_progress (self ) -> InProgress :
140164 async with self ._lock :
0 commit comments