diff --git a/pyproject.toml b/pyproject.toml index 27eb310..37cac99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", + "claude-agent-sdk>=0.1.0", ] [project.optional-dependencies] diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index 57f7faf..f460320 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -1,16 +1,12 @@ -"""Modal agent using Claude API with tools auto-generated from OpenAPI spec.""" +"""Modal agent using Claude Agent SDK with MCP server connection.""" -import json -import re -import time -from typing import Callable +import asyncio -import anthropic import modal import requests image = modal.Image.debian_slim(python_version="3.12").pip_install( - "anthropic", "requests", "logfire[httpx]" + "claude-agent-sdk", "requests", "logfire[httpx]" ) app = modal.App("policyengine-sandbox") @@ -35,7 +31,6 @@ def configure_logfire(traceparent: str | None = None): console=False, ) - # If traceparent provided, attach to the current context if traceparent: from opentelemetry import context from opentelemetry.trace.propagation.tracecontext import ( @@ -49,7 +44,7 @@ def configure_logfire(traceparent: str | None = None): SYSTEM_PROMPT = """You are a PolicyEngine assistant that helps users understand tax and benefit policies. -You have access to the full PolicyEngine API. +You have access to the PolicyEngine API via MCP tools. ## CRITICAL: Always filter by country @@ -57,340 +52,62 @@ def configure_logfire(traceparent: str | None = None): - "policyengine-uk" for UK questions - "policyengine-us" for US questions -Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results and waste turns finding the right ones. +Parameters and datasets from both countries are in the same database. Without the filter, you'll get mixed results. ## Key workflows 1. **Household calculations**: - - POST /household/calculate with model_name and people array - - Poll GET /household/calculate/{job_id} until completed + - Use household_calculate with model_name and people array + - Poll household_calculate_status until completed 2. **Parameter lookup**: - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk (ALWAYS include country filter) - - GET /parameter-values/?parameter_id=...¤t=true for the current value + - Use parameters_list with search and tax_benefit_model_name filter + - Use parameter_values_list with parameter_id and current=true -3. **Economic impact analysis** (budget impact, decile impacts): - - GET /parameters/?search=...&tax_benefit_model_name=policyengine-uk to find parameter_id - - POST /policies/ to create reform with parameter_values - - GET /datasets/?tax_benefit_model_name=policyengine-uk to find dataset_id - - POST /analysis/economic-impact with tax_benefit_model_name, policy_id and dataset_id - - GET /analysis/economic-impact/{report_id} for results (includes decile_impacts and program_statistics) +3. **Economic impact analysis**: + - Find parameter_id with parameters_list + - Create policy with policies_create + - Find dataset_id with datasets_list + - Run analysis with analysis_economic_impact + - Get results with analysis_economic_impact_status ## Response formatting Follow PolicyEngine's writing style: -1. **Sentence case**: Use sentence case for all headings (e.g. "Tax breakdown" not "Tax Breakdown") -2. **Active voice**: "The reform reduces poverty by 3.2%" not "Poverty is reduced by 3.2%" -3. **Quantitative precision**: Use specific numbers, avoid vague words like "significantly" or "substantially" -4. **Neutral tone**: Describe what policies do, not whether they're good or bad -5. **Tables for data**: Present breakdowns and comparisons in markdown tables +1. **Sentence case**: Use sentence case for headings +2. **Active voice**: "The reform reduces poverty by 3.2%" +3. **Quantitative precision**: Use specific numbers +4. **Neutral tone**: Describe what policies do objectively +5. **Tables for data**: Use markdown tables for breakdowns -Example response format: +Example: | Item | Amount | |------|--------| | Income tax | £7,486 | | National Insurance | £2,994 | | **Total tax** | **£10,480** | -- Gross income: £50,000 -- Net income: £39,520 -- Effective tax rate: 21.0% - -Avoid: "significantly reduces", "substantial savings", "unfortunately", "great news" -Prefer: specific percentages, pound/dollar amounts, neutral descriptions - -## Guidelines - -1. Use the API tools to get accurate, current data -2. Be concise - lead with key numbers -3. For UK, amounts are in GBP (£). For US, amounts are in USD ($) -4. When polling async endpoints, use the sleep tool to wait 5-10 seconds between requests +Avoid vague words like "significantly" or "substantially" - use numbers. """ -# Sleep tool for polling delays -SLEEP_TOOL = { - "name": "sleep", - "description": "Wait for a specified number of seconds. Use this between polling requests to avoid hammering the API.", - "input_schema": { - "type": "object", - "properties": { - "seconds": { - "type": "number", - "description": "Number of seconds to sleep (1-60)", - } - }, - "required": ["seconds"], - }, -} - - -def fetch_openapi_spec(api_base_url: str) -> dict: - """Fetch and cache OpenAPI spec.""" - resp = requests.get(f"{api_base_url}/openapi.json", timeout=30) - resp.raise_for_status() - return resp.json() - - -def resolve_ref(spec: dict, ref: str) -> dict: - """Resolve a $ref pointer in the OpenAPI spec.""" - if not ref.startswith("#/"): - return {} - parts = ref[2:].split("/") - result = spec - for part in parts: - result = result.get(part, {}) - return result - - -def schema_to_json_schema(spec: dict, schema: dict) -> dict: - """Convert OpenAPI schema to JSON Schema for Claude tools.""" - if "$ref" in schema: - schema = resolve_ref(spec, schema["$ref"]) - - result = {} - - if "type" in schema: - result["type"] = schema["type"] - if "description" in schema: - result["description"] = schema["description"] - if "enum" in schema: - result["enum"] = schema["enum"] - if "default" in schema: - result["default"] = schema["default"] - if "format" in schema: - # Add format info to description - fmt = schema["format"] - if "description" in result: - result["description"] += f" (format: {fmt})" - else: - result["description"] = f"Format: {fmt}" - - # Handle anyOf (often used for Optional types) - if "anyOf" in schema: - non_null = [s for s in schema["anyOf"] if s.get("type") != "null"] - if len(non_null) == 1: - result.update(schema_to_json_schema(spec, non_null[0])) - elif non_null: - result.update(schema_to_json_schema(spec, non_null[0])) - - # Handle allOf - if "allOf" in schema: - for sub in schema["allOf"]: - result.update(schema_to_json_schema(spec, sub)) - - # Handle objects - if schema.get("type") == "object" or "properties" in schema: - result["type"] = "object" - if "properties" in schema: - result["properties"] = {} - for prop_name, prop_schema in schema["properties"].items(): - result["properties"][prop_name] = schema_to_json_schema( - spec, prop_schema - ) - if "required" in schema: - result["required"] = schema["required"] - - # Handle arrays - if schema.get("type") == "array" and "items" in schema: - result["items"] = schema_to_json_schema(spec, schema["items"]) - - return result - - -def openapi_to_claude_tools(spec: dict) -> list[dict]: - """Convert OpenAPI spec to Claude tool definitions.""" - tools = [] - - for path, methods in spec.get("paths", {}).items(): - for method, operation in methods.items(): - if method not in ("get", "post", "put", "patch", "delete"): - continue - - # Build tool name from operationId or path+method - op_id = operation.get("operationId", f"{method}_{path}") - # Clean up the name - tool_name = re.sub(r"[^a-zA-Z0-9_]", "_", op_id) - tool_name = re.sub(r"_+", "_", tool_name).strip("_") - - # Build description - summary = operation.get("summary", "") - description = operation.get("description", "") - full_desc = f"{method.upper()} {path}" - if summary: - full_desc += f"\n\n{summary}" - if description: - full_desc += f"\n\n{description}" - - # Build input schema - properties = {} - required = [] - - # Path parameters - for param in operation.get("parameters", []): - param_name = param.get("name") - param_in = param.get("in") - param_schema = param.get("schema", {}) - param_required = param.get("required", False) - - prop = schema_to_json_schema(spec, param_schema) - prop["description"] = ( - param.get("description", "") - + f" (in: {param_in})" - ) - properties[param_name] = prop - - if param_required: - required.append(param_name) - - # Request body - request_body = operation.get("requestBody", {}) - if request_body: - content = request_body.get("content", {}) - json_content = content.get("application/json", {}) - body_schema = json_content.get("schema", {}) - - if body_schema: - resolved = schema_to_json_schema(spec, body_schema) - # Flatten body properties into tool properties - if "properties" in resolved: - for prop_name, prop_schema in resolved["properties"].items(): - properties[prop_name] = prop_schema - if "required" in resolved: - required.extend(resolved["required"]) - else: - # Wrap the whole body as a "body" parameter - properties["body"] = resolved - if request_body.get("required"): - required.append("body") - - input_schema = {"type": "object", "properties": properties} - if required: - input_schema["required"] = list(set(required)) - - tools.append({ - "name": tool_name, - "description": full_desc[:1024], # Claude has limits - "input_schema": input_schema, - "_meta": { - "path": path, - "method": method, - "parameters": operation.get("parameters", []), - }, - }) - - return tools - - -def execute_api_tool( - tool: dict, - tool_input: dict, - api_base_url: str, - log_fn: Callable, - trace_headers: dict | None = None, -) -> str: - """Execute an API tool by making the HTTP request.""" - meta = tool.get("_meta", {}) - path = meta.get("path", "") - method = meta.get("method", "get") - parameters = meta.get("parameters", []) - - # Build URL with path parameters - url = f"{api_base_url}{path}" - query_params = {} - headers = {"Content-Type": "application/json"} - if trace_headers: - headers.update(trace_headers) - - # Separate path, query, and body parameters - body_data = {} - for param in parameters: - param_name = param.get("name") - param_in = param.get("in") - value = tool_input.get(param_name) - - if value is None: - continue - - if param_in == "path": - url = url.replace(f"{{{param_name}}}", str(value)) - elif param_in == "query": - query_params[param_name] = value - elif param_in == "header": - headers[param_name] = str(value) - - # Remaining input goes to body (for POST/PUT/PATCH) - param_names = {p.get("name") for p in parameters} - for key, value in tool_input.items(): - if key not in param_names: - body_data[key] = value - - try: - log_fn(f"[API] {method.upper()} {url}") - if query_params: - log_fn(f"[API] Query: {json.dumps(query_params)[:200]}") - if body_data: - log_fn(f"[API] Body: {json.dumps(body_data)[:200]}") - - if method == "get": - resp = requests.get(url, params=query_params, headers=headers, timeout=60) - elif method == "post": - resp = requests.post( - url, params=query_params, json=body_data, headers=headers, timeout=60 - ) - elif method == "put": - resp = requests.put( - url, params=query_params, json=body_data, headers=headers, timeout=60 - ) - elif method == "patch": - resp = requests.patch( - url, params=query_params, json=body_data, headers=headers, timeout=60 - ) - elif method == "delete": - resp = requests.delete(url, params=query_params, headers=headers, timeout=60) - else: - return f"Unsupported method: {method}" - - log_fn(f"[API] Response: {resp.status_code}") - - if resp.status_code >= 400: - return f"Error {resp.status_code}: {resp.text[:500]}" - - try: - data = resp.json() - # For lists, summarize if too long but keep key info - if isinstance(data, list) and len(data) > 50: - result = json.dumps(data[:50], indent=2) - result += f"\n... ({len(data) - 50} more items)" - else: - result = json.dumps(data, indent=2) - return result - except json.JSONDecodeError: - return resp.text[:1000] - - except requests.RequestException as e: - return f"Request error: {str(e)}" - -def _run_agent_impl( +async def _run_agent_async( question: str, - api_base_url: str = "https://v2.api.policyengine.org", - call_id: str = "", + api_base_url: str, + call_id: str, history: list[dict] | None = None, - max_turns: int = 30, traceparent: str | None = None, ) -> dict: - """Core agent implementation.""" - import logfire + """Core async agent implementation using Claude Agent SDK.""" + from claude_agent_sdk import ClaudeAgentOptions, query - # Get traceparent for HTTP requests def get_trace_headers() -> dict: if traceparent: return {"traceparent": traceparent} return {} def log(msg: str) -> None: - logfire.info("agent_log", message=msg, call_id=call_id) print(msg) if call_id: try: @@ -404,91 +121,68 @@ def log(msg: str) -> None: pass log(f"[AGENT] Starting: {question[:200]}") + log(f"[AGENT] Connecting to MCP server at {api_base_url}/mcp/") - # Fetch and convert OpenAPI spec to tools - log("[AGENT] Fetching OpenAPI spec...") - spec = fetch_openapi_spec(api_base_url) - tools = openapi_to_claude_tools(spec) - log(f"[AGENT] Loaded {len(tools)} API tools") - - # Create tool lookup for execution - tool_lookup = {t["name"]: t for t in tools} - - # Strip _meta from tools before sending to Claude (it doesn't need it) - claude_tools = [ - {k: v for k, v in t.items() if k != "_meta"} for t in tools - ] - # Add the sleep tool - claude_tools.append(SLEEP_TOOL) - - client = anthropic.Anthropic() - - # Build messages with conversation history - messages = [] + # Build conversation with history if history: + # Format history as context in the prompt + context_parts = [] for msg in history: - messages.append({"role": msg["role"], "content": msg["content"]}) - messages.append({"role": "user", "content": question}) + role = msg.get("role", "user") + content = msg.get("content", "") + context_parts.append(f"{role.upper()}: {content}") + context = "\n\n".join(context_parts) + full_prompt = f"Previous conversation:\n{context}\n\nNew question: {question}" + else: + full_prompt = question + + # Configure Agent SDK with MCP server + options = ClaudeAgentOptions( + mcp_servers={ + "policyengine": { + "type": "sse", + "url": f"{api_base_url}/mcp/", + } + }, + allowed_tools=["mcp__policyengine__*"], + system_prompt=SYSTEM_PROMPT, + ) final_response = None turns = 0 - while turns < max_turns: - turns += 1 - log(f"[AGENT] Turn {turns}") - - response = client.messages.create( - model="claude-sonnet-4-20250514", - max_tokens=4096, - system=SYSTEM_PROMPT, - tools=claude_tools, - messages=messages, - ) + try: + async for message in query(prompt=full_prompt, options=options): + # Handle different message types + msg_type = type(message).__name__ + + if msg_type == "AssistantMessage": + turns += 1 + for block in message.content: + block_type = type(block).__name__ + if block_type == "TextBlock": + log(f"[ASSISTANT] {block.text[:500]}") + final_response = block.text + elif block_type == "ToolUseBlock": + log(f"[TOOL_USE] {block.name}: {str(block.input)[:200]}") + + elif msg_type == "ToolResultMessage": + for result in message.content: + result_str = str(result)[:300] + log(f"[TOOL_RESULT] {result_str}") + + elif msg_type == "ResultMessage": + log( + f"[AGENT] Completed - Cost: ${message.cost:.4f}, Duration: {message.duration:.1f}s" + ) - log(f"[AGENT] Stop reason: {response.stop_reason}") - - assistant_content = [] - tool_results = [] - - for block in response.content: - if block.type == "text": - log(f"[ASSISTANT] {block.text[:500]}") - assistant_content.append(block) - final_response = block.text - elif block.type == "tool_use": - log(f"[TOOL_USE] {block.name}: {json.dumps(block.input)[:200]}") - assistant_content.append(block) - - # Execute tool - if block.name == "sleep": - # Handle sleep tool specially - seconds = min(max(block.input.get("seconds", 5), 1), 60) - log(f"[SLEEP] Waiting {seconds} seconds...") - time.sleep(seconds) - result = f"Slept for {seconds} seconds" - else: - tool = tool_lookup.get(block.name) - if tool: - result = execute_api_tool( - tool, block.input, api_base_url, log, get_trace_headers() - ) - else: - result = f"Unknown tool: {block.name}" - - log(f"[TOOL_RESULT] {result[:300]}") - - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - }) - - messages.append({"role": "assistant", "content": assistant_content}) - - if tool_results: - messages.append({"role": "user", "content": tool_results}) - else: - break + except Exception as e: + log(f"[AGENT] Error: {str(e)}") + return { + "status": "failed", + "error": str(e), + "turns": turns, + } log(f"[AGENT] Completed in {turns} turns") @@ -512,6 +206,20 @@ def log(msg: str) -> None: return result +def _run_agent_impl( + question: str, + api_base_url: str = "https://v2.api.policyengine.org", + call_id: str = "", + history: list[dict] | None = None, + max_turns: int = 30, + traceparent: str | None = None, +) -> dict: + """Synchronous wrapper for the async agent implementation.""" + return asyncio.run( + _run_agent_async(question, api_base_url, call_id, history, traceparent) + ) + + @app.function(image=image, secrets=[anthropic_secret, logfire_secrets], timeout=600) def run_agent( question: str, @@ -526,18 +234,19 @@ def run_agent( configure_logfire(traceparent) - with logfire.span("run_agent", call_id=call_id, question=question[:200]): - result = _run_agent_impl( - question, - api_base_url, - call_id, - history=history, - max_turns=max_turns, - traceparent=traceparent, - ) + try: + with logfire.span("run_agent", call_id=call_id, question=question[:200]): + result = _run_agent_impl( + question, + api_base_url, + call_id, + history=history, + max_turns=max_turns, + traceparent=traceparent, + ) + finally: + logfire.force_flush() - # Ensure logfire sends all spans before Modal container exits - logfire.force_flush() return result diff --git a/tests/test_agent.py b/tests/test_agent.py index 2c591f5..d0fa765 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,18 +1,15 @@ -"""Tests for the agent streaming API endpoints. +"""Tests for the agent API endpoints. -Tests verify that Claude Code is invoked correctly with proper MCP configuration. +Tests verify the agent endpoint structure and integration with Claude Agent SDK. """ import pytest - -pytestmark = pytest.mark.integration - -import json -from unittest.mock import AsyncMock, MagicMock, patch from fastapi.testclient import TestClient from policyengine_api.main import app +pytestmark = pytest.mark.integration + client = TestClient(app) @@ -24,226 +21,54 @@ def test_status_not_found(self): response = client.get("/agent/status/nonexistent-job-id") assert response.status_code == 404 - def test_ask_request_model(self): - """AskRequest model should accept question field.""" - from policyengine_api.api.agent import AskRequest + def test_run_request_model(self): + """RunRequest model should accept question field.""" + from policyengine_api.api.agent import RunRequest - req = AskRequest(question="Test question") + req = RunRequest(question="Test question") assert req.question == "Test question" - def test_ask_response_model(self): - """AskResponse model should have job_id and status fields.""" - from policyengine_api.api.agent import AskResponse - - resp = AskResponse(job_id="test-123", status="pending") - assert resp.job_id == "test-123" - assert resp.status == "pending" - - -class TestClaudeCodeInvocation: - """Test that Claude Code is invoked correctly.""" - - def test_claude_cli_runs(self): - """Claude CLI must be installed and able to run a simple prompt.""" - import subprocess - - # Actually run Claude with a simple prompt (no tools needed) - result = subprocess.run( - ["claude", "--print", "Say 'hello' and nothing else"], - capture_output=True, - timeout=30, - ) - assert result.returncode == 0, f"claude failed: {result.stderr.decode()}" - assert "hello" in result.stdout.decode().lower(), "Expected 'hello' in output" - - @pytest.mark.asyncio - async def test_stream_claude_code_invokes_claude_cli(self): - """_stream_claude_code should invoke the claude CLI with correct args.""" - from policyengine_api.api.agent import _stream_claude_code - - captured_args = [] - - async def mock_create_subprocess(*args, **kwargs): - captured_args.append(args) - mock_process = MagicMock() - mock_process.returncode = 0 - - async def mock_stdout_iter(): - yield b"Test output\n" - - mock_process.stdout = mock_stdout_iter() - mock_process.stderr = AsyncMock() - mock_process.stderr.read = AsyncMock(return_value=b"") - mock_process.wait = AsyncMock() - return mock_process - - with patch( - "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess - ): - events = [] - async for event in _stream_claude_code( - "Test question", "http://localhost:8000" - ): - events.append(event) - - # Verify claude was called - assert len(captured_args) == 1 - args = captured_args[0] - - # Check command structure: claude -p --allowedTools - assert args[0] == "claude" - assert "-p" in args - assert "--allowedTools" in args - - # Check question is passed after -p - p_idx = args.index("-p") - assert args[p_idx + 1] == "Test question" - - # Check MCP tools are allowed - allowed_tools_idx = args.index("--allowedTools") + 1 - allowed_tools = args[allowed_tools_idx] - assert "mcp__policyengine__*" in allowed_tools - - @pytest.mark.asyncio - async def test_stream_claude_code_yields_sse_events(self): - """_stream_claude_code should yield properly formatted SSE events.""" - from policyengine_api.api.agent import _stream_claude_code - - async def mock_create_subprocess(*args, **kwargs): - mock_process = MagicMock() - mock_process.returncode = 0 - - async def mock_stdout_iter(): - yield b"Line 1\n" - yield b"Line 2\n" - - mock_process.stdout = mock_stdout_iter() - mock_process.stderr = AsyncMock() - mock_process.stderr.read = AsyncMock(return_value=b"") - mock_process.wait = AsyncMock() - return mock_process - - with patch( - "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess - ): - events = [] - async for event in _stream_claude_code("Test", "http://localhost"): - events.append(event) - - # Should have output events - output_events = [e for e in events if "output" in e] - assert len(output_events) == 2 - - # Each should be valid SSE format - for event in output_events: - assert event.startswith("data: ") - assert event.endswith("\n\n") - data = json.loads(event[6:].strip()) - assert data["type"] == "output" - assert "content" in data - - # Should have done event - done_events = [e for e in events if "done" in e] - assert len(done_events) == 1 - done_data = json.loads(done_events[0][6:].strip()) - assert done_data["type"] == "done" - assert done_data["returncode"] == 0 - - @pytest.mark.asyncio - async def test_stream_claude_code_handles_errors(self): - """_stream_claude_code should yield error events on non-zero exit.""" - from policyengine_api.api.agent import _stream_claude_code - - async def mock_create_subprocess(*args, **kwargs): - mock_process = MagicMock() - mock_process.returncode = 1 - - async def mock_stdout_iter(): - yield b"Partial output\n" - - mock_process.stdout = mock_stdout_iter() - mock_process.stderr = AsyncMock() - mock_process.stderr.read = AsyncMock( - return_value=b"Error: something went wrong" - ) - mock_process.wait = AsyncMock() - return mock_process - - with patch( - "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess - ): - events = [] - async for event in _stream_claude_code("Test", "http://localhost"): - events.append(event) - - # Should have error event - error_events = [e for e in events if "error" in e] - assert len(error_events) == 1 - error_data = json.loads(error_events[0][6:].strip()) - assert error_data["type"] == "error" - assert "something went wrong" in error_data["content"] - - @pytest.mark.asyncio - async def test_stream_claude_code_passes_anthropic_api_key(self): - """_stream_claude_code should pass ANTHROPIC_API_KEY in env.""" - from policyengine_api.api.agent import _stream_claude_code - - captured_kwargs = [] - - async def mock_create_subprocess(*args, **kwargs): - captured_kwargs.append(kwargs) - mock_process = MagicMock() - mock_process.returncode = 0 - - async def mock_stdout_iter(): - yield b"Done\n" - - mock_process.stdout = mock_stdout_iter() - mock_process.stderr = AsyncMock() - mock_process.stderr.read = AsyncMock(return_value=b"") - mock_process.wait = AsyncMock() - return mock_process - - with patch( - "asyncio.create_subprocess_exec", side_effect=mock_create_subprocess - ): - async for _ in _stream_claude_code("Test", "http://localhost"): - pass - - # Verify env was passed - assert len(captured_kwargs) == 1 - assert "env" in captured_kwargs[0] - assert "ANTHROPIC_API_KEY" in captured_kwargs[0]["env"] + def test_run_response_model(self): + """RunResponse model should have call_id and status fields.""" + from policyengine_api.api.agent import RunResponse + resp = RunResponse(call_id="fc-test123", status="running") + assert resp.call_id == "fc-test123" + assert resp.status == "running" -class TestAgentSandbox: - """Test the Modal sandbox configuration.""" + def test_logs_not_found(self): + """GET /agent/logs/{call_id} should return 404 for unknown call.""" + response = client.get("/agent/logs/nonexistent-call-id") + assert response.status_code == 404 - def test_sandbox_image_uses_bun(self): - """Sandbox image should use bun, not npm.""" - from policyengine_api.agent_sandbox import sandbox_image - # Just verify the image is defined - actual bun installation - # is tested when deploying to Modal - assert sandbox_image is not None +class TestAgentSandbox: + """Test the Modal agent sandbox configuration.""" - def test_run_function_signature(self): - """run_claude_code_in_sandbox should accept question and api_base_url.""" + def test_run_agent_function_signature(self): + """run_agent should accept expected parameters.""" import inspect - from policyengine_api.agent_sandbox import run_claude_code_in_sandbox + from policyengine_api.agent_sandbox import run_agent - sig = inspect.signature(run_claude_code_in_sandbox) + sig = inspect.signature(run_agent.local) params = list(sig.parameters.keys()) assert "question" in params assert "api_base_url" in params + assert "call_id" in params + assert "history" in params def test_modal_function_defined(self): - """run_policy_analysis Modal function should be defined.""" - from policyengine_api.agent_sandbox import run_policy_analysis - - # Modal functions are wrapped, so check it exists and has expected attributes - assert run_policy_analysis is not None - assert hasattr(run_policy_analysis, "remote") or hasattr( - run_policy_analysis, "local" - ) + """run_agent Modal function should be defined.""" + from policyengine_api.agent_sandbox import run_agent + + assert run_agent is not None + assert hasattr(run_agent, "remote") or hasattr(run_agent, "local") + + def test_system_prompt_defined(self): + """System prompt should be defined with key instructions.""" + from policyengine_api.agent_sandbox import SYSTEM_PROMPT + + assert "policyengine-uk" in SYSTEM_PROMPT + assert "policyengine-us" in SYSTEM_PROMPT + assert "filter by country" in SYSTEM_PROMPT.lower() diff --git a/tests/test_agent_sandbox.py b/tests/test_agent_sandbox.py index d92bf2d..ea6e91d 100644 --- a/tests/test_agent_sandbox.py +++ b/tests/test_agent_sandbox.py @@ -1,26 +1,8 @@ -"""Tests for the agent sandbox using direct Claude API with OpenAPI-generated tools.""" +"""Tests for the agent sandbox using Claude Agent SDK with MCP.""" import pytest -from policyengine_api.agent_sandbox import ( - _run_agent_impl, - fetch_openapi_spec, - openapi_to_claude_tools, -) - - -def test_openapi_tool_generation(): - """OpenAPI spec generates tools correctly.""" - spec = fetch_openapi_spec("https://v2.api.policyengine.org") - tools = openapi_to_claude_tools(spec) - - assert len(tools) > 30 # Should have many endpoints - tool_names = [t["name"] for t in tools] - - # Check key endpoints exist - assert any("parameters" in n for n in tool_names) - assert any("household" in n for n in tool_names) - assert any("policies" in n for n in tool_names) +from policyengine_api.agent_sandbox import _run_agent_impl @pytest.mark.integration @@ -51,7 +33,6 @@ def test_uk_household_calculation(): max_turns=20, ) assert result["status"] == "completed" - # Should mention income tax amount assert "tax" in result["result"].lower() @@ -65,7 +46,6 @@ def test_economic_impact_personal_allowance(): max_turns=25, ) assert result["status"] == "completed" - # Should mention some impact metric assert any( word in result["result"].lower() for word in ["budget", "cost", "revenue", "billion", "impact", "decile"] diff --git a/uv.lock b/uv.lock index 094ebf8..cc52036 100644 --- a/uv.lock +++ b/uv.lock @@ -348,6 +348,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.18" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "mcp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/3d/a8c6ad873e8448696d44441c9eb2c24dded620fb32415d68f576a542ccde/claude_agent_sdk-0.1.18.tar.gz", hash = "sha256:4fcb8730cc77dea562fbe9aa48c65eced3ef58a6bb1f34f77e50e8258902477d", size = 56162, upload-time = "2025-12-18T00:42:57.926Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/14/f529f7c4bab7c71dcbcc8c66f12f491e644ee8a027ac5111d13705df207e/claude_agent_sdk-0.1.18-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9e45b4e3c20c072c3e3325fa60bab9a4b5a7cbbce64ca274b8d7d0af42dd9dd8", size = 54560828, upload-time = "2025-12-18T00:42:44.71Z" }, + { url = "https://files.pythonhosted.org/packages/2c/68/6e83005aa7bb9056bfad0aef0605249f877dc0c78724c9c0fadebff600fb/claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:3c41bd8f38848609ae0d5da8d7327a4c2d7057a363feafb6fd70df611ea204cc", size = 68743107, upload-time = "2025-12-18T00:42:48.255Z" }, + { url = "https://files.pythonhosted.org/packages/fb/85/7d6dd85f402135a610894734c442f1166ffed61d03eced39d6bfd14efccd/claude_agent_sdk-0.1.18-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:983f15e51253f40c55136a86d7cc63e023a3576428b05fa1459093d461b2d215", size = 70444964, upload-time = "2025-12-18T00:42:51.752Z" }, + { url = "https://files.pythonhosted.org/packages/3c/fa/d2b22b7a713c4c049cbd5f9f635836ea5429ff65c1f3bcf4658a8e1c1cf5/claude_agent_sdk-0.1.18-py3-none-win_amd64.whl", hash = "sha256:36f5b84d5c3c8773ee9b56aeb5ab345d1033231db37f80d1f20ac15239bef41c", size = 72637215, upload-time = "2025-12-18T00:42:55.269Z" }, +] + [[package]] name = "click" version = "8.3.1" @@ -1759,6 +1775,7 @@ source = { editable = "." } dependencies = [ { name = "anthropic" }, { name = "boto3" }, + { name = "claude-agent-sdk" }, { name = "fastapi" }, { name = "fastapi-cache2" }, { name = "fastapi-mcp" }, @@ -1795,6 +1812,7 @@ dev = [ requires-dist = [ { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, + { name = "claude-agent-sdk", specifier = ">=0.1.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "fastapi-cache2", specifier = ">=0.2.1" }, { name = "fastapi-mcp", specifier = ">=0.4.0" },