From 5d8efc09e7253a809acbbe01c140be3d4dfbac1e Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Mon, 29 Dec 2025 11:56:01 +0000 Subject: [PATCH] refactor: agent uses log polling instead of Modal stdout Modal doesn't expose function call logs programmatically, so instead: - Modal function POSTs each log line to /agent/log/{call_id} - Modal function POSTs completion status to /agent/complete/{call_id} - API stores logs in memory for polling - UI polls /agent/logs/{call_id} every second This replaces the previous approach that tried to use call.get_logs() which doesn't exist on Modal FunctionCall objects. --- docs/src/components/policy-chat.tsx | 571 ++++++++++---------------- src/policyengine_api/agent_sandbox.py | 455 +++++--------------- src/policyengine_api/api/agent.py | 419 +++++++------------ 3 files changed, 455 insertions(+), 990 deletions(-) diff --git a/docs/src/components/policy-chat.tsx b/docs/src/components/policy-chat.tsx index 9e1d3cd..2a9865b 100644 --- a/docs/src/components/policy-chat.tsx +++ b/docs/src/components/policy-chat.tsx @@ -5,37 +5,15 @@ import ReactMarkdown from "react-markdown"; import remarkBreaks from "remark-breaks"; import { useApi } from "./api-context"; -// Types for Claude Code stream-json format -interface StreamEvent { - type: "system" | "assistant" | "user" | "result"; - subtype?: string; - message?: { - content: Array<{ type: string; text?: string; name?: string; input?: unknown }>; - }; - result?: string; - mcp_servers?: Array<{ name: string; status: string }>; - tool_use_result?: string | { stdout?: string }; - total_cost_usd?: number; - duration_ms?: number; -} - interface Message { role: "user" | "assistant"; content: string; - status?: "pending" | "streaming" | "completed" | "failed"; -} - -interface ToolCall { - name: string; - input: unknown; - result?: string; - isExpanded?: boolean; + status?: "pending" | "running" | "completed" | "failed"; } -interface StreamLine { - type: "text" | "tool" | "result" | "error"; - content: string; - timestamp: number; +interface LogEntry { + timestamp: string; + message: string; } export function PolicyChat() { @@ -43,13 +21,10 @@ export function PolicyChat() { const [messages, setMessages] = useState([]); const [input, setInput] = useState(""); const [isLoading, setIsLoading] = useState(false); - const [currentToolCalls, setCurrentToolCalls] = useState([]); - const [mcpConnected, setMcpConnected] = useState(null); - const [displayedContent, setDisplayedContent] = useState(""); - const [isTyping, setIsTyping] = useState(false); - const [streamLines, setStreamLines] = useState([]); + const [logs, setLogs] = useState([]); + const [callId, setCallId] = useState(null); const messagesEndRef = useRef(null); - const fullContentRef = useRef(""); + const pollIntervalRef = useRef(null); const scrollToBottom = () => { messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); @@ -57,25 +32,89 @@ export function PolicyChat() { useEffect(() => { scrollToBottom(); - }, [messages, currentToolCalls, displayedContent]); + }, [messages, logs]); - // Typing animation effect + // Cleanup polling on unmount useEffect(() => { - if (!isTyping || !fullContentRef.current) return; + return () => { + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current); + } + }; + }, []); - const targetContent = fullContentRef.current; - if (displayedContent.length >= targetContent.length) { - setIsTyping(false); - return; - } + const pollLogs = async (id: string) => { + try { + const res = await fetch(`${baseUrl}/agent/logs/${id}`); + if (!res.ok) { + console.error("Failed to fetch logs:", res.status); + return; + } - const charsToAdd = Math.min(3, targetContent.length - displayedContent.length); - const timeout = setTimeout(() => { - setDisplayedContent(targetContent.slice(0, displayedContent.length + charsToAdd)); - }, 10); + const data = await res.json(); + setLogs(data.logs || []); - return () => clearTimeout(timeout); - }, [displayedContent, isTyping]); + // Check if completed or failed + if (data.status === "completed" || data.status === "failed") { + // Stop polling + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current); + pollIntervalRef.current = null; + } + + setIsLoading(false); + setCallId(null); + + // Extract final result from logs or result field + let finalContent = ""; + if (data.result?.result) { + finalContent = data.result.result; + } else { + // Try to extract from logs - look for [CLAUDE] lines with result + const claudeLogs = data.logs + .map((l: LogEntry) => l.message) + .filter((m: string) => m.startsWith("[CLAUDE]")) + .map((m: string) => m.replace("[CLAUDE] ", "")); + + // Try to parse the last few lines for result + for (const log of claudeLogs.reverse()) { + try { + const event = JSON.parse(log); + if (event.type === "result" && event.result) { + finalContent = event.result; + break; + } + } catch { + // Not JSON, skip + } + } + + if (!finalContent) { + finalContent = + data.status === "completed" + ? "Analysis completed. Check logs for details." + : "Analysis failed. Check logs for errors."; + } + } + + // Update assistant message with final content + setMessages((prev) => { + const newMessages = [...prev]; + const lastIndex = newMessages.length - 1; + if (newMessages[lastIndex]?.role === "assistant") { + newMessages[lastIndex] = { + ...newMessages[lastIndex], + content: finalContent, + status: data.status, + }; + } + return newMessages; + }); + } + } catch (err) { + console.error("Error polling logs:", err); + } + }; const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); @@ -84,12 +123,14 @@ export function PolicyChat() { const userMessage = input.trim(); setInput(""); setIsLoading(true); - setCurrentToolCalls([]); - setMcpConnected(null); - setDisplayedContent(""); - setIsTyping(false); - setStreamLines([]); - fullContentRef.current = ""; + setLogs([]); + setCallId(null); + + // Stop any existing polling + if (pollIntervalRef.current) { + clearInterval(pollIntervalRef.current); + pollIntervalRef.current = null; + } // Add user message setMessages((prev) => [...prev, { role: "user", content: userMessage }]); @@ -101,7 +142,8 @@ export function PolicyChat() { ]); try { - const res = await fetch(`${baseUrl}/agent/stream`, { + // Start the agent + const res = await fetch(`${baseUrl}/agent/run`, { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ question: userMessage }), @@ -111,159 +153,30 @@ export function PolicyChat() { throw new Error(`HTTP ${res.status}`); } - const reader = res.body?.getReader(); - if (!reader) throw new Error("No response body"); - - const decoder = new TextDecoder(); - let assistantText = ""; - let finalResult = ""; - const toolCalls: ToolCall[] = []; + const data = await res.json(); + const newCallId = data.call_id; + setCallId(newCallId); - // Update to streaming status + // Update to running status setMessages((prev) => { const newMessages = [...prev]; const lastIndex = newMessages.length - 1; if (newMessages[lastIndex]?.role === "assistant") { newMessages[lastIndex] = { ...newMessages[lastIndex], - status: "streaming", + status: "running", }; } return newMessages; }); - while (true) { - const { done, value } = await reader.read(); - if (done) break; - - const chunk = decoder.decode(value, { stream: true }); - const lines = chunk.split("\n"); + // Start polling for logs + pollIntervalRef.current = setInterval(() => { + pollLogs(newCallId); + }, 1000); - for (const line of lines) { - if (line.startsWith("data: ")) { - try { - const outerData = JSON.parse(line.slice(6)); - - if (outerData.type === "output" && outerData.content) { - const event: StreamEvent = JSON.parse(outerData.content); - - // Handle system init - check MCP connection - if (event.type === "system" && event.subtype === "init") { - const mcpServer = event.mcp_servers?.find( - (s) => s.name === "policyengine" - ); - setMcpConnected(mcpServer?.status === "connected"); - } - - // Handle assistant messages - if (event.type === "assistant" && event.message?.content) { - for (const item of event.message.content) { - if (item.type === "text" && item.text) { - assistantText += item.text + "\n"; - fullContentRef.current = assistantText.trim(); - setIsTyping(true); - // Add to stream lines - setStreamLines((prev) => [ - ...prev, - { type: "text", content: item.text!, timestamp: Date.now() }, - ]); - setMessages((prev) => { - const newMessages = [...prev]; - const lastIndex = newMessages.length - 1; - if (newMessages[lastIndex]?.role === "assistant") { - newMessages[lastIndex] = { - ...newMessages[lastIndex], - content: assistantText.trim(), - }; - } - return newMessages; - }); - } else if (item.type === "tool_use" && item.name) { - const toolCall: ToolCall = { - name: item.name, - input: item.input, - isExpanded: false, - }; - toolCalls.push(toolCall); - setCurrentToolCalls([...toolCalls]); - // Add tool use to stream lines - setStreamLines((prev) => [ - ...prev, - { type: "tool", content: item.name!, timestamp: Date.now() }, - ]); - } - } - } - - // Handle tool results - if (event.type === "user" && event.tool_use_result) { - if (toolCalls.length > 0) { - const result = - typeof event.tool_use_result === "string" - ? event.tool_use_result - : event.tool_use_result.stdout || ""; - toolCalls[toolCalls.length - 1].result = result; - setCurrentToolCalls([...toolCalls]); - } - } - - // Handle final result - if (event.type === "result" && event.result) { - finalResult = event.result; - fullContentRef.current = finalResult; - setIsTyping(true); - setMessages((prev) => { - const newMessages = [...prev]; - const lastIndex = newMessages.length - 1; - if (newMessages[lastIndex]?.role === "assistant") { - newMessages[lastIndex] = { - ...newMessages[lastIndex], - content: finalResult, - status: "completed", - }; - } - return newMessages; - }); - } - } else if (outerData.type === "error") { - setMessages((prev) => { - const newMessages = [...prev]; - const lastIndex = newMessages.length - 1; - if (newMessages[lastIndex]?.role === "assistant") { - newMessages[lastIndex] = { - role: "assistant", - content: `Error: ${outerData.content}`, - status: "failed", - }; - } - return newMessages; - }); - } else if (outerData.type === "done") { - setCurrentToolCalls([]); - setIsTyping(false); - setDisplayedContent(fullContentRef.current); - if (outerData.returncode !== 0) { - setMessages((prev) => { - const newMessages = [...prev]; - const lastIndex = newMessages.length - 1; - if (newMessages[lastIndex]?.role === "assistant") { - newMessages[lastIndex] = { - ...newMessages[lastIndex], - status: "failed", - }; - } - return newMessages; - }); - } - } - } catch { - // Ignore parse errors for incomplete chunks - } - } - } - } - - setIsLoading(false); + // Initial poll + pollLogs(newCallId); } catch (err) { setMessages((prev) => { const newMessages = [...prev]; @@ -281,45 +194,59 @@ export function PolicyChat() { } }; - const toggleToolExpanded = (index: number) => { - setCurrentToolCalls((prev) => - prev.map((t, i) => - i === index ? { ...t, isExpanded: !t.isExpanded } : t - ) - ); + // Parse log message to extract useful info + const parseLogMessage = (message: string): { type: string; content: string } => { + if (message.startsWith("[AGENT]")) { + return { type: "agent", content: message.replace("[AGENT] ", "") }; + } + if (message.startsWith("[CLAUDE]")) { + const claudeContent = message.replace("[CLAUDE] ", ""); + // Try to parse as JSON + try { + const event = JSON.parse(claudeContent); + if (event.type === "assistant" && event.message?.content) { + const textParts = event.message.content + .filter((c: { type: string }) => c.type === "text") + .map((c: { text: string }) => c.text) + .join(""); + if (textParts) { + return { type: "text", content: textParts }; + } + const toolParts = event.message.content + .filter((c: { type: string }) => c.type === "tool_use") + .map((c: { name: string }) => c.name); + if (toolParts.length > 0) { + return { type: "tool", content: `Using: ${toolParts.join(", ")}` }; + } + } + if (event.type === "system" && event.subtype === "init") { + const mcpStatus = event.mcp_servers?.find( + (s: { name: string }) => s.name === "policyengine" + ); + return { + type: "system", + content: mcpStatus?.status === "connected" ? "MCP connected" : "Starting...", + }; + } + if (event.type === "result") { + return { type: "result", content: "Analysis complete" }; + } + return { type: "claude", content: `[${event.type || "event"}]` }; + } catch { + return { type: "claude", content: claudeContent.slice(0, 100) }; + } + } + return { type: "log", content: message.slice(0, 100) }; }; const exampleQuestions = [ - // UK tax questions "How much would it cost to set the UK basic income tax rate to 19p?", "What would happen if we doubled child benefit?", - "How would a £15,000 personal allowance affect different income groups?", + "Calculate tax for a UK household earning 50,000", "What is the budgetary impact of abolishing the higher rate of income tax?", - "How much does universal credit cost the government?", - // US tax questions - "What would a $2,000 child tax credit cost in the US?", - "How would doubling SNAP benefits affect poverty rates?", - "What is the revenue impact of a 25% top marginal tax rate?", - // Household calculations - "Calculate tax for a UK household earning £50,000", "What benefits would a single parent with two children receive in California?", - "How much income tax does someone earning $100,000 in New York pay?", ]; - const formatToolName = (name: string) => { - return name - .replace("mcp__policyengine__", "") - .replace(/_/g, " ") - .replace(/\b\w/g, (c) => c.toUpperCase()); - }; - - const getDisplayContent = (message: Message, isLastMessage: boolean) => { - if (isLastMessage && isTyping) { - return displayedContent; - } - return message.content; - }; - return (
{/* Header */} @@ -327,22 +254,14 @@ export function PolicyChat() {
Policy analyst - {mcpConnected === true - ? "MCP connected" - : mcpConnected === false - ? "MCP failed" - : "Powered by Claude Code"} + Powered by Claude Code + MCP

@@ -371,80 +290,69 @@ export function PolicyChat() {

)} - {messages.map((message, i) => { - const isLastMessage = i === messages.length - 1; - const content = getDisplayContent(message, isLastMessage); - - return ( + {messages.map((message, i) => ( +
-
- {message.role === "assistant" && message.status === "pending" ? ( -
-
- Starting Claude Code... -
- ) : message.role === "assistant" && message.status === "streaming" ? ( -
- {content ? ( -
- {content} -
- ) : currentToolCalls.length === 0 ? ( - - Thinking... - - ) : null} - {isLastMessage && isTyping && ( - - )} -
- ) : message.status === "completed" ? ( -
- {content} - {isLastMessage && isTyping && ( - - )} -
- ) : ( -
{content}
- )} -
+ {message.role === "assistant" && + (message.status === "pending" || message.status === "running") ? ( +
+
+ + {message.status === "pending" ? "Starting..." : "Analysing..."} + +
+ ) : message.status === "completed" || message.status === "failed" ? ( +
+ + {message.content} + +
+ ) : ( +
{message.content}
+ )}
- ); - })} +
+ ))} - {/* Live streaming log */} - {isLoading && streamLines.length > 0 && ( -
-
- Live output + {/* Live logs */} + {isLoading && logs.length > 0 && ( +
+
+ Live output ({logs.length} entries)
- {streamLines.map((line, i) => ( -
- {">"} - - {line.type === "tool" ? `[Using: ${line.content}]` : line.content} - -
- ))} + {logs.slice(-30).map((log, i) => { + const parsed = parseLogMessage(log.message); + return ( +
+ + {">"} + + {parsed.content} +
+ ); + })}
{">"} @@ -452,63 +360,6 @@ export function PolicyChat() {
)} - {/* Live tool calls */} - {currentToolCalls.length > 0 && ( -
-
- API calls -
- {currentToolCalls.map((tool, i) => ( -
- - {tool.isExpanded && ( -
- {tool.input !== undefined && tool.input !== null && ( -
-
- Input: -
-
-                          {JSON.stringify(tool.input, null, 2)}
-                        
-
- )} - {tool.result && ( -
-
- Result: -
-
-                          {tool.result.slice(0, 500)}
-                          {tool.result.length > 500 && "..."}
-                        
-
- )} -
- )} -
- ))} -
- )} -
diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index f6edf37..dc6ba72 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -1,19 +1,20 @@ """Modal Sandbox for running Claude Code with PolicyEngine MCP server. -This runs the actual Claude Code CLI in an isolated sandbox, connected -to the PolicyEngine API via MCP. Outputs are streamed back in real-time. +This runs the Claude Code CLI connected to the PolicyEngine API via MCP. +Logs are POSTed back to the API for real-time streaming to the UI. """ import json +import subprocess import modal -from modal.stream_type import StreamType +import requests -# Sandbox image with Bun and Claude Code CLI (v3 - with stdbuf for unbuffered output) +# Sandbox image with Bun and Claude Code CLI sandbox_image = ( modal.Image.debian_slim(python_version="3.12") - .apt_install("curl", "git", "unzip", "coreutils") # coreutils provides stdbuf - .pip_install("logfire") + .apt_install("curl", "git", "unzip") + .pip_install("requests") .run_commands( # Install Bun "curl -fsSL https://bun.sh/install | bash", @@ -21,16 +22,16 @@ "export BUN_INSTALL=/root/.bun && export PATH=$BUN_INSTALL/bin:$PATH && " "ln -s $BUN_INSTALL/bin/bun /usr/local/bin/node && " "bun install -g @anthropic-ai/claude-code", - # Pre-accept ToS and configure for non-interactive use (v2) + # Pre-accept ToS and configure for non-interactive use "mkdir -p /root/.claude && " 'echo \'{"hasCompletedOnboarding": true, "hasAcknowledgedCostThreshold": true}\' ' - "> /root/.claude/settings.json && cat /root/.claude/settings.json", + "> /root/.claude/settings.json", ) .env( { "BUN_INSTALL": "/root/.bun", "PATH": "/root/.bun/bin:/usr/local/bin:/usr/bin:/bin", - "CLAUDE_CODE_SKIP_ONBOARDING": "1", # Cache bust + extra safety + "CLAUDE_CODE_SKIP_ONBOARDING": "1", } ) ) @@ -39,286 +40,81 @@ # Secrets anthropic_secret = modal.Secret.from_name("anthropic-api-key") -logfire_secret = modal.Secret.from_name("logfire-token") -async def run_claude_code_in_sandbox_async( +def post_log(api_base_url: str, call_id: str, message: str) -> None: + """POST a log entry to the API.""" + try: + requests.post( + f"{api_base_url}/agent/log/{call_id}", + json={"message": message}, + timeout=5, + ) + except Exception: + pass # Don't fail on log errors + + +def post_complete(api_base_url: str, call_id: str, result: dict) -> None: + """POST completion status to the API.""" + try: + requests.post( + f"{api_base_url}/agent/complete/{call_id}", + json=result, + timeout=10, + ) + except Exception: + pass + + +@app.function(image=sandbox_image, secrets=[anthropic_secret], timeout=600) +def run_agent( question: str, api_base_url: str = "https://v2.api.policyengine.org", -) -> tuple[modal.Sandbox, any]: - """Create a sandbox running Claude Code with MCP server configured. + call_id: str = "", +) -> dict: + """Run Claude Code with MCP server to answer a policy question. - Returns the sandbox and process handle for streaming output. - Uses Modal's async API for proper streaming support. + Logs are POSTed back to the API for real-time streaming. """ - import logfire - logfire.info( - "run_claude_code_in_sandbox: starting", - question=question[:100], - api_base_url=api_base_url, - ) + def log(msg: str) -> None: + print(msg) # Also print for debugging + if call_id: + post_log(api_base_url, call_id, msg) - # MCP config for Claude Code (type: sse for HTTP SSE transport) + log(f"[AGENT] Starting analysis for: {question[:200]}") + log(f"[AGENT] API URL: {api_base_url}") + log(f"[AGENT] MCP endpoint: {api_base_url}/mcp") + + # MCP config for Claude Code - connects to PolicyEngine API's MCP server mcp_config = { "mcpServers": {"policyengine": {"type": "sse", "url": f"{api_base_url}/mcp"}} } mcp_config_json = json.dumps(mcp_config) - # Get reference to deployed app (required when calling from outside Modal) - logfire.info("run_claude_code_in_sandbox: looking up Modal app") - sandbox_app = modal.App.lookup("policyengine-sandbox", create_if_missing=True) - logfire.info("run_claude_code_in_sandbox: Modal app found") - - logfire.info("run_claude_code_in_sandbox: creating sandbox") - sb = await modal.Sandbox.create.aio( - app=sandbox_app, - image=sandbox_image, - secrets=[anthropic_secret, logfire_secret], - timeout=600, - workdir="/tmp", - ) - logfire.info("run_claude_code_in_sandbox: sandbox created") - - # Escape the question and config for shell - escaped_question = question.replace("'", "'\"'\"'") - escaped_mcp_config = mcp_config_json.replace("'", "'\"'\"'") - # CRITICAL: < /dev/null closes stdin (otherwise Claude hangs waiting for input) - # 2>&1 merges stderr into stdout for unified streaming - # stdbuf -oL forces line-buffered stdout to prevent libc buffering - cmd = ( - f"stdbuf -oL claude -p '{escaped_question}' " - f"--mcp-config '{escaped_mcp_config}' " - "--output-format stream-json --verbose --max-turns 10 " - "--allowedTools 'mcp__policyengine__*,Bash,Read,Grep,Glob,Write,Edit' " - "< /dev/null 2>&1" - ) - logfire.info( - "run_claude_code_in_sandbox: executing", - cmd=cmd[:500], - question_len=len(question), - escaped_question_len=len(escaped_question), - ) - # Use async exec for proper streaming - # stdout=StreamType.PIPE allows us to consume the stream (default but explicit) - process = await sb.exec.aio( - "sh", - "-c", - cmd, - text=True, - bufsize=1, - stdout=StreamType.PIPE, - ) - logfire.info("run_claude_code_in_sandbox: claude CLI process started, returning.") - - return sb, process - - -def _get_api_system_prompt(api_base_url: str) -> str: - """Generate system prompt with PolicyEngine API documentation.""" - return f"""You are a PolicyEngine policy analyst. Answer questions about tax and benefit policy using the PolicyEngine API at {api_base_url}. - -Use curl to call the API. All responses are JSON. - -## WORKFLOW 1: Household calculation (single family taxes/benefits) - -Step 1: Calculate household taxes/benefits -```bash -curl -X POST {api_base_url}/household/calculate \\ - -H "Content-Type: application/json" \\ - -d '{{ - "tax_benefit_model_name": "policyengine_uk", - "people": [{{"employment_income": 50000, "age": 35}}], - "household": {{}}, - "year": 2026 - }}' -``` -Returns: {{"job_id": "uuid", "status": "pending"}} - -Step 2: Poll until status="completed" -```bash -curl {api_base_url}/household/calculate/{{job_id}} -``` -Returns: {{"status": "completed", "result": {{"person": [...], "household": {{...}}}}}} - -### UK household example: -```json -{{ - "tax_benefit_model_name": "policyengine_uk", - "people": [{{"employment_income": 50000, "age": 35}}], - "benunit": {{}}, - "household": {{}}, - "year": 2026 -}} -``` - -### US household example: -```json -{{ - "tax_benefit_model_name": "policyengine_us", - "people": [{{"employment_income": 70000, "age": 40}}], - "tax_unit": {{"state_code": "CA"}}, - "household": {{"state_fips": 6}}, - "year": 2024 -}} -``` - -IMPORTANT: Use FLAT values like {{"employment_income": 50000}}, NOT time-period format like {{"employment_income": {{"2024": 50000}}}}. - -## WORKFLOW 2: Economic impact analysis (budgetary/distributional effects) - -This workflow analyses how a policy reform affects the whole economy. - -Step 1: Search for the parameter you want to change -```bash -curl "{api_base_url}/parameters?search=basic_rate" -``` -Look for the parameter with a name like "gov.hmrc.income_tax.rates.uk[0].rate" and note its "id" field. - -Step 2: Get dataset ID for the country -```bash -curl {api_base_url}/datasets -``` -For UK, find the "enhanced_frs" dataset. For US, find "enhanced_cps". Note the "id" field. - -Step 3: Create a policy reform -```bash -curl -X POST {api_base_url}/policies \\ - -H "Content-Type: application/json" \\ - -d '{{ - "name": "Lower basic rate to 16p", - "description": "Reduce UK basic income tax rate from 20p to 16p", - "parameter_values": [ - {{ - "parameter_id": "", - "value_json": 0.16, - "start_date": "2026-01-01T00:00:00Z", - "end_date": null - }} + log(f"[AGENT] MCP config: {mcp_config_json}") + + # Build command + cmd = [ + "claude", + "-p", + question, + "--mcp-config", + mcp_config_json, + "--output-format", + "stream-json", + "--verbose", + "--max-turns", + "15", + "--allowedTools", + "mcp__policyengine__*,Bash,WebFetch,Read,Write,Edit", ] - }}' -``` -Returns: {{"id": "policy-uuid", ...}} - -Step 4: Run economic impact analysis -```bash -curl -X POST {api_base_url}/analysis/economic-impact \\ - -H "Content-Type: application/json" \\ - -d '{{ - "tax_benefit_model_name": "policyengine_uk", - "dataset_id": "", - "policy_id": "" - }}' -``` -Returns: {{"report_id": "uuid", "status": "pending", ...}} - -Step 5: Poll until status="completed" -```bash -curl {api_base_url}/analysis/economic-impact/{{report_id}} -``` -Returns: {{"status": "completed", "decile_impacts": [...], "program_statistics": [...]}} - -## WORKFLOW 3: Household impact comparison (baseline vs reform) - -Compare how a specific household is affected by a policy reform. - -Step 1: Create a policy reform (same as Workflow 2, Step 3) - -Step 2: Run household impact comparison -```bash -curl -X POST {api_base_url}/household/impact \\ - -H "Content-Type: application/json" \\ - -d '{{ - "tax_benefit_model_name": "policyengine_uk", - "people": [{{"employment_income": 50000, "age": 35}}], - "household": {{}}, - "year": 2026, - "policy_id": "" - }}' -``` - -Step 3: Poll until status="completed" -```bash -curl {api_base_url}/household/impact/{{job_id}} -``` -Returns: {{"baseline_result": {{...}}, "reform_result": {{...}}, "impact": {{...}}}} - -## API REFERENCE - -### Parameters (policy levers that can be changed) -- GET /parameters?search= - search by name/label/description -- GET /parameters/{{id}} - get parameter details - -Common UK parameters: -- "gov.hmrc.income_tax.rates.uk[0].rate" - basic rate (currently 0.20) -- "gov.hmrc.income_tax.rates.uk[1].rate" - higher rate (currently 0.40) -- "gov.hmrc.income_tax.allowances.personal_allowance.amount" - personal allowance -- "gov.dwp.child_benefit.weekly.eldest" - child benefit eldest child weekly amount -- "gov.dwp.universal_credit.elements.standard_allowance.single_young" - UC standard allowance - -Common US parameters: -- "gov.irs.income.bracket.rates" - federal income tax rates -- "gov.irs.credits.ctc.amount.base" - child tax credit amount - -### Variables (computed values like income_tax, net_income) -- GET /variables?search= - search variables -- GET /variables/{{id}} - get variable details - -Common variables: income_tax, national_insurance, universal_credit, child_benefit, net_income, household_net_income - -### Datasets (population microdata for economic analysis) -- GET /datasets - list all datasets - -UK dataset: Look for name containing "enhanced_frs" -US dataset: Look for name containing "enhanced_cps" - -### Policies (reform specifications) -- POST /policies - create policy reform -- GET /policies - list all policies -- GET /policies/{{id}} - get policy details - -## TIPS - -1. Always search for parameters FIRST before creating policies -2. Use the exact parameter_id (UUID) from the search results -3. Poll async endpoints until status="completed" (may take 10-60 seconds) -4. For UK, use year=2026; for US, use year=2024 -5. The result contains calculated values for all variables (income_tax, net_income, etc.) -6. Economic impact takes longer (30-120 seconds) as it simulates the full population""" - - -@app.function(image=sandbox_image, secrets=[anthropic_secret], timeout=600) -def stream_policy_analysis( - question: str, api_base_url: str = "https://v2.api.policyengine.org" -): - """Stream Claude Code output line by line. - - Uses direct API calls instead of MCP (MCP doesn't work in Modal containers). - Claude is given a system prompt explaining how to use the PolicyEngine API. - """ - import subprocess - system_prompt = _get_api_system_prompt(api_base_url) + log(f"[AGENT] Running: {' '.join(cmd[:5])}...") - print(f"[MODAL] Starting Claude Code (streaming) for question: {question[:100]}") - - # Use Popen for streaming output - no MCP, use system prompt instead - # stdin=DEVNULL prevents Claude from waiting for input (critical!) + # Run Claude Code - stdin=DEVNULL prevents hanging process = subprocess.Popen( - [ - "claude", - "-p", - question, - "--system-prompt", - system_prompt, - "--output-format", - "stream-json", - "--verbose", - "--max-turns", - "10", - "--allowedTools", - "Bash,WebFetch,Read,Write,Edit", - ], + cmd, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, @@ -326,88 +122,40 @@ def stream_policy_analysis( bufsize=1, # Line buffered ) - # Yield each line as it comes + # Stream output + full_output = [] + final_result = None + for line in process.stdout: - if line.strip(): - print(f"[MODAL] Claude output: {line[:100]}") - yield line + line = line.rstrip() + if line: + log(f"[CLAUDE] {line}") + full_output.append(line) + + # Try to parse stream-json events + try: + event = json.loads(line) + # Capture the final result + if event.get("type") == "result": + final_result = event.get("result", "") + except json.JSONDecodeError: + pass process.wait() - print(f"[MODAL] Claude Code finished with returncode: {process.returncode}") + log(f"[AGENT] Claude exited with code: {process.returncode}") + result = { + "status": "completed" if process.returncode == 0 else "failed", + "result": final_result, + "returncode": process.returncode, + "output_lines": len(full_output), + } -@app.function( - image=sandbox_image, secrets=[anthropic_secret, logfire_secret], timeout=600 -) -def run_policy_analysis( - question: str, api_base_url: str = "https://v2.api.policyengine.org" -) -> dict: - """Run Claude Code to answer a policy question. - - This is the non-streaming version that returns the full result. - """ - import os - import subprocess - - import logfire - - # Only configure logfire if token is available - if os.environ.get("LOGFIRE_TOKEN"): - logfire.configure( - service_name="policyengine-agent-sandbox", - token=os.environ["LOGFIRE_TOKEN"], - ) - - with logfire.span( - "run_policy_analysis", question=question[:100], api_base_url=api_base_url - ): - # MCP config for Claude Code (type: sse for HTTP SSE transport) - mcp_config = { - "mcpServers": { - "policyengine": {"type": "sse", "url": f"{api_base_url}/mcp"} - } - } - mcp_config_json = json.dumps(mcp_config) - - logfire.info( - "Starting Claude Code", - question=question[:100], - mcp_url=f"{api_base_url}/mcp", - ) - - # Run Claude Code with --mcp-config (no --dangerously-skip-permissions as root) - result = subprocess.run( - [ - "claude", - "-p", - question, - "--mcp-config", - mcp_config_json, - "--max-turns", - "10", - "--allowedTools", - "mcp__policyengine__*,Bash,Read,Grep,Glob,Write,Edit", - ], - capture_output=True, - text=True, - timeout=540, - ) - - logfire.info( - "Claude Code finished", - returncode=result.returncode, - stdout_len=len(result.stdout), - stderr_len=len(result.stderr), - ) - - if result.returncode != 0: - logfire.error("Claude Code failed", stderr=result.stderr[:500]) + # Notify API of completion + if call_id: + post_complete(api_base_url, call_id, result) - return { - "status": "completed" if result.returncode == 0 else "failed", - "report": result.stdout, - "error": result.stderr if result.returncode != 0 else None, - } + return result # For local testing @@ -415,9 +163,7 @@ def run_policy_analysis( import sys question = ( - sys.argv[1] - if len(sys.argv) > 1 - else "How much would it cost to set the UK basic rate to 19p?" + sys.argv[1] if len(sys.argv) > 1 else "What is the UK basic rate of income tax?" ) print(f"Question: {question}\n") @@ -425,7 +171,6 @@ def run_policy_analysis( # Run via Modal with modal.enable_local(): - result = run_policy_analysis.local(question) - print(result["report"]) - if result["error"]: - print(f"\nError: {result['error']}") + result = run_agent.local(question) + print("\n" + "=" * 60) + print(f"Result: {result}") diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index 16d51b6..33a4f21 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -2,17 +2,16 @@ This endpoint lets users ask natural language questions about tax/benefit policy and get AI-generated reports using Claude Code connected to the PolicyEngine MCP server. -Outputs are streamed back in real-time via SSE. + +The agent runs in a Modal sandbox and logs are fetched via Modal SDK. """ -import asyncio -import json -import os -from uuid import uuid4 +import uuid +from datetime import datetime import logfire +import modal from fastapi import APIRouter, HTTPException -from fastapi.responses import StreamingResponse from pydantic import BaseModel from policyengine_api.config import settings @@ -20,308 +19,178 @@ router = APIRouter(prefix="/agent", tags=["agent"]) -class AskRequest(BaseModel): - """Request to ask a policy question.""" +class RunRequest(BaseModel): + """Request to run the agent.""" question: str -class AskResponse(BaseModel): - """Response with job ID for polling.""" +class RunResponse(BaseModel): + """Response with function call ID for fetching logs.""" - job_id: str + call_id: str status: str -class JobStatusResponse(BaseModel): - """Status of an agent job.""" +class LogEntry(BaseModel): + """A single log entry.""" + + timestamp: str + message: str + + +class LogsResponse(BaseModel): + """Response with logs for a function call.""" + + call_id: str + status: str # "running", "completed", "failed" + logs: list[LogEntry] + result: dict | None = None + - job_id: str +class LogInput(BaseModel): + """Input for logging an entry.""" + + message: str + + +class StatusResponse(BaseModel): + """Response with job status.""" + + call_id: str status: str - report: str | None = None - error: str | None = None + result: dict | None = None -# In-memory job storage -_jobs: dict[str, dict] = {} +# In-memory storage for function calls and their logs +_calls: dict[str, dict] = {} +_logs: dict[str, list[LogEntry]] = {} -async def _stream_claude_code(question: str, api_base_url: str): - """Stream output from Claude Code running with MCP server.""" - # MCP config as JSON string (type: sse for HTTP SSE transport) - mcp_config = json.dumps( - {"mcpServers": {"policyengine": {"type": "sse", "url": f"{api_base_url}/mcp"}}} - ) +@router.post("/run", response_model=RunResponse) +async def run_agent(request: RunRequest) -> RunResponse: + """Start the agent to answer a policy question. - # Run Claude Code with streaming JSON output for realtime updates - process = await asyncio.create_subprocess_exec( - "claude", - "-p", - question, - "--output-format", - "stream-json", - "--verbose", - "--mcp-config", - mcp_config, - "--allowedTools", - "mcp__policyengine__*,Bash,Read,Grep,Glob,Write,Edit", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env={**os.environ, "ANTHROPIC_API_KEY": settings.anthropic_api_key}, - ) + Returns a call_id that can be used to fetch logs and status. - # Stream stdout - async for line in process.stdout: - text = line.decode("utf-8") - yield f"data: {json.dumps({'type': 'output', 'content': text})}\n\n" + Example: + ```bash + curl -X POST https://v2.api.policyengine.org/agent/run \\ + -H "Content-Type: application/json" \\ + -d '{"question": "What is the UK basic rate of income tax?"}' + ``` + + Response: + ```json + {"call_id": "fc-abc123", "status": "running"} + ``` + + Then poll /agent/logs/{call_id} to get logs and final result. + """ + logfire.info("agent_run", question=request.question[:100]) - # Wait for completion - await process.wait() + api_base_url = settings.policyengine_api_url + + # Look up the deployed function + run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent") + + # Generate a call_id before spawning so we can pass it to the function + call_id = f"fc-{uuid.uuid4().hex[:24]}" - if process.returncode != 0: - stderr = await process.stderr.read() - yield f"data: {json.dumps({'type': 'error', 'content': stderr.decode('utf-8')})}\n\n" + # Initialize logs storage + _logs[call_id] = [] - yield f"data: {json.dumps({'type': 'done', 'returncode': process.returncode})}\n\n" + # Spawn the function (non-blocking) - pass call_id so it can POST logs back + call = run_fn.spawn(request.question, api_base_url, call_id) + + # Store call info + _calls[call_id] = { + "call": call, + "modal_call_id": call.object_id, + "question": request.question, + "started_at": datetime.utcnow().isoformat(), + "status": "running", + "result": None, + } + logfire.info("agent_spawned", call_id=call_id, modal_call_id=call.object_id) -def _parse_claude_stream_event(line: str) -> dict | None: - """Parse a Claude Code stream-json event and extract useful content. + return RunResponse(call_id=call_id, status="running") - Returns a dict with 'type' and 'content' for streaming to client, - or None if the event should be skipped. + +@router.post("/log/{call_id}") +async def post_log(call_id: str, log_input: LogInput) -> dict: + """Receive a log entry from the running agent. + + This endpoint is called by the Modal function to stream logs back. """ - if not line or not line.strip(): - return None - - try: - event = json.loads(line) - except json.JSONDecodeError: - # Not JSON, pass through as raw output - return {"type": "raw", "content": line} - - event_type = event.get("type") - - # Assistant text output (the main response) - if event_type == "assistant": - message = event.get("message", {}) - content_blocks = message.get("content", []) - text_parts = [] - for block in content_blocks: - if block.get("type") == "text": - text_parts.append(block.get("text", "")) - elif block.get("type") == "tool_use": - tool_name = block.get("name", "unknown") - text_parts.append(f"[Using tool: {tool_name}]") - if text_parts: - return {"type": "assistant", "content": "".join(text_parts)} - - # Content block delta (streaming text chunks) - elif event_type == "content_block_delta": - delta = event.get("delta", {}) - if delta.get("type") == "text_delta": - text = delta.get("text", "") - if text: - return {"type": "text", "content": text} - - # Tool use events - elif event_type == "tool_use": - tool_name = event.get("name", "unknown") - return {"type": "tool", "content": f"Using tool: {tool_name}"} - - # Tool result - elif event_type == "tool_result": - content = event.get("content", "") - if isinstance(content, str) and content: - # Truncate long tool results - preview = content[:500] + "..." if len(content) > 500 else content - return {"type": "tool_result", "content": preview} - - # Result/completion - elif event_type == "result": - result_text = event.get("result", "") - if result_text: - return {"type": "result", "content": result_text} - - # System messages - elif event_type == "system": - msg = event.get("message", "") - if msg: - return {"type": "system", "content": msg} - - return None - - -async def _stream_modal_function(question: str, api_base_url: str): - """Stream output from Claude Code running in a Modal function. - - Uses Modal's generator function for streaming, which runs Claude via subprocess - directly in the Modal container (avoiding the sandbox exec MCP issue). + if call_id not in _logs: + _logs[call_id] = [] + + entry = LogEntry( + timestamp=datetime.utcnow().isoformat(), + message=log_input.message, + ) + _logs[call_id].append(entry) + + return {"status": "ok"} + + +@router.post("/complete/{call_id}") +async def complete_call(call_id: str, result: dict) -> dict: + """Mark a call as complete with its result. + + Called by the Modal function when it finishes. """ - import modal - - with logfire.span( - "agent_stream", question=question[:100], api_base_url=api_base_url - ): - try: - # Look up the deployed streaming function - stream_fn = modal.Function.from_name( - "policyengine-sandbox", "stream_policy_analysis" - ) - logfire.info("modal_function_found") - - # Iterate over the generator output - lines_received = 0 - events_sent = 0 - - # Use Modal's async generator API to avoid blocking the event loop - async for line in stream_fn.remote_gen.aio(question, api_base_url): - lines_received += 1 - print(f"[CLAUDE] {line[:300]}", flush=True) - logfire.info( - "raw_line", - line_num=lines_received, - line_len=len(line) if line else 0, - ) - # Send raw Claude Code output wrapped in 'output' event - # The frontend expects this format to parse the stream-json - if line and line.strip(): - events_sent += 1 - yield f"data: {json.dumps({'type': 'output', 'content': line})}\n\n" - - logfire.info( - "complete", - events_sent=events_sent, - lines_received=lines_received, - ) - yield f"data: {json.dumps({'type': 'done', 'returncode': 0})}\n\n" - - except Exception as e: - logfire.exception("failed", error=str(e)) - yield f"data: {json.dumps({'type': 'error', 'content': f'Modal error: {str(e)}'})}\n\n" - yield f"data: {json.dumps({'type': 'done', 'returncode': 1})}\n\n" - - -@router.post("/stream") -async def stream_analysis(request: AskRequest): - """Stream a policy analysis using Claude Code with MCP. - - Returns a Server-Sent Events stream with real-time output from Claude Code. - - Event types: - - output: A line of output from Claude Code - - error: An error message - - done: Analysis complete (includes returncode) + if call_id in _calls: + _calls[call_id]["status"] = result.get("status", "completed") + _calls[call_id]["result"] = result - Example: - ``` - data: {"type": "output", "content": "Searching for basic rate parameter...\\n"} + return {"status": "ok"} + + +@router.get("/logs/{call_id}", response_model=LogsResponse) +async def get_logs(call_id: str) -> LogsResponse: + """Get logs for an agent run. - data: {"type": "output", "content": "Found parameter: gov.hmrc.income_tax.rates.uk[0].rate\\n"} + Returns all logs emitted so far, plus status and result if completed. - data: {"type": "done", "returncode": 0} + Example: + ```bash + curl https://v2.api.policyengine.org/agent/logs/fc-abc123 ``` """ - print(f"[AGENT] /stream called, use_modal={settings.agent_use_modal}", flush=True) - api_base_url = settings.policyengine_api_url - logfire.info( - "stream_analysis: called", - question=request.question[:100], - agent_use_modal=settings.agent_use_modal, - api_base_url=api_base_url, + logfire.info("agent_get_logs", call_id=call_id) + + if call_id not in _calls: + raise HTTPException(status_code=404, detail="Call not found") + + call_info = _calls[call_id] + logs = _logs.get(call_id, []) + + return LogsResponse( + call_id=call_id, + status=call_info["status"], + logs=logs, + result=call_info["result"], ) - # SSE headers to prevent buffering by proxies (nginx, Cloud Run) - sse_headers = { - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # Disable nginx buffering - } - if settings.agent_use_modal: - return StreamingResponse( - _stream_modal_function(request.question, api_base_url), - media_type="text/event-stream", - headers=sse_headers, - ) - else: - return StreamingResponse( - _stream_claude_code(request.question, api_base_url), - media_type="text/event-stream", - headers=sse_headers, - ) - - -@router.post("/ask", response_model=AskResponse) -async def ask_question(request: AskRequest) -> AskResponse: - """Ask a policy question (non-streaming). - - Starts the analysis in the background. Poll GET /agent/status/{job_id} for results. - For real-time streaming, use POST /agent/stream instead. +@router.get("/status/{call_id}", response_model=StatusResponse) +async def get_status(call_id: str) -> StatusResponse: + """Get just the status of an agent run (no logs). + + Faster than /logs if you just need to check if it's done. """ - job_id = str(uuid4()) - api_base_url = settings.policyengine_api_url + logfire.info("agent_get_status", call_id=call_id) - _jobs[job_id] = { - "status": "pending", - "question": request.question, - "report": None, - "error": None, - } + if call_id not in _calls: + raise HTTPException(status_code=404, detail="Call not found") + + call_info = _calls[call_id] - # Run in background - async def run_job(): - _jobs[job_id]["status"] = "running" - try: - if settings.agent_use_modal: - import modal - - run_policy_analysis = modal.Function.lookup( - "policyengine-sandbox", "run_policy_analysis" - ) - result = run_policy_analysis.remote(request.question, api_base_url) - else: - # Run locally - process = await asyncio.create_subprocess_exec( - "claude", - "-p", - request.question, - "--allowedTools", - "mcp__policyengine__*,Bash,Read,Grep,Glob,Write,Edit", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env={**os.environ, "ANTHROPIC_API_KEY": settings.anthropic_api_key}, - ) - stdout, stderr = await process.communicate() - result = { - "status": "completed" if process.returncode == 0 else "failed", - "report": stdout.decode("utf-8"), - "error": stderr.decode("utf-8") - if process.returncode != 0 - else None, - } - - _jobs[job_id]["status"] = result.get("status", "completed") - _jobs[job_id]["report"] = result.get("report") - _jobs[job_id]["error"] = result.get("error") - except Exception as e: - _jobs[job_id]["status"] = "failed" - _jobs[job_id]["error"] = str(e) - - asyncio.create_task(run_job()) - return AskResponse(job_id=job_id, status="pending") - - -@router.get("/status/{job_id}", response_model=JobStatusResponse) -async def get_job_status(job_id: str) -> JobStatusResponse: - """Get the status of an agent job.""" - if job_id not in _jobs: - raise HTTPException(status_code=404, detail="Job not found") - - job = _jobs[job_id] - return JobStatusResponse( - job_id=job_id, - status=job["status"], - report=job.get("report"), - error=job.get("error"), + return StatusResponse( + call_id=call_id, + status=call_info["status"], + result=call_info["result"], )