diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 52466272..9590c961 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -90,7 +90,8 @@ async def process_query( sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item] # Create Query to handle control protocol - is_streaming = not isinstance(prompt, str) + # Force streaming mode if SDK MCP servers are present (they require bidirectional communication) + is_streaming = not isinstance(prompt, str) or bool(sdk_mcp_servers) query = Query( transport=chosen_transport, is_streaming_mode=is_streaming, @@ -109,11 +110,12 @@ async def process_query( if is_streaming: await query.initialize() - # Stream input if it's an AsyncIterable - if isinstance(prompt, AsyncIterable) and query._tg: - # Start streaming in background - # Create a task that will run in the background - query._tg.start_soon(query.stream_input, prompt) + # Stream input if in streaming mode + if query._tg: + # Use the (possibly converted) prompt from transport + stream_prompt = getattr(chosen_transport, "_prompt", prompt) + if isinstance(stream_prompt, AsyncIterable): + query._tg.start_soon(query.stream_input, stream_prompt) # For string prompts, the prompt is already passed via CLI args # Yield parsed messages diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index c30fc159..5c7fa115 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -14,6 +14,7 @@ ListToolsRequest, ) +from .._errors import CLIConnectionError from ..types import ( PermissionResultAllow, PermissionResultDeny, @@ -322,7 +323,14 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: "response": response_data, }, } - await self.transport.write(json.dumps(success_response) + "\n") + try: + await self.transport.write(json.dumps(success_response) + "\n") + except CLIConnectionError: + logger.debug( + "Transport closed before sending control response for %s (request_id=%s)", + subtype, + request_id, + ) except Exception as e: # Send error response @@ -334,7 +342,15 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: "error": str(e), }, } - await self.transport.write(json.dumps(error_response) + "\n") + try: + await self.transport.write(json.dumps(error_response) + "\n") + except CLIConnectionError: + logger.debug( + "Transport closed before sending error response for %s (request_id=%s): %s", + subtype, + request_id, + e, + ) async def _send_control_request( self, request: dict[str, Any], timeout: float = 60.0 diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index a4882db1..3c2aa2ca 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -37,6 +37,27 @@ _CMD_LENGTH_LIMIT = 8000 if platform.system() == "Windows" else 100000 +async def _string_to_async_iterable(prompt: str) -> AsyncIterator[dict[str, Any]]: + """Convert a string prompt to an async iterable for streaming mode. + + When SDK MCP servers are present, we need streaming mode for bidirectional + communication. This helper converts a string prompt to the expected + stream-json format. + + Args: + prompt: The string prompt to convert + + Yields: + A single user message dict in stream-json format + """ + yield { + "type": "user", + "message": {"role": "user", "content": prompt}, + "parent_tool_use_id": None, + "session_id": "default", + } + + class SubprocessCLITransport(Transport): """Subprocess transport using Claude Code CLI.""" @@ -45,8 +66,22 @@ def __init__( prompt: str | AsyncIterable[dict[str, Any]], options: ClaudeAgentOptions, ): - self._prompt = prompt - self._is_streaming = not isinstance(prompt, str) + # Check if SDK MCP servers are present - they require streaming mode + # for bidirectional communication + has_sdk_mcp = self._has_sdk_mcp_servers(options) + + # Determine streaming mode: either explicit AsyncIterable or + # forced by SDK MCP servers presence + if isinstance(prompt, str) and has_sdk_mcp: + # Convert string prompt to async iterable for SDK MCP support + self._prompt: str | AsyncIterable[dict[str, Any]] = ( + _string_to_async_iterable(prompt) + ) + self._is_streaming = True + else: + self._prompt = prompt + self._is_streaming = not isinstance(prompt, str) + self._options = options self._cli_path = ( str(options.cli_path) if options.cli_path is not None else self._find_cli() @@ -67,6 +102,27 @@ def __init__( self._temp_files: list[str] = [] # Track temporary files for cleanup self._write_lock: anyio.Lock = anyio.Lock() + def _has_sdk_mcp_servers(self, options: ClaudeAgentOptions) -> bool: + """Check if any SDK MCP servers are configured. + + SDK MCP servers require bidirectional communication through stdin/stdout, + so when present, streaming mode must be forced even for string prompts. + + Args: + options: The agent options to check + + Returns: + True if any SDK MCP server is configured, False otherwise + """ + if not options.mcp_servers: + return False + if not isinstance(options.mcp_servers, dict): + return False + return any( + isinstance(config, dict) and config.get("type") == "sdk" + for config in options.mcp_servers.values() + ) + def _find_cli(self) -> str: """Find Claude Code CLI binary.""" # First, check for bundled CLI diff --git a/tests/test_query_control_request.py b/tests/test_query_control_request.py new file mode 100644 index 00000000..71606cd7 --- /dev/null +++ b/tests/test_query_control_request.py @@ -0,0 +1,267 @@ +"""Tests for Query._handle_control_request() race condition handling.""" + +import json +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock + +import anyio +import pytest + +from claude_agent_sdk import CLIConnectionError +from claude_agent_sdk._internal.query import Query +from claude_agent_sdk.types import ( + PermissionResultAllow, + SDKControlRequest, +) + + +def create_mock_transport_raising_on_write() -> MagicMock: + """Create a mock transport that raises CLIConnectionError on write.""" + mock_transport = MagicMock() + mock_transport.write = AsyncMock(side_effect=CLIConnectionError("Transport closed")) + return mock_transport + + +def create_mock_transport_working() -> tuple[MagicMock, list[str]]: + """Create a working mock transport that records written data.""" + mock_transport = MagicMock() + written_data: list[str] = [] + + async def mock_write(data: str) -> None: + written_data.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + return mock_transport, written_data + + +def _make_can_use_tool_request(request_id: str) -> SDKControlRequest: + """Create a can_use_tool control request for testing.""" + return cast( + SDKControlRequest, + { + "type": "control_request", + "request_id": request_id, + "request": { + "subtype": "can_use_tool", + "tool_name": "test_tool", + "input": {"arg": "value"}, + }, + }, + ) + + +def _make_mcp_message_request(request_id: str) -> SDKControlRequest: + """Create an mcp_message control request for testing (with None server_name).""" + return cast( + SDKControlRequest, + { + "type": "control_request", + "request_id": request_id, + "request": { + "subtype": "mcp_message", + "server_name": None, # Will trigger error path + "message": {}, + }, + }, + ) + + +def _make_hook_callback_request(request_id: str) -> SDKControlRequest: + """Create a hook_callback control request for testing.""" + return cast( + SDKControlRequest, + { + "type": "control_request", + "request_id": request_id, + "request": { + "subtype": "hook_callback", + "callback_id": "hook_0", + "input": {"test": "data"}, + "tool_use_id": "tool-123", + }, + }, + ) + + +class TestHandleControlRequestRaceCondition: + """Test race condition handling in _handle_control_request().""" + + def test_success_response_on_closed_transport(self) -> None: + """Transport closed before success response - should not raise.""" + + async def _test() -> None: + mock_transport = create_mock_transport_raising_on_write() + + # Create a can_use_tool callback that returns Allow + async def mock_can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: object, + ) -> PermissionResultAllow: + return PermissionResultAllow() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + can_use_tool=mock_can_use_tool, + ) + + request = _make_can_use_tool_request("test-request-1") + + # Should not raise - CLIConnectionError should be caught + await query._handle_control_request(request) + + # Verify write was attempted + mock_transport.write.assert_called_once() + + anyio.run(_test) + + def test_error_response_on_closed_transport(self) -> None: + """Transport closed before error response - should not raise.""" + + async def _test() -> None: + mock_transport = create_mock_transport_raising_on_write() + + # Create a can_use_tool callback that raises an exception + async def mock_can_use_tool_error( + tool_name: str, + tool_input: dict[str, Any], + context: object, + ) -> PermissionResultAllow: + raise ValueError("Callback error") + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + can_use_tool=mock_can_use_tool_error, + ) + + request = _make_can_use_tool_request("test-request-2") + + # Should not raise - CLIConnectionError in error path should be caught + await query._handle_control_request(request) + + # Verify write was attempted (for error response) + mock_transport.write.assert_called_once() + + anyio.run(_test) + + def test_mcp_message_on_closed_transport(self) -> None: + """MCP message response on closed transport - should not raise.""" + + async def _test() -> None: + mock_transport = create_mock_transport_raising_on_write() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + + request = _make_mcp_message_request("test-request-3") + + # Should not raise - error response CLIConnectionError should be caught + await query._handle_control_request(request) + + # Verify write was attempted + mock_transport.write.assert_called_once() + + anyio.run(_test) + + def test_normal_operation_unaffected(self) -> None: + """Normal operation still works correctly.""" + + async def _test() -> None: + mock_transport, written_data = create_mock_transport_working() + + # Create a can_use_tool callback that returns Allow + async def mock_can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: object, + ) -> PermissionResultAllow: + return PermissionResultAllow(updated_input={"modified": True}) + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + can_use_tool=mock_can_use_tool, + ) + + request = _make_can_use_tool_request("test-request-4") + + await query._handle_control_request(request) + + # Verify response was written correctly + assert len(written_data) == 1 + response = json.loads(written_data[0].strip()) + assert response["type"] == "control_response" + assert response["response"]["subtype"] == "success" + assert response["response"]["request_id"] == "test-request-4" + assert response["response"]["response"]["behavior"] == "allow" + assert response["response"]["response"]["updatedInput"] == { + "modified": True + } + + anyio.run(_test) + + def test_hook_callback_on_closed_transport(self) -> None: + """Hook callback response on closed transport - should not raise.""" + + async def _test() -> None: + mock_transport = create_mock_transport_raising_on_write() + + # Create a hook callback + async def mock_hook( + input_data: dict[str, Any] | None, + tool_use_id: str | None, + context: dict[str, Any], + ) -> dict[str, Any]: + return {"continue_": True} + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + ) + # Register the hook callback + query.hook_callbacks["hook_0"] = mock_hook + + request = _make_hook_callback_request("test-request-5") + + # Should not raise - CLIConnectionError should be caught + await query._handle_control_request(request) + + # Verify write was attempted + mock_transport.write.assert_called_once() + + anyio.run(_test) + + def test_other_exceptions_still_propagate(self) -> None: + """Non-CLIConnectionError exceptions should still propagate.""" + + async def _test() -> None: + mock_transport = MagicMock() + mock_transport.write = AsyncMock( + side_effect=RuntimeError("Unexpected error") + ) + + # Create a can_use_tool callback that returns Allow + async def mock_can_use_tool( + tool_name: str, + tool_input: dict[str, Any], + context: object, + ) -> PermissionResultAllow: + return PermissionResultAllow() + + query = Query( + transport=mock_transport, + is_streaming_mode=True, + can_use_tool=mock_can_use_tool, + ) + + request = _make_can_use_tool_request("test-request-6") + + # RuntimeError should still propagate + with pytest.raises(RuntimeError, match="Unexpected error"): + await query._handle_control_request(request) + + anyio.run(_test) diff --git a/tests/test_transport.py b/tests/test_transport.py index fe9b6b22..99a74f9d 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -826,3 +826,152 @@ async def do_write(i: int): await process.wait() anyio.run(_test, backend="trio") + + def test_has_sdk_mcp_servers_detection_no_servers(self): + """Test SDK MCP server detection with no servers.""" + transport = SubprocessCLITransport(prompt="test", options=make_options()) + assert transport._has_sdk_mcp_servers(transport._options) is False + + def test_has_sdk_mcp_servers_detection_external_only(self): + """Test SDK MCP server detection with only external servers.""" + external_server = {"type": "stdio", "command": "echo"} + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers={"ext": external_server}), + ) + assert transport._has_sdk_mcp_servers(transport._options) is False + + def test_has_sdk_mcp_servers_detection_sdk_present(self): + """Test SDK MCP server detection with SDK server present.""" + # Mock SDK server config (instance field would normally have McpServer) + sdk_server = {"type": "sdk", "name": "test", "instance": MagicMock()} + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers={"sdk": sdk_server}), + ) + assert transport._has_sdk_mcp_servers(transport._options) is True + + def test_streaming_mode_forced_with_sdk_mcp_servers(self): + """Test that SDK MCP servers force streaming mode even with string prompt.""" + sdk_server = {"type": "sdk", "name": "test", "instance": MagicMock()} + transport = SubprocessCLITransport( + prompt="Hello", # String prompt + options=make_options(mcp_servers={"test": sdk_server}), + ) + + # Should force streaming mode due to SDK MCP server + assert transport._is_streaming is True + + cmd = transport._build_command() + assert "--input-format" in cmd + assert "stream-json" in cmd + assert "--print" not in cmd + + def test_string_prompt_without_sdk_mcp_stays_non_streaming(self): + """Test that string prompts without SDK MCP servers use non-streaming mode.""" + transport = SubprocessCLITransport( + prompt="Hello", + options=make_options(), + ) + + assert transport._is_streaming is False + + cmd = transport._build_command() + assert "--print" in cmd + + def test_string_prompt_with_external_mcp_stays_non_streaming(self): + """Test that external MCP servers don't force streaming mode.""" + external_server = {"type": "stdio", "command": "echo", "args": ["test"]} + transport = SubprocessCLITransport( + prompt="Hello", + options=make_options(mcp_servers={"external": external_server}), + ) + + # External servers don't need bidirectional communication + assert transport._is_streaming is False + + cmd = transport._build_command() + assert "--print" in cmd + + def test_string_to_async_iterable_output_format(self): + """Test that _string_to_async_iterable produces correct stream-json format.""" + from claude_agent_sdk._internal.transport.subprocess_cli import ( + _string_to_async_iterable, + ) + + async def _test(): + messages = [msg async for msg in _string_to_async_iterable("Hello world")] + + assert len(messages) == 1 + msg = messages[0] + assert msg["type"] == "user" + assert msg["message"]["role"] == "user" + assert msg["message"]["content"] == "Hello world" + assert msg["parent_tool_use_id"] is None + assert msg["session_id"] == "default" + + anyio.run(_test) + + def test_string_prompt_converted_to_async_iterable_with_sdk_mcp(self): + """Test that string prompt is converted to AsyncIterable with SDK MCP servers.""" + from collections.abc import AsyncIterable + + sdk_server = {"type": "sdk", "name": "test", "instance": MagicMock()} + transport = SubprocessCLITransport( + prompt="Hello", + options=make_options(mcp_servers={"test": sdk_server}), + ) + + # _prompt should be an AsyncIterable, not a string + assert isinstance(transport._prompt, AsyncIterable) + assert not isinstance(transport._prompt, str) + + def test_mixed_sdk_and_external_mcp_servers_forces_streaming(self): + """Test that mixed SDK + external MCP servers force streaming mode.""" + sdk_server = {"type": "sdk", "name": "test", "instance": MagicMock()} + external_server = {"type": "stdio", "command": "echo", "args": ["test"]} + transport = SubprocessCLITransport( + prompt="Hello", + options=make_options( + mcp_servers={"sdk": sdk_server, "external": external_server} + ), + ) + + # SDK server presence should force streaming mode + assert transport._is_streaming is True + + cmd = transport._build_command() + assert "--input-format" in cmd + assert "stream-json" in cmd + assert "--print" not in cmd + + def test_has_sdk_mcp_servers_with_file_path_returns_false(self): + """Test that mcp_servers as file path returns False for SDK detection.""" + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers="/path/to/mcp-config.json"), + ) + assert transport._has_sdk_mcp_servers(transport._options) is False + + def test_has_sdk_mcp_servers_with_invalid_config_values(self): + """Test SDK detection with various invalid config values.""" + # Config value is None + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers={"server": None}), + ) + assert transport._has_sdk_mcp_servers(transport._options) is False + + # Config value is a string + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers={"server": "invalid"}), + ) + assert transport._has_sdk_mcp_servers(transport._options) is False + + # Config value is a dict without type field + transport = SubprocessCLITransport( + prompt="test", + options=make_options(mcp_servers={"server": {"command": "echo"}}), + ) + assert transport._has_sdk_mcp_servers(transport._options) is False