@@ -176,8 +176,10 @@ async def send_progress(
176176 ): ...
177177
178178 async def receive_response (
179- self , request_id : RequestId , timeout : float | None = None , fail_on_timeout : bool = True
180- ) -> JSONRPCResponse | JSONRPCError | None : ...
179+ self ,
180+ request_id : RequestId ,
181+ timeout : float | None = None ,
182+ ) -> JSONRPCResponse | JSONRPCError : ...
181183
182184 async def handle_response (self , message : JSONRPCResponse | JSONRPCError ) -> bool : ...
183185
@@ -242,8 +244,10 @@ async def send_progress(
242244 )
243245
244246 async def receive_response (
245- self , request_id : RequestId , timeout : float | None = None , fail_on_timeout : bool = True
246- ) -> JSONRPCResponse | JSONRPCError | None :
247+ self ,
248+ request_id : RequestId ,
249+ timeout : float | None = None ,
250+ ) -> JSONRPCResponse | JSONRPCError :
247251 request , _ , response_stream_reader = self ._response_streams .get (request_id , [None , None , None ])
248252
249253 if response_stream_reader is None :
@@ -254,24 +258,20 @@ async def receive_response(
254258 )
255259 )
256260
257- if fail_on_timeout :
258- try :
259- with anyio .fail_after (timeout ):
260- return await response_stream_reader .receive ()
261- except TimeoutError :
262- raise McpError (
263- ErrorData (
264- code = httpx .codes .REQUEST_TIMEOUT ,
265- message = (
266- f"Timed out while waiting for response to "
267- f"{ request .__class__ .__name__ } . Waited "
268- f"{ timeout } seconds."
269- ),
270- )
271- )
272- else :
273- with anyio .move_on_after (timeout ):
261+ try :
262+ with anyio .fail_after (timeout ):
274263 return await response_stream_reader .receive ()
264+ except TimeoutError :
265+ raise McpError (
266+ ErrorData (
267+ code = httpx .codes .REQUEST_TIMEOUT ,
268+ message = (
269+ f"Timed out while waiting for response to "
270+ f"{ request .__class__ .__name__ } . Waited "
271+ f"{ timeout } seconds."
272+ ),
273+ )
274+ )
275275
276276 async def handle_response (self , message : JSONRPCResponse | JSONRPCError ) -> bool :
277277 _ , stream , _ = self ._response_streams .get (message .id , [None , None , None ])
@@ -405,8 +405,8 @@ async def join_request(
405405 result_type : type [ReceiveResultT ],
406406 request_read_timeout_seconds : timedelta | None = None ,
407407 progress_callback : ProgressFnT | None = None ,
408- fail_on_timeout : bool = True ,
409- ) -> ReceiveResultT | None :
408+ done_on_timeout : bool = True ,
409+ ) -> ReceiveResultT :
410410 """
411411 Joins a request previously started via start_request
412412 """
@@ -420,16 +420,18 @@ async def join_request(
420420 elif self ._session_read_timeout_seconds is not None :
421421 timeout = self ._session_read_timeout_seconds .total_seconds ()
422422
423- response_or_error = await self ._request_state_manager .receive_response (request_id , timeout , fail_on_timeout )
423+ response_or_error = await self ._request_state_manager .receive_response (request_id , timeout )
424424
425- if response_or_error is None :
426- return None
425+ if isinstance (response_or_error , JSONRPCError ):
426+ if response_or_error .error .code == httpx .codes .REQUEST_TIMEOUT .value :
427+ if done_on_timeout :
428+ await self ._request_state_manager .close_request (request_id )
429+ else :
430+ await self ._request_state_manager .close_request (request_id )
431+ raise McpError (response_or_error .error )
427432 else :
428433 await self ._request_state_manager .close_request (request_id )
429- if isinstance (response_or_error , JSONRPCError ):
430- raise McpError (response_or_error .error )
431- else :
432- return result_type .model_validate (response_or_error .result )
434+ return result_type .model_validate (response_or_error .result )
433435
434436 async def cancel_request (self , request_id : RequestId ) -> bool :
435437 """
@@ -464,10 +466,10 @@ async def send_request(
464466 instead.
465467 """
466468 request_id = await self .start_request (request , metadata , progress_callback )
467- result = await self . join_request ( request_id , result_type , request_read_timeout_seconds )
468- if result is None :
469- raise RuntimeError ( "Should not be possible" )
470- return result
469+ try :
470+ return await self . join_request ( request_id , result_type , request_read_timeout_seconds )
471+ finally :
472+ await self . _request_state_manager . close_request ( request_id )
471473
472474 async def send_notification (
473475 self ,
0 commit comments