Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 64 additions & 10 deletions src/policyengine_api/agent_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,42 @@
import requests

image = modal.Image.debian_slim(python_version="3.12").pip_install(
"anthropic", "requests"
"anthropic", "requests", "logfire[httpx]"
)

app = modal.App("policyengine-sandbox")
anthropic_secret = modal.Secret.from_name("anthropic-api-key")
logfire_secrets = modal.Secret.from_name("policyengine-logfire")


def configure_logfire(traceparent: str | None = None):
"""Configure logfire with optional trace context propagation."""
import os

import logfire

token = os.environ.get("LOGFIRE_TOKEN", "")
if not token:
return None

logfire.configure(
service_name="policyengine-agent",
token=token,
environment=os.environ.get("LOGFIRE_ENVIRONMENT", "production"),
console=False,
)

# If traceparent provided, attach to the current context
if traceparent:
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)

propagator = TraceContextTextMapPropagator()
ctx = propagator.extract(carrier={"traceparent": traceparent})
return ctx

return None

SYSTEM_PROMPT = """You are a PolicyEngine assistant that helps users understand tax and benefit policies.

Expand Down Expand Up @@ -256,6 +287,7 @@ def execute_api_tool(
tool_input: dict,
api_base_url: str,
log_fn: Callable,
trace_headers: dict | None = None,
) -> str:
"""Execute an API tool by making the HTTP request."""
meta = tool.get("_meta", {})
Expand All @@ -267,6 +299,8 @@ def execute_api_tool(
url = f"{api_base_url}{path}"
query_params = {}
headers = {"Content-Type": "application/json"}
if trace_headers:
headers.update(trace_headers)

# Separate path, query, and body parameters
body_data = {}
Expand Down Expand Up @@ -344,16 +378,26 @@ def _run_agent_impl(
call_id: str = "",
history: list[dict] | None = None,
max_turns: int = 30,
traceparent: str | None = None,
) -> dict:
"""Core agent implementation."""
import logfire

# Get traceparent for HTTP requests
def get_trace_headers() -> dict:
if traceparent:
return {"traceparent": traceparent}
return {}

def log(msg: str) -> None:
logfire.info(msg, call_id=call_id)
print(msg)
if call_id:
try:
requests.post(
f"{api_base_url}/agent/log/{call_id}",
json={"message": msg},
headers=get_trace_headers(),
timeout=5,
)
except Exception:
Expand Down Expand Up @@ -425,7 +469,9 @@ def log(msg: str) -> None:
else:
tool = tool_lookup.get(block.name)
if tool:
result = execute_api_tool(tool, block.input, api_base_url, log)
result = execute_api_tool(
tool, block.input, api_base_url, log, get_trace_headers()
)
else:
result = f"Unknown tool: {block.name}"

Expand Down Expand Up @@ -457,6 +503,7 @@ def log(msg: str) -> None:
requests.post(
f"{api_base_url}/agent/complete/{call_id}",
json=result,
headers=get_trace_headers(),
timeout=10,
)
except Exception:
Expand All @@ -465,22 +512,29 @@ def log(msg: str) -> None:
return result


@app.function(image=image, secrets=[anthropic_secret], timeout=600)
@app.function(image=image, secrets=[anthropic_secret, logfire_secrets], timeout=600)
def run_agent(
question: str,
api_base_url: str = "https://v2.api.policyengine.org",
call_id: str = "",
history: list[dict] | None = None,
max_turns: int = 30,
traceparent: str | None = None,
) -> dict:
"""Run agentic loop to answer a policy question (Modal wrapper)."""
return _run_agent_impl(
question,
api_base_url,
call_id,
history=history,
max_turns=max_turns,
)
import logfire

ctx = configure_logfire(traceparent)

with logfire.span("run_agent", call_id=call_id, question=question[:200], _context=ctx):
return _run_agent_impl(
question,
api_base_url,
call_id,
history=history,
max_turns=max_turns,
traceparent=traceparent,
)


if __name__ == "__main__":
Expand Down
14 changes: 13 additions & 1 deletion src/policyengine_api/api/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@

import logfire
from fastapi import APIRouter, HTTPException
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel

from policyengine_api.config import settings


def get_traceparent() -> str | None:
"""Get the current W3C traceparent header for distributed tracing."""
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")

router = APIRouter(prefix="/agent", tags=["agent"])


Expand Down Expand Up @@ -126,9 +134,12 @@ async def run_agent(request: RunRequest) -> RunResponse:
# Production: use Modal
import modal

traceparent = get_traceparent()
run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent")
history_dicts = [{"role": m.role, "content": m.content} for m in request.history]
call = run_fn.spawn(request.question, api_base_url, call_id, history_dicts)
call = run_fn.spawn(
request.question, api_base_url, call_id, history_dicts, traceparent=traceparent
)

_calls[call_id] = {
"call": call,
Expand All @@ -137,6 +148,7 @@ async def run_agent(request: RunRequest) -> RunResponse:
"started_at": datetime.utcnow().isoformat(),
"status": "running",
"result": None,
"trace_id": traceparent, # Store for linking
}
logfire.info("agent_spawned", call_id=call_id, modal_call_id=call.object_id)
else:
Expand Down
14 changes: 12 additions & 2 deletions src/policyengine_api/api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logfire
from fastapi import APIRouter, Depends, HTTPException
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field
from sqlmodel import Session, select

Expand All @@ -40,6 +41,13 @@
from policyengine_api.services.database import get_session


def get_traceparent() -> str | None:
"""Get the current W3C traceparent header for distributed tracing."""
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")


def _safe_float(value: float | None) -> float | None:
"""Convert NaN/inf to None for JSON serialization."""
if value is None:
Expand Down Expand Up @@ -522,6 +530,8 @@ def _trigger_economy_comparison(
"""Trigger economy comparison analysis (local or Modal)."""
from policyengine_api.config import settings

traceparent = get_traceparent()

if not settings.agent_use_modal and session is not None:
# Run locally
if tax_benefit_model_name == "policyengine_uk":
Expand All @@ -531,7 +541,7 @@ def _trigger_economy_comparison(
import modal

fn = modal.Function.from_name("policyengine", "economy_comparison_us")
fn.spawn(job_id=job_id)
fn.spawn(job_id=job_id, traceparent=traceparent)
else:
# Use Modal
import modal
Expand All @@ -541,7 +551,7 @@ def _trigger_economy_comparison(
else:
fn = modal.Function.from_name("policyengine", "economy_comparison_us")

fn.spawn(job_id=job_id)
fn.spawn(job_id=job_id, traceparent=traceparent)


@router.post("/economic-impact", response_model=EconomicImpactResponse)
Expand Down
12 changes: 12 additions & 0 deletions src/policyengine_api/api/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import logfire
from fastapi import APIRouter, Depends, HTTPException
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pydantic import BaseModel, Field
from sqlmodel import Session

Expand All @@ -20,6 +21,13 @@
)
from policyengine_api.services.database import get_session


def get_traceparent() -> str | None:
"""Get the current W3C traceparent header for distributed tracing."""
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")

router = APIRouter(prefix="/household", tags=["household"])


Expand Down Expand Up @@ -400,6 +408,8 @@ def _trigger_modal_household(
# Use Modal
import modal

traceparent = get_traceparent()

if request.tax_benefit_model_name == "policyengine_uk":
fn = modal.Function.from_name("policyengine", "simulate_household_uk")
fn.spawn(
Expand All @@ -410,6 +420,7 @@ def _trigger_modal_household(
year=request.year or 2026,
policy_data=policy_data,
dynamic_data=dynamic_data,
traceparent=traceparent,
)
else:
fn = modal.Function.from_name("policyengine", "simulate_household_us")
Expand All @@ -424,6 +435,7 @@ def _trigger_modal_household(
year=request.year or 2024,
policy_data=policy_data,
dynamic_data=dynamic_data,
traceparent=traceparent,
)


Expand Down
Loading