diff --git a/src/policyengine_api/agent_sandbox.py b/src/policyengine_api/agent_sandbox.py index 1cd8be7..4f97062 100644 --- a/src/policyengine_api/agent_sandbox.py +++ b/src/policyengine_api/agent_sandbox.py @@ -10,11 +10,42 @@ import requests image = modal.Image.debian_slim(python_version="3.12").pip_install( - "anthropic", "requests" + "anthropic", "requests", "logfire[httpx]" ) app = modal.App("policyengine-sandbox") anthropic_secret = modal.Secret.from_name("anthropic-api-key") +logfire_secrets = modal.Secret.from_name("policyengine-logfire") + + +def configure_logfire(traceparent: str | None = None): + """Configure logfire with optional trace context propagation.""" + import os + + import logfire + + token = os.environ.get("LOGFIRE_TOKEN", "") + if not token: + return None + + logfire.configure( + service_name="policyengine-agent", + token=token, + environment=os.environ.get("LOGFIRE_ENVIRONMENT", "production"), + console=False, + ) + + # If traceparent provided, attach to the current context + if traceparent: + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + propagator = TraceContextTextMapPropagator() + ctx = propagator.extract(carrier={"traceparent": traceparent}) + return ctx + + return None SYSTEM_PROMPT = """You are a PolicyEngine assistant that helps users understand tax and benefit policies. @@ -256,6 +287,7 @@ def execute_api_tool( tool_input: dict, api_base_url: str, log_fn: Callable, + trace_headers: dict | None = None, ) -> str: """Execute an API tool by making the HTTP request.""" meta = tool.get("_meta", {}) @@ -267,6 +299,8 @@ def execute_api_tool( url = f"{api_base_url}{path}" query_params = {} headers = {"Content-Type": "application/json"} + if trace_headers: + headers.update(trace_headers) # Separate path, query, and body parameters body_data = {} @@ -344,16 +378,26 @@ def _run_agent_impl( call_id: str = "", history: list[dict] | None = None, max_turns: int = 30, + traceparent: str | None = None, ) -> dict: """Core agent implementation.""" + import logfire + + # Get traceparent for HTTP requests + def get_trace_headers() -> dict: + if traceparent: + return {"traceparent": traceparent} + return {} def log(msg: str) -> None: + logfire.info(msg, call_id=call_id) print(msg) if call_id: try: requests.post( f"{api_base_url}/agent/log/{call_id}", json={"message": msg}, + headers=get_trace_headers(), timeout=5, ) except Exception: @@ -425,7 +469,9 @@ def log(msg: str) -> None: else: tool = tool_lookup.get(block.name) if tool: - result = execute_api_tool(tool, block.input, api_base_url, log) + result = execute_api_tool( + tool, block.input, api_base_url, log, get_trace_headers() + ) else: result = f"Unknown tool: {block.name}" @@ -457,6 +503,7 @@ def log(msg: str) -> None: requests.post( f"{api_base_url}/agent/complete/{call_id}", json=result, + headers=get_trace_headers(), timeout=10, ) except Exception: @@ -465,22 +512,29 @@ def log(msg: str) -> None: return result -@app.function(image=image, secrets=[anthropic_secret], timeout=600) +@app.function(image=image, secrets=[anthropic_secret, logfire_secrets], timeout=600) def run_agent( question: str, api_base_url: str = "https://v2.api.policyengine.org", call_id: str = "", history: list[dict] | None = None, max_turns: int = 30, + traceparent: str | None = None, ) -> dict: """Run agentic loop to answer a policy question (Modal wrapper).""" - return _run_agent_impl( - question, - api_base_url, - call_id, - history=history, - max_turns=max_turns, - ) + import logfire + + ctx = configure_logfire(traceparent) + + with logfire.span("run_agent", call_id=call_id, question=question[:200], _context=ctx): + return _run_agent_impl( + question, + api_base_url, + call_id, + history=history, + max_turns=max_turns, + traceparent=traceparent, + ) if __name__ == "__main__": diff --git a/src/policyengine_api/api/agent.py b/src/policyengine_api/api/agent.py index 4e58485..7b7d108 100644 --- a/src/policyengine_api/api/agent.py +++ b/src/policyengine_api/api/agent.py @@ -12,10 +12,18 @@ import logfire from fastapi import APIRouter, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel from policyengine_api.config import settings + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + router = APIRouter(prefix="/agent", tags=["agent"]) @@ -126,9 +134,12 @@ async def run_agent(request: RunRequest) -> RunResponse: # Production: use Modal import modal + traceparent = get_traceparent() run_fn = modal.Function.from_name("policyengine-sandbox", "run_agent") history_dicts = [{"role": m.role, "content": m.content} for m in request.history] - call = run_fn.spawn(request.question, api_base_url, call_id, history_dicts) + call = run_fn.spawn( + request.question, api_base_url, call_id, history_dicts, traceparent=traceparent + ) _calls[call_id] = { "call": call, @@ -137,6 +148,7 @@ async def run_agent(request: RunRequest) -> RunResponse: "started_at": datetime.utcnow().isoformat(), "status": "running", "result": None, + "trace_id": traceparent, # Store for linking } logfire.info("agent_spawned", call_id=call_id, modal_call_id=call.object_id) else: diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index 3db5426..c9aa86d 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -21,6 +21,7 @@ import logfire from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel, Field from sqlmodel import Session, select @@ -40,6 +41,13 @@ from policyengine_api.services.database import get_session +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + + def _safe_float(value: float | None) -> float | None: """Convert NaN/inf to None for JSON serialization.""" if value is None: @@ -522,6 +530,8 @@ def _trigger_economy_comparison( """Trigger economy comparison analysis (local or Modal).""" from policyengine_api.config import settings + traceparent = get_traceparent() + if not settings.agent_use_modal and session is not None: # Run locally if tax_benefit_model_name == "policyengine_uk": @@ -531,7 +541,7 @@ def _trigger_economy_comparison( import modal fn = modal.Function.from_name("policyengine", "economy_comparison_us") - fn.spawn(job_id=job_id) + fn.spawn(job_id=job_id, traceparent=traceparent) else: # Use Modal import modal @@ -541,7 +551,7 @@ def _trigger_economy_comparison( else: fn = modal.Function.from_name("policyengine", "economy_comparison_us") - fn.spawn(job_id=job_id) + fn.spawn(job_id=job_id, traceparent=traceparent) @router.post("/economic-impact", response_model=EconomicImpactResponse) diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 79c41eb..b0e99c9 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -9,6 +9,7 @@ import logfire from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel, Field from sqlmodel import Session @@ -20,6 +21,13 @@ ) from policyengine_api.services.database import get_session + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + router = APIRouter(prefix="/household", tags=["household"]) @@ -400,6 +408,8 @@ def _trigger_modal_household( # Use Modal import modal + traceparent = get_traceparent() + if request.tax_benefit_model_name == "policyengine_uk": fn = modal.Function.from_name("policyengine", "simulate_household_uk") fn.spawn( @@ -410,6 +420,7 @@ def _trigger_modal_household( year=request.year or 2026, policy_data=policy_data, dynamic_data=dynamic_data, + traceparent=traceparent, ) else: fn = modal.Function.from_name("policyengine", "simulate_household_us") @@ -424,6 +435,7 @@ def _trigger_modal_household( year=request.year or 2024, policy_data=policy_data, dynamic_data=dynamic_data, + traceparent=traceparent, ) diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index cadf168..2cc9dc1 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -25,7 +25,8 @@ "sqlmodel>=0.0.22 " "psycopg2-binary>=2.9.10 " "supabase>=2.10.0 " - "rich>=13.9.4" + "rich>=13.9.4 " + "logfire[httpx]>=3.0.0" ) # Include the policyengine_api models package (copy=True allows subsequent build steps) .add_local_python_source("policyengine_api", copy=True) @@ -59,8 +60,44 @@ def _import_us(): app = modal.App("policyengine") -# Secrets for database access -secrets = modal.Secret.from_name("policyengine-db") +# Secrets for database and observability +db_secrets = modal.Secret.from_name("policyengine-db") +logfire_secrets = modal.Secret.from_name("policyengine-logfire") + + +def configure_logfire(service_name: str, traceparent: str | None = None): + """Configure logfire with optional trace context propagation. + + Args: + service_name: Service name for spans (e.g. "policyengine-modal-uk") + traceparent: W3C traceparent header for distributed tracing + """ + import os + + import logfire + + token = os.environ.get("LOGFIRE_TOKEN", "") + if not token: + return None + + logfire.configure( + service_name=service_name, + token=token, + environment=os.environ.get("LOGFIRE_ENVIRONMENT", "production"), + console=False, + ) + + # If traceparent provided, attach to the current context + if traceparent: + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator, + ) + + propagator = TraceContextTextMapPropagator() + ctx = propagator.extract(carrier={"traceparent": traceparent}) + return ctx + + return None def get_database_url() -> str: @@ -110,7 +147,9 @@ def download_dataset( return str(cache_path) -@app.function(image=uk_image, secrets=[secrets], memory=4096, cpu=4, timeout=600) +@app.function( + image=uk_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600 +) def simulate_household_uk( job_id: str, people: list[dict], @@ -119,120 +158,126 @@ def simulate_household_uk( year: int, policy_data: dict | None, dynamic_data: dict | None, + traceparent: str | None = None, ) -> None: """Calculate UK household and write result to database.""" import json from datetime import datetime, timezone - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print(f"[bold blue]Running UK household job {job_id}[/bold blue]") + ctx = configure_logfire("policyengine-modal-uk", traceparent) - database_url = get_database_url() - engine = create_engine(database_url) + with logfire.span("simulate_household_uk", job_id=job_id, _context=ctx): + logfire.info("Starting UK household calculation", job_id=job_id) - try: - from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.analysis import ( - UKHouseholdInput, - calculate_household_impact, - ) + database_url = get_database_url() + engine = create_engine(database_url) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ( - ParameterValue as PEParameterValue, - ) - from policyengine.core.policy import ( - Policy as PEPolicy, + try: + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.analysis import ( + UKHouseholdInput, + calculate_household_impact, ) - pe_param_values = [] - param_lookup = {p.name: p for p in uk_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ( + ParameterValue as PEParameterValue, + ) + from policyengine.core.policy import ( + Policy as PEPolicy, + ) - pe_input = UKHouseholdInput( - people=people, - benunit=benunit, - household=household, - year=year, - ) + pe_param_values = [] + param_lookup = {p.name: p for p in uk_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) - result = calculate_household_impact(pe_input, policy=policy) - - # Write result to database - with Session(engine) as session: - from sqlmodel import text - - session.exec( - text(""" - UPDATE household_jobs - SET status = 'COMPLETED', - result = :result, - completed_at = :completed_at - WHERE id = :job_id - """), - params={ - "job_id": job_id, - "result": json.dumps( - { - "person": result.person, - "benunit": result.benunit, - "household": result.household, - } - ), - "completed_at": datetime.now(timezone.utc), - }, - ) - session.commit() - - console.print(f"[bold green]UK household job {job_id} completed[/bold green]") - - except Exception as e: - console.print(f"[bold red]UK household job {job_id} failed: {e}[/bold red]") - with Session(engine) as session: - from sqlmodel import text - - session.exec( - text(""" - UPDATE household_jobs - SET status = 'FAILED', - error_message = :error, - completed_at = :completed_at - WHERE id = :job_id - """), - params={ - "job_id": job_id, - "error": str(e), - "completed_at": datetime.now(timezone.utc), - }, + pe_input = UKHouseholdInput( + people=people, + benunit=benunit, + household=household, + year=year, ) - session.commit() - raise + + with logfire.span("calculate_household_impact"): + result = calculate_household_impact(pe_input, policy=policy) + + # Write result to database + with Session(engine) as session: + from sqlmodel import text + + session.exec( + text(""" + UPDATE household_jobs + SET status = 'COMPLETED', + result = :result, + completed_at = :completed_at + WHERE id = :job_id + """), + params={ + "job_id": job_id, + "result": json.dumps( + { + "person": result.person, + "benunit": result.benunit, + "household": result.household, + } + ), + "completed_at": datetime.now(timezone.utc), + }, + ) + session.commit() + + logfire.info("UK household job completed", job_id=job_id) + + except Exception as e: + logfire.error("UK household job failed", job_id=job_id, error=str(e)) + with Session(engine) as session: + from sqlmodel import text + + session.exec( + text(""" + UPDATE household_jobs + SET status = 'FAILED', + error_message = :error, + completed_at = :completed_at + WHERE id = :job_id + """), + params={ + "job_id": job_id, + "error": str(e), + "completed_at": datetime.now(timezone.utc), + }, + ) + session.commit() + raise -@app.function(image=us_image, secrets=[secrets], memory=4096, cpu=4, timeout=600) +@app.function( + image=us_image, secrets=[db_secrets, logfire_secrets], memory=4096, cpu=4, timeout=600 +) def simulate_household_us( job_id: str, people: list[dict], @@ -244,816 +289,812 @@ def simulate_household_us( year: int, policy_data: dict | None, dynamic_data: dict | None, + traceparent: str | None = None, ) -> None: """Calculate US household and write result to database.""" import json from datetime import datetime, timezone - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print(f"[bold blue]Running US household job {job_id}[/bold blue]") + ctx = configure_logfire("policyengine-modal-us", traceparent) - database_url = get_database_url() - engine = create_engine(database_url) + with logfire.span("simulate_household_us", job_id=job_id, _context=ctx): + logfire.info("Starting US household calculation", job_id=job_id) - try: - from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.analysis import ( - USHouseholdInput, - calculate_household_impact, - ) + database_url = get_database_url() + engine = create_engine(database_url) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ( - ParameterValue as PEParameterValue, - ) - from policyengine.core.policy import ( - Policy as PEPolicy, + try: + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.analysis import ( + USHouseholdInput, + calculate_household_impact, ) - pe_param_values = [] - param_lookup = {p.name: p for p in us_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ( + ParameterValue as PEParameterValue, + ) + from policyengine.core.policy import ( + Policy as PEPolicy, + ) - pe_input = USHouseholdInput( - people=people, - marital_unit=marital_unit, - family=family, - spm_unit=spm_unit, - tax_unit=tax_unit, - household=household, - year=year, - ) + pe_param_values = [] + param_lookup = {p.name: p for p in us_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) - result = calculate_household_impact(pe_input, policy=policy) - - # Write result to database - with Session(engine) as session: - from sqlmodel import text - - session.exec( - text(""" - UPDATE household_jobs - SET status = 'COMPLETED', - result = :result, - completed_at = :completed_at - WHERE id = :job_id - """), - params={ - "job_id": job_id, - "result": json.dumps( - { - "person": result.person, - "marital_unit": result.marital_unit, - "family": result.family, - "spm_unit": result.spm_unit, - "tax_unit": result.tax_unit, - "household": result.household, - } - ), - "completed_at": datetime.now(timezone.utc), - }, + pe_input = USHouseholdInput( + people=people, + marital_unit=marital_unit, + family=family, + spm_unit=spm_unit, + tax_unit=tax_unit, + household=household, + year=year, ) - session.commit() - - console.print(f"[bold green]US household job {job_id} completed[/bold green]") - - except Exception as e: - console.print(f"[bold red]US household job {job_id} failed: {e}[/bold red]") - with Session(engine) as session: - from sqlmodel import text - - session.exec( - text(""" - UPDATE household_jobs - SET status = 'FAILED', - error_message = :error, - completed_at = :completed_at - WHERE id = :job_id - """), - params={ - "job_id": job_id, - "error": str(e), - "completed_at": datetime.now(timezone.utc), - }, - ) - session.commit() - raise + with logfire.span("calculate_household_impact"): + result = calculate_household_impact(pe_input, policy=policy) + + # Write result to database + with Session(engine) as session: + from sqlmodel import text + + session.exec( + text(""" + UPDATE household_jobs + SET status = 'COMPLETED', + result = :result, + completed_at = :completed_at + WHERE id = :job_id + """), + params={ + "job_id": job_id, + "result": json.dumps( + { + "person": result.person, + "marital_unit": result.marital_unit, + "family": result.family, + "spm_unit": result.spm_unit, + "tax_unit": result.tax_unit, + "household": result.household, + } + ), + "completed_at": datetime.now(timezone.utc), + }, + ) + session.commit() + + logfire.info("US household job completed", job_id=job_id) -@app.function(image=uk_image, secrets=[secrets], memory=8192, cpu=8, timeout=1800) -def simulate_economy_uk(simulation_id: str) -> None: + except Exception as e: + logfire.error("US household job failed", job_id=job_id, error=str(e)) + with Session(engine) as session: + from sqlmodel import text + + session.exec( + text(""" + UPDATE household_jobs + SET status = 'FAILED', + error_message = :error, + completed_at = :completed_at + WHERE id = :job_id + """), + params={ + "job_id": job_id, + "error": str(e), + "completed_at": datetime.now(timezone.utc), + }, + ) + session.commit() + raise + + +@app.function( + image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 +) +def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> None: """Run a single UK economy simulation and write results to database.""" import os from datetime import datetime, timezone from uuid import UUID - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print( - f"[bold blue]Running UK economy simulation {simulation_id}[/bold blue]" - ) - - database_url = get_database_url() - supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] - storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - - engine = create_engine(database_url) + ctx = configure_logfire("policyengine-modal-uk", traceparent) - try: - from policyengine_api.models import ( - Dataset, - Simulation, - SimulationStatus, - ) - with Session(engine) as session: - simulation = session.get(Simulation, UUID(simulation_id)) - if not simulation: - raise ValueError(f"Simulation {simulation_id} not found") - - # Skip if already completed - if simulation.status == SimulationStatus.COMPLETED: - console.print( - f"[yellow]Simulation {simulation_id} already completed[/yellow]" - ) - return + with logfire.span("simulate_economy_uk", simulation_id=simulation_id, _context=ctx): + logfire.info("Starting UK economy simulation", simulation_id=simulation_id) - # Update status to running - simulation.status = SimulationStatus.RUNNING - session.add(simulation) - session.commit() + database_url = get_database_url() + supabase_url = os.environ["SUPABASE_URL"] + supabase_key = os.environ["SUPABASE_KEY"] + storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - # Get dataset - dataset = session.get(Dataset, simulation.dataset_id) - if not dataset: - raise ValueError(f"Dataset {simulation.dataset_id} not found") + engine = create_engine(database_url) - # Import policyengine - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import ( - PolicyEngineUKDataset, + try: + from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, ) + with Session(engine) as session: + simulation = session.get(Simulation, UUID(simulation_id)) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + + # Skip if already completed + if simulation.status == SimulationStatus.COMPLETED: + logfire.info("Simulation already completed", simulation_id=simulation_id) + return + + # Update status to running + simulation.status = SimulationStatus.RUNNING + session.add(simulation) + session.commit() - pe_model_version = uk_latest - - # Get policy and dynamic - policy = _get_pe_policy_uk(simulation.policy_id, pe_model_version, session) - dynamic = _get_pe_dynamic_uk( - simulation.dynamic_id, pe_model_version, session - ) + # Get dataset + dataset = session.get(Dataset, simulation.dataset_id) + if not dataset: + raise ValueError(f"Dataset {simulation.dataset_id} not found") - # Download dataset - console.print(f" Loading dataset: {dataset.filepath}") - local_path = download_dataset( - dataset.filepath, supabase_url, supabase_key, storage_bucket - ) + # Import policyengine + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + ) - pe_dataset = PolicyEngineUKDataset( - name=dataset.name, - description=dataset.description or "", - filepath=local_path, - year=dataset.year, - ) + pe_model_version = uk_latest - # Create and run simulation - console.print(" Running simulation...") - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, - ) - pe_sim.ensure() + # Get policy and dynamic + policy = _get_pe_policy_uk(simulation.policy_id, pe_model_version, session) + dynamic = _get_pe_dynamic_uk( + simulation.dynamic_id, pe_model_version, session + ) - # Mark as completed - simulation.status = SimulationStatus.COMPLETED - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() + # Download dataset + logfire.info("Loading dataset", filepath=dataset.filepath) + local_path = download_dataset( + dataset.filepath, supabase_url, supabase_key, storage_bucket + ) - console.print( - f"[bold green]UK economy simulation {simulation_id} completed[/bold green]" - ) + pe_dataset = PolicyEngineUKDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, + ) - except Exception as e: - console.print( - f"[bold red]UK economy simulation {simulation_id} failed: {e}[/bold red]" - ) - # Use raw SQL to mark as failed - models may not be available - try: - from sqlmodel import text + # Create and run simulation + with logfire.span("run_simulation"): + pe_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=dynamic, + ) + pe_sim.ensure() - with Session(engine) as session: - session.execute( - text( - "UPDATE simulations SET status = 'failed', error_message = :error " - "WHERE id = :sim_id" - ), - {"sim_id": simulation_id, "error": str(e)[:1000]}, - ) + # Mark as completed + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) session.commit() - except Exception as db_error: - console.print(f"[bold red]Failed to update DB: {db_error}[/bold red]") - raise + + logfire.info("UK economy simulation completed", simulation_id=simulation_id) + + except Exception as e: + logfire.error("UK economy simulation failed", simulation_id=simulation_id, error=str(e)) + # Use raw SQL to mark as failed - models may not be available + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE simulations SET status = 'failed', error_message = :error " + "WHERE id = :sim_id" + ), + {"sim_id": simulation_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise -@app.function(image=us_image, secrets=[secrets], memory=8192, cpu=8, timeout=1800) -def simulate_economy_us(simulation_id: str) -> None: +@app.function( + image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 +) +def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> None: """Run a single US economy simulation and write results to database.""" import os from datetime import datetime, timezone from uuid import UUID - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print( - f"[bold blue]Running US economy simulation {simulation_id}[/bold blue]" - ) + ctx = configure_logfire("policyengine-modal-us", traceparent) - database_url = get_database_url() - supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] - storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") + with logfire.span("simulate_economy_us", simulation_id=simulation_id, _context=ctx): + logfire.info("Starting US economy simulation", simulation_id=simulation_id) - engine = create_engine(database_url) + database_url = get_database_url() + supabase_url = os.environ["SUPABASE_URL"] + supabase_key = os.environ["SUPABASE_KEY"] + storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - try: - from policyengine_api.models import ( - Dataset, - Simulation, - SimulationStatus, - ) - with Session(engine) as session: - simulation = session.get(Simulation, UUID(simulation_id)) - if not simulation: - raise ValueError(f"Simulation {simulation_id} not found") - - # Skip if already completed - if simulation.status == SimulationStatus.COMPLETED: - console.print( - f"[yellow]Simulation {simulation_id} already completed[/yellow]" - ) - return - - # Update status to running - simulation.status = SimulationStatus.RUNNING - session.add(simulation) - session.commit() + engine = create_engine(database_url) - # Get dataset - dataset = session.get(Dataset, simulation.dataset_id) - if not dataset: - raise ValueError(f"Dataset {simulation.dataset_id} not found") - - # Import policyengine - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import ( - PolicyEngineUSDataset, + try: + from policyengine_api.models import ( + Dataset, + Simulation, + SimulationStatus, ) + with Session(engine) as session: + simulation = session.get(Simulation, UUID(simulation_id)) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + + # Skip if already completed + if simulation.status == SimulationStatus.COMPLETED: + logfire.info("Simulation already completed", simulation_id=simulation_id) + return + + # Update status to running + simulation.status = SimulationStatus.RUNNING + session.add(simulation) + session.commit() - pe_model_version = us_latest - - # Get policy and dynamic - policy = _get_pe_policy_us(simulation.policy_id, pe_model_version, session) - dynamic = _get_pe_dynamic_us( - simulation.dynamic_id, pe_model_version, session - ) + # Get dataset + dataset = session.get(Dataset, simulation.dataset_id) + if not dataset: + raise ValueError(f"Dataset {simulation.dataset_id} not found") - # Download dataset - console.print(f" Loading dataset: {dataset.filepath}") - local_path = download_dataset( - dataset.filepath, supabase_url, supabase_key, storage_bucket - ) + # Import policyengine + from policyengine.core import Simulation as PESimulation + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + ) - pe_dataset = PolicyEngineUSDataset( - name=dataset.name, - description=dataset.description or "", - filepath=local_path, - year=dataset.year, - ) + pe_model_version = us_latest - # Create and run simulation - console.print(" Running simulation...") - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, - ) - pe_sim.ensure() + # Get policy and dynamic + policy = _get_pe_policy_us(simulation.policy_id, pe_model_version, session) + dynamic = _get_pe_dynamic_us( + simulation.dynamic_id, pe_model_version, session + ) - # Mark as completed - simulation.status = SimulationStatus.COMPLETED - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() + # Download dataset + logfire.info("Loading dataset", filepath=dataset.filepath) + local_path = download_dataset( + dataset.filepath, supabase_url, supabase_key, storage_bucket + ) - console.print( - f"[bold green]US economy simulation {simulation_id} completed[/bold green]" - ) + pe_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, + ) - except Exception as e: - console.print( - f"[bold red]US economy simulation {simulation_id} failed: {e}[/bold red]" - ) - # Use raw SQL to mark as failed - models may not be available - try: - from sqlmodel import text + # Create and run simulation + with logfire.span("run_simulation"): + pe_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=dynamic, + ) + pe_sim.ensure() - with Session(engine) as session: - session.execute( - text( - "UPDATE simulations SET status = 'failed', error_message = :error " - "WHERE id = :sim_id" - ), - {"sim_id": simulation_id, "error": str(e)[:1000]}, - ) + # Mark as completed + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) session.commit() - except Exception as db_error: - console.print(f"[bold red]Failed to update DB: {db_error}[/bold red]") - raise + + logfire.info("US economy simulation completed", simulation_id=simulation_id) + + except Exception as e: + logfire.error("US economy simulation failed", simulation_id=simulation_id, error=str(e)) + # Use raw SQL to mark as failed - models may not be available + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE simulations SET status = 'failed', error_message = :error " + "WHERE id = :sim_id" + ), + {"sim_id": simulation_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise -@app.function(image=uk_image, secrets=[secrets], memory=8192, cpu=8, timeout=1800) -def economy_comparison_uk(job_id: str) -> None: +@app.function( + image=uk_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 +) +def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: """Run UK economy comparison analysis (decile impacts, budget impact, etc).""" import os from datetime import datetime, timezone from uuid import UUID - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print(f"[bold blue]Running UK economy comparison {job_id}[/bold blue]") + ctx = configure_logfire("policyengine-modal-uk", traceparent) - database_url = get_database_url() - supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] - storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") + with logfire.span("economy_comparison_uk", job_id=job_id, _context=ctx): + logfire.info("Starting UK economy comparison", job_id=job_id) - engine = create_engine(database_url) + database_url = get_database_url() + supabase_url = os.environ["SUPABASE_URL"] + supabase_key = os.environ["SUPABASE_KEY"] + storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - try: - # Import models inline - from policyengine_api.models import ( - Dataset, - DecileImpact, - ProgramStatistics, - Report, - ReportStatus, - Simulation, - SimulationStatus, - TaxBenefitModelVersion, - ) - with Session(engine) as session: - # Load report and related data - report = session.get(Report, UUID(job_id)) - if not report: - raise ValueError(f"Report {job_id} not found") - - baseline_sim = session.get(Simulation, report.baseline_simulation_id) - reform_sim = session.get(Simulation, report.reform_simulation_id) - - if not baseline_sim or not reform_sim: - raise ValueError("Simulations not found") - - # Update status to running - report.status = ReportStatus.RUNNING - session.add(report) - session.commit() - - # Get dataset - dataset = session.get(Dataset, baseline_sim.dataset_id) - if not dataset: - raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") - - # Get model version (unused but keeping for reference) - _ = session.get( - TaxBenefitModelVersion, baseline_sim.tax_benefit_model_version_id - ) + engine = create_engine(database_url) - # Import policyengine - from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import ( - PolicyEngineUKDataset, - ) - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, + try: + # Import models inline + from policyengine_api.models import ( + Dataset, + DecileImpact, + ProgramStatistics, + Report, + ReportStatus, + Simulation, + SimulationStatus, + TaxBenefitModelVersion, ) + with Session(engine) as session: + # Load report and related data + report = session.get(Report, UUID(job_id)) + if not report: + raise ValueError(f"Report {job_id} not found") - pe_model_version = uk_latest + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) - # Get policies - baseline_policy = _get_pe_policy_uk( - baseline_sim.policy_id, pe_model_version, session - ) - reform_policy = _get_pe_policy_uk( - reform_sim.policy_id, pe_model_version, session - ) - baseline_dynamic = _get_pe_dynamic_uk( - baseline_sim.dynamic_id, pe_model_version, session - ) - reform_dynamic = _get_pe_dynamic_uk( - reform_sim.dynamic_id, pe_model_version, session - ) + if not baseline_sim or not reform_sim: + raise ValueError("Simulations not found") - # Download dataset - console.print(f" Loading dataset: {dataset.filepath}") - local_path = download_dataset( - dataset.filepath, supabase_url, supabase_key, storage_bucket - ) + # Update status to running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() - pe_dataset = PolicyEngineUKDataset( - name=dataset.name, - description=dataset.description or "", - filepath=local_path, - year=dataset.year, - ) + # Get dataset + dataset = session.get(Dataset, baseline_sim.dataset_id) + if not dataset: + raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") - # Create and run simulations - console.print(" Running baseline simulation...") - pe_baseline_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=baseline_policy, - dynamic=baseline_dynamic, - ) - pe_baseline_sim.ensure() - - console.print(" Running reform simulation...") - pe_reform_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=reform_policy, - dynamic=reform_dynamic, - ) - pe_reform_sim.ensure() - - # Calculate decile impacts - console.print(" Calculating decile impacts...") - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, + # Get model version (unused but keeping for reference) + _ = session.get( + TaxBenefitModelVersion, baseline_sim.tax_benefit_model_version_id + ) + + # Import policyengine + from policyengine.core import Simulation as PESimulation + from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, + from policyengine.tax_benefit_models.uk.outputs import ( + ProgrammeStatistics as PEProgrammeStats, ) - session.add(decile_impact) - # Calculate program statistics - console.print(" Calculating program statistics...") - PEProgrammeStats.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) + pe_model_version = uk_latest - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError as e: - console.print(f" Skipping {prog_name}: variable not found ({e})") - - # Mark simulations and report as completed - baseline_sim.status = SimulationStatus.COMPLETED - baseline_sim.completed_at = datetime.now(timezone.utc) - reform_sim.status = SimulationStatus.COMPLETED - reform_sim.completed_at = datetime.now(timezone.utc) - report.status = ReportStatus.COMPLETED - - session.add(baseline_sim) - session.add(reform_sim) - session.add(report) - session.commit() - - console.print( - f"[bold green]UK economy comparison {job_id} completed[/bold green]" - ) + # Get policies + baseline_policy = _get_pe_policy_uk( + baseline_sim.policy_id, pe_model_version, session + ) + reform_policy = _get_pe_policy_uk( + reform_sim.policy_id, pe_model_version, session + ) + baseline_dynamic = _get_pe_dynamic_uk( + baseline_sim.dynamic_id, pe_model_version, session + ) + reform_dynamic = _get_pe_dynamic_uk( + reform_sim.dynamic_id, pe_model_version, session + ) - except Exception as e: - console.print( - f"[bold red]UK economy comparison {job_id} failed: {e}[/bold red]" - ) - # Use raw SQL to mark as failed - models may not be available - try: - from sqlmodel import text + # Download dataset + logfire.info("Loading dataset", filepath=dataset.filepath) + local_path = download_dataset( + dataset.filepath, supabase_url, supabase_key, storage_bucket + ) - with Session(engine) as session: - session.execute( - text( - "UPDATE reports SET status = 'failed', error_message = :error " - "WHERE id = :job_id" - ), - {"job_id": job_id, "error": str(e)[:1000]}, + pe_dataset = PolicyEngineUKDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, ) + + # Create and run simulations + with logfire.span("run_baseline_simulation"): + pe_baseline_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=baseline_policy, + dynamic=baseline_dynamic, + ) + pe_baseline_sim.ensure() + + with logfire.span("run_reform_simulation"): + pe_reform_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=reform_policy, + dynamic=reform_dynamic, + ) + pe_reform_sim.ensure() + + # Calculate decile impacts + with logfire.span("calculate_decile_impacts"): + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + + decile_impact = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(decile_impact) + + # Calculate program statistics + with logfire.span("calculate_program_statistics"): + PEProgrammeStats.model_rebuild( + _types_namespace={"Simulation": PESimulation} + ) + + programmes = { + "income_tax": {"entity": "person", "is_tax": True}, + "national_insurance": {"entity": "person", "is_tax": True}, + "vat": {"entity": "household", "is_tax": True}, + "council_tax": {"entity": "household", "is_tax": True}, + "universal_credit": {"entity": "person", "is_tax": False}, + "child_benefit": {"entity": "person", "is_tax": False}, + "pension_credit": {"entity": "person", "is_tax": False}, + "income_support": {"entity": "person", "is_tax": False}, + "working_tax_credit": {"entity": "person", "is_tax": False}, + "child_tax_credit": {"entity": "person", "is_tax": False}, + } + + for prog_name, prog_info in programmes.items(): + try: + ps = PEProgrammeStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programme_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + program_stat = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(program_stat) + except KeyError as e: + logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e)) + + # Mark simulations and report as completed + baseline_sim.status = SimulationStatus.COMPLETED + baseline_sim.completed_at = datetime.now(timezone.utc) + reform_sim.status = SimulationStatus.COMPLETED + reform_sim.completed_at = datetime.now(timezone.utc) + report.status = ReportStatus.COMPLETED + + session.add(baseline_sim) + session.add(reform_sim) + session.add(report) session.commit() - except Exception as db_error: - console.print(f"[bold red]Failed to update DB: {db_error}[/bold red]") - raise + + logfire.info("UK economy comparison completed", job_id=job_id) + + except Exception as e: + logfire.error("UK economy comparison failed", job_id=job_id, error=str(e)) + # Use raw SQL to mark as failed - models may not be available + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'failed', error_message = :error " + "WHERE id = :job_id" + ), + {"job_id": job_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise -@app.function(image=us_image, secrets=[secrets], memory=8192, cpu=8, timeout=1800) -def economy_comparison_us(job_id: str) -> None: +@app.function( + image=us_image, secrets=[db_secrets, logfire_secrets], memory=8192, cpu=8, timeout=1800 +) +def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: """Run US economy comparison analysis (decile impacts, budget impact, etc).""" import os from datetime import datetime, timezone from uuid import UUID - from rich.console import Console + import logfire from sqlmodel import Session, create_engine - console = Console() - console.print(f"[bold blue]Running US economy comparison {job_id}[/bold blue]") - - database_url = get_database_url() - supabase_url = os.environ["SUPABASE_URL"] - supabase_key = os.environ["SUPABASE_KEY"] - storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") + ctx = configure_logfire("policyengine-modal-us", traceparent) - engine = create_engine(database_url) - - try: - # Import models inline - from policyengine_api.models import ( - Dataset, - DecileImpact, - ProgramStatistics, - Report, - ReportStatus, - Simulation, - SimulationStatus, - ) + with logfire.span("economy_comparison_us", job_id=job_id, _context=ctx): + logfire.info("Starting US economy comparison", job_id=job_id) - with Session(engine) as session: - # Load report and related data - report = session.get(Report, UUID(job_id)) - if not report: - raise ValueError(f"Report {job_id} not found") + database_url = get_database_url() + supabase_url = os.environ["SUPABASE_URL"] + supabase_key = os.environ["SUPABASE_KEY"] + storage_bucket = os.environ.get("STORAGE_BUCKET", "datasets") - baseline_sim = session.get(Simulation, report.baseline_simulation_id) - reform_sim = session.get(Simulation, report.reform_simulation_id) + engine = create_engine(database_url) - if not baseline_sim or not reform_sim: - raise ValueError("Simulations not found") - - # Update status to running - report.status = ReportStatus.RUNNING - session.add(report) - session.commit() + try: + # Import models inline + from policyengine_api.models import ( + Dataset, + DecileImpact, + ProgramStatistics, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) - # Get dataset - dataset = session.get(Dataset, baseline_sim.dataset_id) - if not dataset: - raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") + with Session(engine) as session: + # Load report and related data + report = session.get(Report, UUID(job_id)) + if not report: + raise ValueError(f"Report {job_id} not found") - # Import policyengine - from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import ( - PolicyEngineUSDataset, - ) - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + reform_sim = session.get(Simulation, report.reform_simulation_id) - pe_model_version = us_latest + if not baseline_sim or not reform_sim: + raise ValueError("Simulations not found") - # Get policies - baseline_policy = _get_pe_policy_us( - baseline_sim.policy_id, pe_model_version, session - ) - reform_policy = _get_pe_policy_us( - reform_sim.policy_id, pe_model_version, session - ) - baseline_dynamic = _get_pe_dynamic_us( - baseline_sim.dynamic_id, pe_model_version, session - ) - reform_dynamic = _get_pe_dynamic_us( - reform_sim.dynamic_id, pe_model_version, session - ) + # Update status to running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() - # Download dataset - console.print(f" Loading dataset: {dataset.filepath}") - local_path = download_dataset( - dataset.filepath, supabase_url, supabase_key, storage_bucket - ) + # Get dataset + dataset = session.get(Dataset, baseline_sim.dataset_id) + if not dataset: + raise ValueError(f"Dataset {baseline_sim.dataset_id} not found") + + # Import policyengine + from policyengine.core import Simulation as PESimulation + from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + ) + from policyengine.tax_benefit_models.us.outputs import ( + ProgramStatistics as PEProgramStats, + ) - pe_dataset = PolicyEngineUSDataset( - name=dataset.name, - description=dataset.description or "", - filepath=local_path, - year=dataset.year, - ) + pe_model_version = us_latest - # Create and run simulations - console.print(" Running baseline simulation...") - pe_baseline_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=baseline_policy, - dynamic=baseline_dynamic, - ) - pe_baseline_sim.ensure() - - console.print(" Running reform simulation...") - pe_reform_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=reform_policy, - dynamic=reform_dynamic, - ) - pe_reform_sim.ensure() - - # Calculate decile impacts - console.print(" Calculating decile impacts...") - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, + # Get policies + baseline_policy = _get_pe_policy_us( + baseline_sim.policy_id, pe_model_version, session ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, + reform_policy = _get_pe_policy_us( + reform_sim.policy_id, pe_model_version, session + ) + baseline_dynamic = _get_pe_dynamic_us( + baseline_sim.dynamic_id, pe_model_version, session + ) + reform_dynamic = _get_pe_dynamic_us( + reform_sim.dynamic_id, pe_model_version, session ) - session.add(decile_impact) - - # Calculate program statistics - console.print(" Calculating program statistics...") - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "person", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError as e: - console.print(f" Skipping {prog_name}: variable not found ({e})") - - # Mark simulations and report as completed - baseline_sim.status = SimulationStatus.COMPLETED - baseline_sim.completed_at = datetime.now(timezone.utc) - reform_sim.status = SimulationStatus.COMPLETED - reform_sim.completed_at = datetime.now(timezone.utc) - report.status = ReportStatus.COMPLETED - - session.add(baseline_sim) - session.add(reform_sim) - session.add(report) - session.commit() - - console.print( - f"[bold green]US economy comparison {job_id} completed[/bold green]" - ) - except Exception as e: - console.print( - f"[bold red]US economy comparison {job_id} failed: {e}[/bold red]" - ) - # Use raw SQL to mark as failed - models may not be available - try: - from sqlmodel import text + # Download dataset + logfire.info("Loading dataset", filepath=dataset.filepath) + local_path = download_dataset( + dataset.filepath, supabase_url, supabase_key, storage_bucket + ) - with Session(engine) as session: - session.execute( - text( - "UPDATE reports SET status = 'failed', error_message = :error " - "WHERE id = :job_id" - ), - {"job_id": job_id, "error": str(e)[:1000]}, + pe_dataset = PolicyEngineUSDataset( + name=dataset.name, + description=dataset.description or "", + filepath=local_path, + year=dataset.year, ) + + # Create and run simulations + with logfire.span("run_baseline_simulation"): + pe_baseline_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=baseline_policy, + dynamic=baseline_dynamic, + ) + pe_baseline_sim.ensure() + + with logfire.span("run_reform_simulation"): + pe_reform_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=reform_policy, + dynamic=reform_dynamic, + ) + pe_reform_sim.ensure() + + # Calculate decile impacts + with logfire.span("calculate_decile_impacts"): + for decile_num in range(1, 11): + di = PEDecileImpact( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + decile=decile_num, + ) + di.run() + + decile_impact = DecileImpact( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + income_variable=di.income_variable, + entity=di.entity, + decile=di.decile, + quantiles=di.quantiles, + baseline_mean=di.baseline_mean, + reform_mean=di.reform_mean, + absolute_change=di.absolute_change, + relative_change=di.relative_change, + count_better_off=di.count_better_off, + count_worse_off=di.count_worse_off, + count_no_change=di.count_no_change, + ) + session.add(decile_impact) + + # Calculate program statistics + with logfire.span("calculate_program_statistics"): + PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) + + programs = { + "income_tax": {"entity": "tax_unit", "is_tax": True}, + "employee_payroll_tax": {"entity": "person", "is_tax": True}, + "snap": {"entity": "spm_unit", "is_tax": False}, + "tanf": {"entity": "spm_unit", "is_tax": False}, + "ssi": {"entity": "spm_unit", "is_tax": False}, + "social_security": {"entity": "person", "is_tax": False}, + } + + for prog_name, prog_info in programs.items(): + try: + ps = PEProgramStats( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + ) + ps.run() + program_stat = ProgramStatistics( + baseline_simulation_id=baseline_sim.id, + reform_simulation_id=reform_sim.id, + report_id=report.id, + program_name=prog_name, + entity=prog_info["entity"], + is_tax=prog_info["is_tax"], + baseline_total=ps.baseline_total, + reform_total=ps.reform_total, + change=ps.change, + baseline_count=ps.baseline_count, + reform_count=ps.reform_count, + winners=ps.winners, + losers=ps.losers, + ) + session.add(program_stat) + except KeyError as e: + logfire.warn(f"Skipping {prog_name}: variable not found", error=str(e)) + + # Mark simulations and report as completed + baseline_sim.status = SimulationStatus.COMPLETED + baseline_sim.completed_at = datetime.now(timezone.utc) + reform_sim.status = SimulationStatus.COMPLETED + reform_sim.completed_at = datetime.now(timezone.utc) + report.status = ReportStatus.COMPLETED + + session.add(baseline_sim) + session.add(reform_sim) + session.add(report) session.commit() - except Exception as db_error: - console.print(f"[bold red]Failed to update DB: {db_error}[/bold red]") - raise + + logfire.info("US economy comparison completed", job_id=job_id) + + except Exception as e: + logfire.error("US economy comparison failed", job_id=job_id, error=str(e)) + # Use raw SQL to mark as failed - models may not be available + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'failed', error_message = :error " + "WHERE id = :job_id" + ), + {"job_id": job_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise def _get_pe_policy_uk(policy_id, model_version, session):