@@ -164,33 +164,27 @@ async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
164164
165165 async def get_result (self , req : types .GetToolAsyncResultRequest ):
166166 logger .debug ("Getting result" )
167- in_progress = self . _in_progress . get ( req .params .token )
168- logger . debug ( f"Found in progress { in_progress } " )
167+ async_token = req .params .token
168+ in_progress = self . _in_progress . get ( async_token )
169169 if in_progress is None :
170170 return types .CallToolResult (
171- content = [types .TextContent (type = "text" , text = "Unknown progress token" )],
171+ content = [types .TextContent (type = "text" , text = "Unknown async token" )],
172172 isError = True ,
173173 )
174174 else :
175+ logger .debug (f"Found in progress { in_progress } " )
175176 if in_progress .user == user_context .get ():
176- if in_progress .future is None :
177+ assert in_progress .future is not None
178+ # TODO add timeout to get async result
179+ try :
180+ result = in_progress .future .result (1 )
181+ logger .debug (f"Found result { result } " )
182+ return result
183+ except TimeoutError :
177184 return types .CallToolResult (
178- content = [
179- types .TextContent (type = "text" , text = "Permission denied" )
180- ],
181- isError = True ,
185+ content = [],
186+ isPending = True ,
182187 )
183- else :
184- # TODO add timeout to get async result
185- try :
186- result = in_progress .future .result (1 )
187- logger .debug (f"Found result { result } " )
188- return result
189- except TimeoutError :
190- return types .CallToolResult (
191- content = [],
192- isPending = True ,
193- )
194188 else :
195189 return types .CallToolResult (
196190 content = [types .TextContent (type = "text" , text = "Permission denied" )],
@@ -200,19 +194,20 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
200194 async def notification_hook (
201195 self , session : ServerSession , notification : types .ServerNotification
202196 ):
203- logger .debug (f"received { notification } from { id (session )} " )
197+ session_id = id (session )
198+ logger .debug (f"received { notification } from { session_id } " )
204199 if type (notification .root ) is types .ProgressNotification :
205200 # async with self._lock:
206- async_token = self ._session_lookup .get (id ( session ) )
201+ async_token = self ._session_lookup .get (session_id )
207202 if async_token is None :
208203 # not all sessions are async so just debug
209204 logger .debug ("Discarding progress notification from unknown session" )
210205 else :
211206 in_progress = self ._in_progress .get (async_token )
212- assert in_progress is not None
207+ assert in_progress is not None , "lost in progress for {async_token}"
213208 for other_id , other_session in in_progress .sessions .items ():
214- logger .debug (f"Checking { other_id } == { id ( session ) } " )
215- if not other_id == id ( session ) :
209+ logger .debug (f"Checking { other_id } == { session_id } " )
210+ if not other_id == session_id :
216211 logger .debug (f"Sending progress to { other_id } " )
217212 progress_token = in_progress .session_progress .get (other_id )
218213 assert progress_token is not None
@@ -227,26 +222,23 @@ async def notification_hook(
227222 )
228223
229224 async def session_close_hook (self , session : ServerSession ):
230- logger .debug (f"Closing { id (session )} " )
231- dropped = self ._session_lookup .pop (id (session ), None )
232- if dropped is None :
233- logger .warning (f"Discarding callback from unknown session { id (session )} " )
225+ session_id = id (session )
226+ logger .debug (f"Closing { session_id } " )
227+ dropped = self ._session_lookup .pop (session_id , None )
228+ assert dropped is not None , f"Discarded callback, unknown session { session_id } "
229+
230+ in_progress = self ._in_progress .get (dropped )
231+ if in_progress is None :
232+ logger .warning ("In progress not found" )
234233 else :
235- in_progress = self ._in_progress .get (dropped )
236- if in_progress is None :
237- logger .warning ("In progress not found" )
238- else :
239- found = in_progress .sessions .pop (id (session ), None )
240- if found is None :
241- logger .warning ("No session found" )
242- if len (in_progress .sessions ) == 0 :
243- self ._in_progress .pop (dropped , None )
244- logger .debug ("In progress found" )
245- if in_progress .future is None :
246- logger .warning ("In progress future is none" )
247- else :
248- logger .debug ("Cancelled in progress future" )
249- in_progress .future .cancel ()
234+ found = in_progress .sessions .pop (session_id , None )
235+ if found is None :
236+ logger .warning ("No session found" )
237+ if len (in_progress .sessions ) == 0 :
238+ self ._in_progress .pop (dropped , None )
239+ assert in_progress .future is not None
240+ logger .debug ("Cancelled in progress future" )
241+ in_progress .future .cancel ()
250242
251243 async def _new_in_progress (self ) -> InProgress :
252244 while True :
0 commit comments