diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index c30fc159..4958b216 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -113,6 +113,11 @@ def __init__( float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0 ) # Convert ms to seconds + # Timeout for task group cleanup during close() to prevent indefinite hangs + self._task_group_close_timeout = ( + float(os.environ.get("CLAUDE_CODE_TASK_GROUP_CLOSE_TIMEOUT", "5000")) / 1000.0 + ) # Convert ms to seconds, default 5s + async def initialize(self) -> dict[str, Any] | None: """Initialize control protocol if in streaming mode. @@ -604,9 +609,21 @@ async def close(self) -> None: self._closed = True if self._tg: self._tg.cancel_scope.cancel() - # Wait for task group to complete cancellation + # Wait for task group to complete cancellation with timeout + # to prevent indefinite hangs if tasks don't respond to cancellation. + # Set deadline on the task group's own cancel scope rather than + # wrapping with a new scope to avoid cancel scope nesting issues. + self._tg.cancel_scope.deadline = ( + anyio.current_time() + self._task_group_close_timeout + ) with suppress(anyio.get_cancelled_exc_class()): await self._tg.__aexit__(None, None, None) + if self._tg.cancel_scope.cancel_called and not self._tg.cancel_scope.cancelled_caught: + # Timeout occurred during cleanup + logger.warning( + f"Task group cleanup timed out after {self._task_group_close_timeout}s, " + "forcing close" + ) await self.transport.close() # Make Query an async iterator diff --git a/tests/test_client.py b/tests/test_client.py index 39c32895..855bb2e2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,10 +1,13 @@ """Tests for Claude SDK client functionality.""" +import os from unittest.mock import AsyncMock, Mock, patch import anyio from claude_agent_sdk import AssistantMessage, ClaudeAgentOptions, query +from claude_agent_sdk._internal.query import Query +from claude_agent_sdk._internal.transport import Transport from claude_agent_sdk.types import TextBlock @@ -121,3 +124,52 @@ async def mock_receive(): assert call_kwargs["options"].cwd == "/custom/path" anyio.run(_test) + + +class TestQueryClose: + """Test Query.close() behavior.""" + + def test_close_timeout_prevents_hang(self): + """Test that close() doesn't hang if task group cleanup takes too long. + + Regression test for issue #378 - Query.close() could hang indefinitely + when tasks don't respond to cancellation. + """ + + async def _test(): + # Create a mock transport + mock_transport = Mock(spec=Transport) + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.write = AsyncMock() + + # Set a short timeout for testing (100ms) + os.environ["CLAUDE_CODE_TASK_GROUP_CLOSE_TIMEOUT"] = "100" + try: + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + + # Start the query to create the task group + await query.start() + + # Verify task group was created + assert query._tg is not None + + # Close should complete within the timeout even if tasks misbehave + # The timeout is set to 100ms, so this should complete quickly + start = anyio.current_time() + await query.close() + elapsed = anyio.current_time() - start + + # Should complete quickly (< 1 second), not hang indefinitely + assert elapsed < 1.0 + + # Transport close should have been called + mock_transport.close.assert_called_once() + finally: + # Restore original timeout + os.environ.pop("CLAUDE_CODE_TASK_GROUP_CLOSE_TIMEOUT", None) + + anyio.run(_test)