Skip to content

Commit ac4b822

Browse files
committed
add tests to assert cancellable=False behaves as expected
1 parent 8b7f1cd commit ac4b822

File tree

3 files changed

+97
-30
lines changed

3 files changed

+97
-30
lines changed

src/mcp/client/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ async def call_tool(
260260
name: str,
261261
arguments: dict[str, Any] | None = None,
262262
read_timeout_seconds: timedelta | None = None,
263+
cancellable: bool = True,
263264
) -> types.CallToolResult:
264265
"""Send a tools/call request."""
265266

@@ -272,6 +273,7 @@ async def call_tool(
272273
),
273274
types.CallToolResult,
274275
request_read_timeout_seconds=read_timeout_seconds,
276+
cancellable=cancellable,
275277
)
276278

277279
async def list_prompts(self) -> types.ListPromptsResult:

src/mcp/shared/session.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -263,39 +263,40 @@ async def send_request(
263263
elif self._session_read_timeout_seconds is not None:
264264
timeout = self._session_read_timeout_seconds.total_seconds()
265265

266-
try:
267-
with anyio.fail_after(timeout) as scope:
268-
response_or_error = await response_stream_reader.receive()
269-
270-
if cancellable and scope.cancel_called:
271-
with anyio.CancelScope(shield=True):
272-
notification = CancelledNotification(
273-
method="notifications/cancelled",
274-
params=CancelledNotificationParams(
275-
requestId=request_id, reason="cancelled"
276-
),
277-
)
278-
await self._send_notification_internal(
279-
notification, request_id
266+
with anyio.CancelScope(shield=not cancellable):
267+
try:
268+
with anyio.fail_after(timeout) as scope:
269+
response_or_error = await response_stream_reader.receive()
270+
271+
if scope.cancel_called:
272+
with anyio.CancelScope(shield=True):
273+
notification = CancelledNotification(
274+
method="notifications/cancelled",
275+
params=CancelledNotificationParams(
276+
requestId=request_id, reason="cancelled"
277+
),
278+
)
279+
await self._send_notification_internal(
280+
notification, request_id
281+
)
282+
283+
raise McpError(
284+
ErrorData(
285+
code=REQUEST_CANCELLED, message="Request cancelled"
286+
)
280287
)
281288

282-
raise McpError(
283-
ErrorData(
284-
code=REQUEST_CANCELLED, message="Request cancelled"
285-
)
289+
except TimeoutError:
290+
raise McpError(
291+
ErrorData(
292+
code=httpx.codes.REQUEST_TIMEOUT,
293+
message=(
294+
f"Timed out while waiting for response to "
295+
f"{request.__class__.__name__}. Waited "
296+
f"{timeout} seconds."
297+
),
286298
)
287-
288-
except TimeoutError:
289-
raise McpError(
290-
ErrorData(
291-
code=httpx.codes.REQUEST_TIMEOUT,
292-
message=(
293-
f"Timed out while waiting for response to "
294-
f"{request.__class__.__name__}. Waited "
295-
f"{timeout} seconds."
296-
),
297299
)
298-
)
299300

300301
if isinstance(response_or_error, JSONRPCError):
301302
raise McpError(response_or_error.error)

tests/shared/test_session.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ async def test_in_flight_requests_cleared_after_completion(
4242
@pytest.mark.anyio
4343
async def test_request_cancellation():
4444
"""Test that requests can be cancelled while in-flight."""
45-
# The tool is already registered in the fixture
4645

4746
ev_tool_called = anyio.Event()
4847
ev_tool_cancelled = anyio.Event()
@@ -121,3 +120,68 @@ async def make_request(client_session: ClientSession):
121120
# Give cancellation time to process on server
122121
with anyio.fail_after(1):
123122
await ev_tool_cancelled.wait()
123+
124+
@pytest.mark.anyio
125+
async def test_request_cancellation_uncancellable():
126+
"""Test that asserts."""
127+
# The tool is already registered in the fixture
128+
129+
ev_tool_called = anyio.Event()
130+
ev_tool_commplete = anyio.Event()
131+
ev_cancelled = anyio.Event()
132+
133+
# Start the request in a separate task so we can cancel it
134+
def make_server() -> Server:
135+
server = Server(name="TestSessionServer")
136+
137+
# Register the tool handler
138+
@server.call_tool()
139+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
140+
nonlocal ev_tool_called, ev_tool_commplete
141+
if name == "slow_tool":
142+
ev_tool_called.set()
143+
with anyio.CancelScope():
144+
with anyio.fail_after(10): # Long enough to ensure we can cancel
145+
await ev_cancelled.wait()
146+
ev_tool_commplete.set()
147+
return []
148+
149+
raise ValueError(f"Unknown tool: {name}")
150+
151+
# Register the tool so it shows up in list_tools
152+
@server.list_tools()
153+
async def handle_list_tools() -> list[types.Tool]:
154+
return [
155+
types.Tool(
156+
name="slow_tool",
157+
description="A slow tool that takes 10 seconds to complete",
158+
inputSchema={},
159+
)
160+
]
161+
162+
return server
163+
164+
async def make_request(client_session: ClientSession):
165+
nonlocal ev_cancelled
166+
try:
167+
await client_session.call_tool("slow_tool", cancellable=False)
168+
except McpError as e:
169+
pytest.fail("Request should not have been cancelled")
170+
171+
async with create_connected_server_and_client_session(
172+
make_server()
173+
) as client_session:
174+
async with anyio.create_task_group() as tg:
175+
tg.start_soon(make_request, client_session)
176+
177+
# Wait for the request to be in-flight
178+
with anyio.fail_after(1): # Timeout after 1 second
179+
await ev_tool_called.wait()
180+
181+
# Cancel the task via task group
182+
tg.cancel_scope.cancel()
183+
ev_cancelled.set()
184+
185+
# Check server completed regardless
186+
with anyio.fail_after(1):
187+
await ev_tool_commplete.wait()

0 commit comments

Comments
 (0)