Skip to content

Commit 04ff73a

Browse files
committed
refactor args for clearer meaning, use error vs returning none on timeout
1 parent 4165200 commit 04ff73a

File tree

3 files changed

+49
-41
lines changed

3 files changed

+49
-41
lines changed

src/mcp/client/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,14 +307,14 @@ async def join_call_tool(
307307
request_id: types.RequestId,
308308
progress_callback: ProgressFnT | None = None,
309309
request_read_timeout_seconds: timedelta | None = None,
310-
fail_on_timeout: bool = True,
310+
done_on_timeout: bool = True,
311311
) -> types.CallToolResult | None:
312312
return await self.join_request(
313313
request_id,
314314
types.CallToolResult,
315315
request_read_timeout_seconds=request_read_timeout_seconds,
316316
progress_callback=progress_callback,
317-
fail_on_timeout=fail_on_timeout,
317+
done_on_timeout=done_on_timeout,
318318
)
319319

320320
async def cancel_call_tool(

src/mcp/shared/session.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,10 @@ async def send_progress(
176176
): ...
177177

178178
async def receive_response(
179-
self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True
180-
) -> JSONRPCResponse | JSONRPCError | None: ...
179+
self,
180+
request_id: RequestId,
181+
timeout: float | None = None,
182+
) -> JSONRPCResponse | JSONRPCError: ...
181183

182184
async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool: ...
183185

@@ -242,8 +244,10 @@ async def send_progress(
242244
)
243245

244246
async def receive_response(
245-
self, request_id: RequestId, timeout: float | None = None, fail_on_timeout: bool = True
246-
) -> JSONRPCResponse | JSONRPCError | None:
247+
self,
248+
request_id: RequestId,
249+
timeout: float | None = None,
250+
) -> JSONRPCResponse | JSONRPCError:
247251
request, _, response_stream_reader = self._response_streams.get(request_id, [None, None, None])
248252

249253
if response_stream_reader is None:
@@ -254,24 +258,20 @@ async def receive_response(
254258
)
255259
)
256260

257-
if fail_on_timeout:
258-
try:
259-
with anyio.fail_after(timeout):
260-
return await response_stream_reader.receive()
261-
except TimeoutError:
262-
raise McpError(
263-
ErrorData(
264-
code=httpx.codes.REQUEST_TIMEOUT,
265-
message=(
266-
f"Timed out while waiting for response to "
267-
f"{request.__class__.__name__}. Waited "
268-
f"{timeout} seconds."
269-
),
270-
)
271-
)
272-
else:
273-
with anyio.move_on_after(timeout):
261+
try:
262+
with anyio.fail_after(timeout):
274263
return await response_stream_reader.receive()
264+
except TimeoutError:
265+
raise McpError(
266+
ErrorData(
267+
code=httpx.codes.REQUEST_TIMEOUT,
268+
message=(
269+
f"Timed out while waiting for response to "
270+
f"{request.__class__.__name__}. Waited "
271+
f"{timeout} seconds."
272+
),
273+
)
274+
)
275275

276276
async def handle_response(self, message: JSONRPCResponse | JSONRPCError) -> bool:
277277
_, stream, _ = self._response_streams.get(message.id, [None, None, None])
@@ -405,8 +405,8 @@ async def join_request(
405405
result_type: type[ReceiveResultT],
406406
request_read_timeout_seconds: timedelta | None = None,
407407
progress_callback: ProgressFnT | None = None,
408-
fail_on_timeout: bool = True,
409-
) -> ReceiveResultT | None:
408+
done_on_timeout: bool = True,
409+
) -> ReceiveResultT:
410410
"""
411411
Joins a request previously started via start_request
412412
"""
@@ -420,16 +420,18 @@ async def join_request(
420420
elif self._session_read_timeout_seconds is not None:
421421
timeout = self._session_read_timeout_seconds.total_seconds()
422422

423-
response_or_error = await self._request_state_manager.receive_response(request_id, timeout, fail_on_timeout)
423+
response_or_error = await self._request_state_manager.receive_response(request_id, timeout)
424424

425-
if response_or_error is None:
426-
return None
425+
if isinstance(response_or_error, JSONRPCError):
426+
if response_or_error.error.code == httpx.codes.REQUEST_TIMEOUT.value:
427+
if done_on_timeout:
428+
await self._request_state_manager.close_request(request_id)
429+
else:
430+
await self._request_state_manager.close_request(request_id)
431+
raise McpError(response_or_error.error)
427432
else:
428433
await self._request_state_manager.close_request(request_id)
429-
if isinstance(response_or_error, JSONRPCError):
430-
raise McpError(response_or_error.error)
431-
else:
432-
return result_type.model_validate(response_or_error.result)
434+
return result_type.model_validate(response_or_error.result)
433435

434436
async def cancel_request(self, request_id: RequestId) -> bool:
435437
"""
@@ -464,10 +466,10 @@ async def send_request(
464466
instead.
465467
"""
466468
request_id = await self.start_request(request, metadata, progress_callback)
467-
result = await self.join_request(request_id, result_type, request_read_timeout_seconds)
468-
if result is None:
469-
raise RuntimeError("Should not be possible")
470-
return result
469+
try:
470+
return await self.join_request(request_id, result_type, request_read_timeout_seconds)
471+
finally:
472+
await self._request_state_manager.close_request(request_id)
471473

472474
async def send_notification(
473475
self,

tests/client/test_session.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from typing import Any
33

44
import anyio
5+
import httpx
56
import pytest
67

78
import mcp.types as types
89
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
910
from mcp.shared.context import RequestContext
11+
from mcp.shared.exceptions import McpError
1012
from mcp.shared.message import SessionMessage
1113
from mcp.shared.session import ImMemoryRequestStateManager, RequestResponder
1214
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -636,13 +638,17 @@ async def message_handler(
636638
request_id = await session.request_call_tool("hello", {"name": "world"})
637639

638640
with anyio.fail_after(1):
639-
result = await session.join_call_tool(
640-
request_id, request_read_timeout_seconds=timedelta(microseconds=1), fail_on_timeout=False
641-
)
642-
assert result is None
641+
try:
642+
result = await session.join_call_tool(
643+
request_id, request_read_timeout_seconds=timedelta(microseconds=1), done_on_timeout=False
644+
)
645+
except McpError as e:
646+
if not e.error.code == httpx.codes.REQUEST_TIMEOUT:
647+
raise e
648+
643649
send_result.set()
644650
result = await session.join_call_tool(
645-
request_id, request_read_timeout_seconds=timedelta(seconds=1), fail_on_timeout=False
651+
request_id, request_read_timeout_seconds=timedelta(seconds=1), done_on_timeout=False
646652
)
647653

648654
# Assert the result

0 commit comments

Comments
 (0)