From 59e8f35f12d11ef9c1505aeb4aa5463361e2cbd9 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Fri, 2 Jan 2026 20:24:04 +0000 Subject: [PATCH] Revert "refactor: migrate agent to Claude Agent SDK with MCP (#63)" This reverts commit 58aba5fd1a0039bc05da379e1360542de1b67bbd. --- pyproject.toml | 1 - src/policyengine_api/agent_sandbox.py | 505 ++++++++++++++++++++------ tests/test_agent.py | 253 +++++++++++-- tests/test_agent_sandbox.py | 24 +- uv.lock | 18 - 5 files changed, 634 insertions(+), 167 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 37cac99..27eb310 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", - "claude-agent-sdk>=0.1.0", ] [project.optional-dependencies] diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index f460320..57f7faf 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -1,12 +1,16 @@ -"""Modal agent using Claude Agent SDK with MCP server connection.""" +"""Modal agent using Claude API with tools auto-generated from OpenAPI spec.""" -import asyncio +import json +import re +import time +from typing import Callable +import anthropic import modal import requests image = modal.Image.debian_slim(python_version="3.12").pip_install( - "claude-agent-sdk", "requests", "logfire[httpx]" + "anthropic", "requests", "logfire[httpx]" ) app = modal.App("policyengine-sandbox") @@ -31,6 +35,7 @@ def configure_logfire(traceparent: str | None = None): console=False, ) + # If traceparent provided, attach to the current context if traceparent: from opentelemetry import context from opentelemetry.trace.propagation.tracecontext import ( @@ -44,7 +49,7 @@ def configure_logfire(traceparent: str | None = None): SYSTEM_PROMPT = """You are a PolicyEngine assistant that helps users understand tax and benefit policies. -You have access to the PolicyEngine API via MCP tools. +You have access to the full PolicyEngine API. ## CRITICAL: Always filter by country @@ -52,62 +57,340 @@ def configure_logfire(traceparent: str | None = None): - "policyengine-uk" for UK questions - "policyengine-us" for US questions -Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results. +Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results and waste turns finding the right ones. ## Key workflows 1. **Household calculations**: - - Use household_calculate with model_name and people array - - Poll household_calculate_status until completed + - POST /household/calculate with model_name and people array + - Poll GET /household/calculate/{job_id} until completed 2. **Parameter lookup**: - - Use parameters_list with search and tax_benefit_model_name filter - - Use parameter_values_list with parameter_id and current=true + - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk (ALWAYS include country filter) + - GET /parameter-values/?parameter_id=...¤t=true for the current value -3. **Economic impact analysis**: - - Find parameter_id with parameters_list - - Create policy with policies_create - - Find dataset_id with datasets_list - - Run analysis with analysis_economic_impact - - Get results with analysis_economic_impact_status +3. **Economic impact analysis** (budget impact, decile impacts): + - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk to find parameter_id + - POST /policies/ to create reform with parameter_values + - GET /datasets/?tax_benefit_model_name=policyengine-uk to find dataset_id + - POST /analysis/economic-impact with tax_benefit_model_name, policy_id and dataset_id + - GET /analysis/economic-impact/{report_id} for results (includes decile_impacts and program_statistics) ## Response formatting Follow PolicyEngine's writing style: -1. **Sentence case**: Use sentence case for headings -2. **Active voice**: "The reform reduces poverty by 3.2%" -3. **Quantitative precision**: Use specific numbers -4. **Neutral tone**: Describe what policies do objectively -5. **Tables for data**: Use markdown tables for breakdowns +1. **Sentence case**: Use sentence case for all headings (e.g. "Tax breakdown" not "Tax Breakdown") +2. **Active voice**: "The reform reduces poverty by 3.2%" not "Poverty is reduced by 3.2%" +3. **Quantitative precision**: Use specific numbers, avoid vague words like "significantly" or "substantially" +4. **Neutral tone**: Describe what policies do, not whether they're good or bad +5. **Tables for data**: Present breakdowns and comparisons in markdown tables -Example: +Example response format: | Item | Amount | |------|--------| | Income tax | £7,486 | | National Insurance | £2,994 | | **Total tax** | **£10,480** | -Avoid vague words like "significantly" or "substantially" - use numbers. +- Gross income: £50,000 +- Net income: £39,520 +- Effective tax rate: 21.0% + +Avoid: "significantly reduces", "substantial savings", "unfortunately", "great news" +Prefer: specific percentages, pound/dollar amounts, neutral descriptions + +## Guidelines + +1. Use the API tools to get accurate, current data +2. Be concise - lead with key numbers +3. For UK, amounts are in GBP (£). For US, amounts are in USD ($) +4. When polling async endpoints, use the sleep tool to wait 5-10 seconds between requests """ +# Sleep tool for polling delays +SLEEP_TOOL = { + "name": "sleep", + "description": "Wait for a specified number of seconds. Use this between polling requests to avoid hammering the API.", + "input_schema": { + "type": "object", + "properties": { + "seconds": { + "type": "number", + "description": "Number of seconds to sleep (1-60)", + } + }, + "required": ["seconds"], + }, +} + -async def _run_agent_async( - question: str, +def fetch_openapi_spec(api_base_url: str) -> dict: + """Fetch and cache OpenAPI spec.""" + resp = requests.get(f"{api_base_url}/openapi.json", timeout=30) + resp.raise_for_status() + return resp.json() + + +def resolve_ref(spec: dict, ref: str) -> dict: + """Resolve a $ref pointer in the OpenAPI spec.""" + if not ref.startswith("#/"): + return {} + parts = ref[2:].split("/") + result = spec + for part in parts: + result = result.get(part, {}) + return result + + +def schema_to_json_schema(spec: dict, schema: dict) -> dict: + """Convert OpenAPI schema to JSON Schema for Claude tools.""" + if "$ref" in schema: + schema = resolve_ref(spec, schema["$ref"]) + + result = {} + + if "type" in schema: + result["type"] = schema["type"] + if "description" in schema: + result["description"] = schema["description"] + if "enum" in schema: + result["enum"] = schema["enum"] + if "default" in schema: + result["default"] = schema["default"] + if "format" in schema: + # Add format info to description + fmt = schema["format"] + if "description" in result: + result["description"] += f" (format: {fmt})" + else: + result["description"] = f"Format: {fmt}" + + # Handle anyOf (often used for Optional types) + if "anyOf" in schema: + non_null = [s for s in schema["anyOf"] if s.get("type") != "null"] + if len(non_null) == 1: + result.update(schema_to_json_schema(spec, non_null[0])) + elif non_null: + result.update(schema_to_json_schema(spec, non_null[0])) + + # Handle allOf + if "allOf" in schema: + for sub in schema["allOf"]: + result.update(schema_to_json_schema(spec, sub)) + + # Handle objects + if schema.get("type") == "object" or "properties" in schema: + result["type"] = "object" + if "properties" in schema: + result["properties"] = {} + for prop_name, prop_schema in schema["properties"].items(): + result["properties"][prop_name] = schema_to_json_schema( + spec, prop_schema + ) + if "required" in schema: + result["required"] = schema["required"] + + # Handle arrays + if schema.get("type") == "array" and "items" in schema: + result["items"] = schema_to_json_schema(spec, schema["items"]) + + return result + + +def openapi_to_claude_tools(spec: dict) -> list[dict]: + """Convert OpenAPI spec to Claude tool definitions.""" + tools = [] + + for path, methods in spec.get("paths", {}).items(): + for method, operation in methods.items(): + if method not in ("get", "post", "put", "patch", "delete"): + continue + + # Build tool name from operationId or path+method + op_id = operation.get("operationId", f"{method}_{path}") + # Clean up the name + tool_name = re.sub(r"[^a-zA-Z0-9_]", "_", op_id) + tool_name = re.sub(r"_+", "_", tool_name).strip("_") + + # Build description + summary = operation.get("summary", "") + description = operation.get("description", "") + full_desc = f"{method.upper()} {path}" + if summary: + full_desc += f"\n\n{summary}" + if description: + full_desc += f"\n\n{description}" + + # Build input schema + properties = {} + required = [] + + # Path parameters + for param in operation.get("parameters", []): + param_name = param.get("name") + param_in = param.get("in") + param_schema = param.get("schema", {}) + param_required = param.get("required", False) + + prop = schema_to_json_schema(spec, param_schema) + prop["description"] = ( + param.get("description", "") + + f" (in: {param_in})" + ) + properties[param_name] = prop + + if param_required: + required.append(param_name) + + # Request body + request_body = operation.get("requestBody", {}) + if request_body: + content = request_body.get("content", {}) + json_content = content.get("application/json", {}) + body_schema = json_content.get("schema", {}) + + if body_schema: + resolved = schema_to_json_schema(spec, body_schema) + # Flatten body properties into tool properties + if "properties" in resolved: + for prop_name, prop_schema in resolved["properties"].items(): + properties[prop_name] = prop_schema + if "required" in resolved: + required.extend(resolved["required"]) + else: + # Wrap the whole body as a "body" parameter + properties["body"] = resolved + if request_body.get("required"): + required.append("body") + + input_schema = {"type": "object", "properties": properties} + if required: + input_schema["required"] = list(set(required)) + + tools.append({ + "name": tool_name, + "description": full_desc[:1024], # Claude has limits + "input_schema": input_schema, + "_meta": { + "path": path, + "method": method, + "parameters": operation.get("parameters", []), + }, + }) + + return tools + + +def execute_api_tool( + tool: dict, + tool_input: dict, api_base_url: str, - call_id: str, + log_fn: Callable, + trace_headers: dict | None = None, +) -> str: + """Execute an API tool by making the HTTP request.""" + meta = tool.get("_meta", {}) + path = meta.get("path", "") + method = meta.get("method", "get") + parameters = meta.get("parameters", []) + + # Build URL with path parameters + url = f"{api_base_url}{path}" + query_params = {} + headers = {"Content-Type": "application/json"} + if trace_headers: + headers.update(trace_headers) + + # Separate path, query, and body parameters + body_data = {} + for param in parameters: + param_name = param.get("name") + param_in = param.get("in") + value = tool_input.get(param_name) + + if value is None: + continue + + if param_in == "path": + url = url.replace(f"{{{param_name}}}", str(value)) + elif param_in == "query": + query_params[param_name] = value + elif param_in == "header": + headers[param_name] = str(value) + + # Remaining input goes to body (for POST/PUT/PATCH) + param_names = {p.get("name") for p in parameters} + for key, value in tool_input.items(): + if key not in param_names: + body_data[key] = value + + try: + log_fn(f"[API] {method.upper()} {url}") + if query_params: + log_fn(f"[API] Query: {json.dumps(query_params)[:200]}") + if body_data: + log_fn(f"[API] Body: {json.dumps(body_data)[:200]}") + + if method == "get": + resp = requests.get(url, params=query_params, headers=headers, timeout=60) + elif method == "post": + resp = requests.post( + url, params=query_params, json=body_data, headers=headers, timeout=60 + ) + elif method == "put": + resp = requests.put( + url, params=query_params, json=body_data, headers=headers, timeout=60 + ) + elif method == "patch": + resp = requests.patch( + url, params=query_params, json=body_data, headers=headers, timeout=60 + ) + elif method == "delete": + resp = requests.delete(url, params=query_params, headers=headers, timeout=60) + else: + return f"Unsupported method: {method}" + + log_fn(f"[API] Response: {resp.status_code}") + + if resp.status_code >= 400: + return f"Error {resp.status_code}: {resp.text[:500]}" + + try: + data = resp.json() + # For lists, summarize if too long but keep key info + if isinstance(data, list) and len(data) > 50: + result = json.dumps(data[:50], indent=2) + result += f"\n... ({len(data) - 50} more items)" + else: + result = json.dumps(data, indent=2) + return result + except json.JSONDecodeError: + return resp.text[:1000] + + except requests.RequestException as e: + return f"Request error: {str(e)}" + + +def _run_agent_impl( + question: str, + api_base_url: str = "https://v2.api.policyengine.org", + call_id: str = "", history: list[dict] | None = None, + max_turns: int = 30, traceparent: str | None = None, ) -> dict: - """Core async agent implementation using Claude Agent SDK.""" - from claude_agent_sdk import ClaudeAgentOptions, query + """Core agent implementation.""" + import logfire + # Get traceparent for HTTP requests def get_trace_headers() -> dict: if traceparent: return {"traceparent": traceparent} return {} def log(msg: str) -> None: + logfire.info("agent_log", message=msg, call_id=call_id) print(msg) if call_id: try: @@ -121,68 +404,91 @@ def log(msg: str) -> None: pass log(f"[AGENT] Starting: {question[:200]}") - log(f"[AGENT] Connecting to MCP server at {api_base_url}/mcp/") - # Build conversation with history + # Fetch and convert OpenAPI spec to tools + log("[AGENT] Fetching OpenAPI spec...") + spec = fetch_openapi_spec(api_base_url) + tools = openapi_to_claude_tools(spec) + log(f"[AGENT] Loaded {len(tools)} API tools") + + # Create tool lookup for execution + tool_lookup = {t["name"]: t for t in tools} + + # Strip _meta from tools before sending to Claude (it doesn't need it) + claude_tools = [ + {k: v for k, v in t.items() if k != "_meta"} for t in tools + ] + # Add the sleep tool + claude_tools.append(SLEEP_TOOL) + + client = anthropic.Anthropic() + + # Build messages with conversation history + messages = [] if history: - # Format history as context in the prompt - context_parts = [] for msg in history: - role = msg.get("role", "user") - content = msg.get("content", "") - context_parts.append(f"{role.upper()}: {content}") - context = "\n\n".join(context_parts) - full_prompt = f"Previous conversation:\n{context}\n\nNew question: {question}" - else: - full_prompt = question - - # Configure Agent SDK with MCP server - options = ClaudeAgentOptions( - mcp_servers={ - "policyengine": { - "type": "sse", - "url": f"{api_base_url}/mcp/", - } - }, - allowed_tools=["mcp__policyengine__*"], - system_prompt=SYSTEM_PROMPT, - ) + messages.append({"role": msg["role"], "content": msg["content"]}) + messages.append({"role": "user", "content": question}) final_response = None turns = 0 - try: - async for message in query(prompt=full_prompt, options=options): - # Handle different message types - msg_type = type(message).__name__ - - if msg_type == "AssistantMessage": - turns += 1 - for block in message.content: - block_type = type(block).__name__ - if block_type == "TextBlock": - log(f"[ASSISTANT] {block.text[:500]}") - final_response = block.text - elif block_type == "ToolUseBlock": - log(f"[TOOL_USE] {block.name}: {str(block.input)[:200]}") - - elif msg_type == "ToolResultMessage": - for result in message.content: - result_str = str(result)[:300] - log(f"[TOOL_RESULT] {result_str}") - - elif msg_type == "ResultMessage": - log( - f"[AGENT] Completed - Cost: ${message.cost:.4f}, Duration: {message.duration:.1f}s" - ) + while turns < max_turns: + turns += 1 + log(f"[AGENT] Turn {turns}") - except Exception as e: - log(f"[AGENT] Error: {str(e)}") - return { - "status": "failed", - "error": str(e), - "turns": turns, - } + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=4096, + system=SYSTEM_PROMPT, + tools=claude_tools, + messages=messages, + ) + + log(f"[AGENT] Stop reason: {response.stop_reason}") + + assistant_content = [] + tool_results = [] + + for block in response.content: + if block.type == "text": + log(f"[ASSISTANT] {block.text[:500]}") + assistant_content.append(block) + final_response = block.text + elif block.type == "tool_use": + log(f"[TOOL_USE] {block.name}: {json.dumps(block.input)[:200]}") + assistant_content.append(block) + + # Execute tool + if block.name == "sleep": + # Handle sleep tool specially + seconds = min(max(block.input.get("seconds", 5), 1), 60) + log(f"[SLEEP] Waiting {seconds} seconds...") + time.sleep(seconds) + result = f"Slept for {seconds} seconds" + else: + tool = tool_lookup.get(block.name) + if tool: + result = execute_api_tool( + tool, block.input, api_base_url, log, get_trace_headers() + ) + else: + result = f"Unknown tool: {block.name}" + + log(f"[TOOL_RESULT] {result[:300]}") + + tool_results.append({ + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + }) + + messages.append({"role": "assistant", "content": assistant_content}) + + if tool_results: + messages.append({"role": "user", "content": tool_results}) + else: + break log(f"[AGENT] Completed in {turns} turns") @@ -206,20 +512,6 @@ def log(msg: str) -> None: return result -def _run_agent_impl( - question: str, - api_base_url: str = "https://v2.api.policyengine.org", - call_id: str = "", - history: list[dict] | None = None, - max_turns: int = 30, - traceparent: str | None = None, -) -> dict: - """Synchronous wrapper for the async agent implementation.""" - return asyncio.run( - _run_agent_async(question, api_base_url, call_id, history, traceparent) - ) - - @app.function(image=image, secrets=[anthropic_secret, logfire_secrets], timeout=600) def run_agent( question: str, @@ -234,19 +526,18 @@ def run_agent( configure_logfire(traceparent) - try: - with logfire.span("run_agent", call_id=call_id, question=question[:200]): - result = _run_agent_impl( - question, - api_base_url, - call_id, - history=history, - max_turns=max_turns, - traceparent=traceparent, - ) - finally: - logfire.force_flush() + with logfire.span("run_agent", call_id=call_id, question=question[:200]): + result = _run_agent_impl( + question, + api_base_url, + call_id, + history=history, + max_turns=max_turns, + traceparent=traceparent, + ) + # Ensure logfire sends all spans before Modal container exits + logfire.force_flush() return result diff --git a/tests/test_agent.py b/tests/test_agent.py index d0fa765..2c591f5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,15 +1,18 @@ -"""Tests for the agent API endpoints. +"""Tests for the agent streaming API endpoints. -Tests verify the agent endpoint structure and integration with Claude Agent SDK. +Tests verify that Claude Code is invoked correctly with proper MCP configuration. """ import pytest + +pytestmark = pytest.mark.integration + +import json +from unittest.mock import AsyncMock, MagicMock, patch from fastapi.testclient import TestClient from policyengine_api.main import app -pytestmark = pytest.mark.integration - client = TestClient(app) @@ -21,54 +24,226 @@ def test_status_not_found(self): response = client.get("/agent/status/nonexistent-job-id") assert response.status_code == 404 - def test_run_request_model(self): - """RunRequest model should accept question field.""" - from policyengine_api.api.agent import RunRequest + def test_ask_request_model(self): + """AskRequest model should accept question field.""" + from policyengine_api.api.agent import AskRequest - req = RunRequest(question="Test question") + req = AskRequest(question="Test question") assert req.question == "Test question" - def test_run_response_model(self): - """RunResponse model should have call_id and status fields.""" - from policyengine_api.api.agent import RunResponse + def test_ask_response_model(self): + """AskResponse model should have job_id and status fields.""" + from policyengine_api.api.agent import AskResponse + + resp = AskResponse(job_id="test-123", status="pending") + assert resp.job_id == "test-123" + assert resp.status == "pending" + + +class TestClaudeCodeInvocation: + """Test that Claude Code is invoked correctly.""" + + def test_claude_cli_runs(self): + """Claude CLI must be installed and able to run a simple prompt.""" + import subprocess + + # Actually run Claude with a simple prompt (no tools needed) + result = subprocess.run( + ["claude", "--print", "Say 'hello' and nothing else"], + capture_output=True, + timeout=30, + ) + assert result.returncode == 0, f"claude failed: {result.stderr.decode()}" + assert "hello" in result.stdout.decode().lower(), "Expected 'hello' in output" + + @pytest.mark.asyncio + async def test_stream_claude_code_invokes_claude_cli(self): + """_stream_claude_code should invoke the claude CLI with correct args.""" + from policyengine_api.api.agent import _stream_claude_code + + captured_args = [] + + async def mock_create_subprocess(*args, **kwargs): + captured_args.append(args) + mock_process = MagicMock() + mock_process.returncode = 0 + + async def mock_stdout_iter(): + yield b"Test output\n" + + mock_process.stdout = mock_stdout_iter() + mock_process.stderr = AsyncMock() + mock_process.stderr.read = AsyncMock(return_value=b"") + mock_process.wait = AsyncMock() + return mock_process + + with patch( + "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess + ): + events = [] + async for event in _stream_claude_code( + "Test question", "http://localhost:8000" + ): + events.append(event) + + # Verify claude was called + assert len(captured_args) == 1 + args = captured_args[0] + + # Check command structure: claude -p --allowedTools + assert args[0] == "claude" + assert "-p" in args + assert "--allowedTools" in args + + # Check question is passed after -p + p_idx = args.index("-p") + assert args[p_idx + 1] == "Test question" + + # Check MCP tools are allowed + allowed_tools_idx = args.index("--allowedTools") + 1 + allowed_tools = args[allowed_tools_idx] + assert "mcp__policyengine__*" in allowed_tools + + @pytest.mark.asyncio + async def test_stream_claude_code_yields_sse_events(self): + """_stream_claude_code should yield properly formatted SSE events.""" + from policyengine_api.api.agent import _stream_claude_code + + async def mock_create_subprocess(*args, **kwargs): + mock_process = MagicMock() + mock_process.returncode = 0 + + async def mock_stdout_iter(): + yield b"Line 1\n" + yield b"Line 2\n" + + mock_process.stdout = mock_stdout_iter() + mock_process.stderr = AsyncMock() + mock_process.stderr.read = AsyncMock(return_value=b"") + mock_process.wait = AsyncMock() + return mock_process + + with patch( + "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess + ): + events = [] + async for event in _stream_claude_code("Test", "http://localhost"): + events.append(event) + + # Should have output events + output_events = [e for e in events if "output" in e] + assert len(output_events) == 2 + + # Each should be valid SSE format + for event in output_events: + assert event.startswith("data: ") + assert event.endswith("\n\n") + data = json.loads(event[6:].strip()) + assert data["type"] == "output" + assert "content" in data + + # Should have done event + done_events = [e for e in events if "done" in e] + assert len(done_events) == 1 + done_data = json.loads(done_events[0][6:].strip()) + assert done_data["type"] == "done" + assert done_data["returncode"] == 0 + + @pytest.mark.asyncio + async def test_stream_claude_code_handles_errors(self): + """_stream_claude_code should yield error events on non-zero exit.""" + from policyengine_api.api.agent import _stream_claude_code + + async def mock_create_subprocess(*args, **kwargs): + mock_process = MagicMock() + mock_process.returncode = 1 + + async def mock_stdout_iter(): + yield b"Partial output\n" + + mock_process.stdout = mock_stdout_iter() + mock_process.stderr = AsyncMock() + mock_process.stderr.read = AsyncMock( + return_value=b"Error: something went wrong" + ) + mock_process.wait = AsyncMock() + return mock_process + + with patch( + "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess + ): + events = [] + async for event in _stream_claude_code("Test", "http://localhost"): + events.append(event) + + # Should have error event + error_events = [e for e in events if "error" in e] + assert len(error_events) == 1 + error_data = json.loads(error_events[0][6:].strip()) + assert error_data["type"] == "error" + assert "something went wrong" in error_data["content"] + + @pytest.mark.asyncio + async def test_stream_claude_code_passes_anthropic_api_key(self): + """_stream_claude_code should pass ANTHROPIC_API_KEY in env.""" + from policyengine_api.api.agent import _stream_claude_code + + captured_kwargs = [] + + async def mock_create_subprocess(*args, **kwargs): + captured_kwargs.append(kwargs) + mock_process = MagicMock() + mock_process.returncode = 0 + + async def mock_stdout_iter(): + yield b"Done\n" + + mock_process.stdout = mock_stdout_iter() + mock_process.stderr = AsyncMock() + mock_process.stderr.read = AsyncMock(return_value=b"") + mock_process.wait = AsyncMock() + return mock_process + + with patch( + "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess + ): + async for _ in _stream_claude_code("Test", "http://localhost"): + pass + + # Verify env was passed + assert len(captured_kwargs) == 1 + assert "env" in captured_kwargs[0] + assert "ANTHROPIC_API_KEY" in captured_kwargs[0]["env"] - resp = RunResponse(call_id="fc-test123", status="running") - assert resp.call_id == "fc-test123" - assert resp.status == "running" - def test_logs_not_found(self): - """GET /agent/logs/{call_id} should return 404 for unknown call.""" - response = client.get("/agent/logs/nonexistent-call-id") - assert response.status_code == 404 +class TestAgentSandbox: + """Test the Modal sandbox configuration.""" + def test_sandbox_image_uses_bun(self): + """Sandbox image should use bun, not npm.""" + from policyengine_api.agent_sandbox import sandbox_image -class TestAgentSandbox: - """Test the Modal agent sandbox configuration.""" + # Just verify the image is defined - actual bun installation + # is tested when deploying to Modal + assert sandbox_image is not None - def test_run_agent_function_signature(self): - """run_agent should accept expected parameters.""" + def test_run_function_signature(self): + """run_claude_code_in_sandbox should accept question and api_base_url.""" import inspect - from policyengine_api.agent_sandbox import run_agent + from policyengine_api.agent_sandbox import run_claude_code_in_sandbox - sig = inspect.signature(run_agent.local) + sig = inspect.signature(run_claude_code_in_sandbox) params = list(sig.parameters.keys()) assert "question" in params assert "api_base_url" in params - assert "call_id" in params - assert "history" in params def test_modal_function_defined(self): - """run_agent Modal function should be defined.""" - from policyengine_api.agent_sandbox import run_agent - - assert run_agent is not None - assert hasattr(run_agent, "remote") or hasattr(run_agent, "local") - - def test_system_prompt_defined(self): - """System prompt should be defined with key instructions.""" - from policyengine_api.agent_sandbox import SYSTEM_PROMPT - - assert "policyengine-uk" in SYSTEM_PROMPT - assert "policyengine-us" in SYSTEM_PROMPT - assert "filter by country" in SYSTEM_PROMPT.lower() + """run_policy_analysis Modal function should be defined.""" + from policyengine_api.agent_sandbox import run_policy_analysis + + # Modal functions are wrapped, so check it exists and has expected attributes + assert run_policy_analysis is not None + assert hasattr(run_policy_analysis, "remote") or hasattr( + run_policy_analysis, "local" + ) diff --git a/tests/test_agent_sandbox.py b/tests/test_agent_sandbox.py index ea6e91d..d92bf2d 100644 --- a/tests/test_agent_sandbox.py +++ b/tests/test_agent_sandbox.py @@ -1,8 +1,26 @@ -"""Tests for the agent sandbox using Claude Agent SDK with MCP.""" +"""Tests for the agent sandbox using direct Claude API with OpenAPI-generated tools.""" import pytest -from policyengine_api.agent_sandbox import _run_agent_impl +from policyengine_api.agent_sandbox import ( + _run_agent_impl, + fetch_openapi_spec, + openapi_to_claude_tools, +) + + +def test_openapi_tool_generation(): + """OpenAPI spec generates tools correctly.""" + spec = fetch_openapi_spec("https://v2.api.policyengine.org") + tools = openapi_to_claude_tools(spec) + + assert len(tools) > 30 # Should have many endpoints + tool_names = [t["name"] for t in tools] + + # Check key endpoints exist + assert any("parameters" in n for n in tool_names) + assert any("household" in n for n in tool_names) + assert any("policies" in n for n in tool_names) @pytest.mark.integration @@ -33,6 +51,7 @@ def test_uk_household_calculation(): max_turns=20, ) assert result["status"] == "completed" + # Should mention income tax amount assert "tax" in result["result"].lower() @@ -46,6 +65,7 @@ def test_economic_impact_personal_allowance(): max_turns=25, ) assert result["status"] == "completed" + # Should mention some impact metric assert any( word in result["result"].lower() for word in ["budget", "cost", "revenue", "billion", "impact", "decile"] diff --git a/uv.lock b/uv.lock index cc52036..094ebf8 100644 --- a/uv.lock +++ b/uv.lock @@ -348,22 +348,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] -[[package]] -name = "claude-agent-sdk" -version = "0.1.18" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "mcp" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fd/3d/a8c6ad873e8448696d44441c9eb2c24dded620fb32415d68f576a542ccde/claude_agent_sdk-0.1.18.tar.gz", hash = "sha256:4fcb8730cc77dea562fbe9aa48c65eced3ef58a6bb1f34f77e50e8258902477d", size = 56162, upload-time = "2025-12-18T00:42:57.926Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/06/14/f529f7c4bab7c71dcbcc8c66f12f491e644ee8a027ac5111d13705df207e/claude_agent_sdk-0.1.18-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9e45b4e3c20c072c3e3325fa60bab9a4b5a7cbbce64ca274b8d7d0af42dd9dd8", size = 54560828, upload-time = "2025-12-18T00:42:44.71Z" }, - { url = "https://files.pythonhosted.org/packages/2c/68/6e83005aa7bb9056bfad0aef0605249f877dc0c78724c9c0fadebff600fb/claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:3c41bd8f38848609ae0d5da8d7327a4c2d7057a363feafb6fd70df611ea204cc", size = 68743107, upload-time = "2025-12-18T00:42:48.255Z" }, - { url = "https://files.pythonhosted.org/packages/fb/85/7d6dd85f402135a610894734c442f1166ffed61d03eced39d6bfd14efccd/claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:983f15e51253f40c55136a86d7cc63e023a3576428b05fa1459093d461b2d215", size = 70444964, upload-time = "2025-12-18T00:42:51.752Z" }, - { url = "https://files.pythonhosted.org/packages/3c/fa/d2b22b7a713c4c049cbd5f9f635836ea5429ff65c1f3bcf4658a8e1c1cf5/claude_agent_sdk-0.1.18-py3-none-win_amd64.whl", hash = "sha256:36f5b84d5c3c8773ee9b56aeb5ab345d1033231db37f80d1f20ac15239bef41c", size = 72637215, upload-time = "2025-12-18T00:42:55.269Z" }, -] - [[package]] name = "click" version = "8.3.1" @@ -1775,7 +1759,6 @@ source = { editable = "." } dependencies = [ { name = "anthropic" }, { name = "boto3" }, - { name = "claude-agent-sdk" }, { name = "fastapi" }, { name = "fastapi-cache2" }, { name = "fastapi-mcp" }, @@ -1812,7 +1795,6 @@ dev = [ requires-dist = [ { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, - { name = "claude-agent-sdk", specifier = ">=0.1.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "fastapi-cache2", specifier = ">=0.2.1" }, { name = "fastapi-mcp", specifier = ">=0.4.0" },