From ed3d7160c64a3ea28f4c1c463e7838f3abed3637 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sat, 3 Jan 2026 07:44:13 +0000 Subject: [PATCH 1/2] feat: add poverty and inequality outputs to economy comparison Adds poverty rates and inequality metrics to the economy comparison analysis, matching policyengine.py PR #207. Poverty outputs: - UK: absolute BHC/AHC, relative BHC/AHC - US: SPM, deep SPM Inequality outputs: - Gini coefficient - Top 10%, top 1%, bottom 50% income shares New models: Poverty, Inequality Migration: 20260103000000_add_poverty_inequality.sql Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/modal_app.py | 94 +++++++++++++++++++ src/policyengine_api/models/__init__.py | 8 ++ src/policyengine_api/models/inequality.py | 41 ++++++++ src/policyengine_api/models/poverty.py | 41 ++++++++ .../20260103000000_add_poverty_inequality.sql | 33 +++++++ 5 files changed, 217 insertions(+) create mode 100644 src/policyengine_api/models/inequality.py create mode 100644 src/policyengine_api/models/poverty.py create mode 100644 supabase/migrations/20260103000000_add_poverty_inequality.sql diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 5e36ca3..d78f796 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -707,6 +707,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: from policyengine_api.models import ( Dataset, DecileImpact, + Inequality, + Poverty, ProgramStatistics, Report, ReportStatus, @@ -748,6 +750,12 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.outputs.inequality import ( + calculate_uk_inequality, + ) + from policyengine.outputs.poverty import ( + calculate_uk_poverty_rates, + ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, @@ -881,6 +889,45 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: except KeyError: pass # Variable not in model, skip silently + # Calculate poverty rates + with logfire.span("calculate_poverty"): + for sim, sim_id in [ + (pe_baseline_sim, baseline_sim.id), + (pe_reform_sim, reform_sim.id), + ]: + poverty_collection = calculate_uk_poverty_rates(sim) + for pov in poverty_collection.outputs: + poverty_record = Poverty( + simulation_id=sim_id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate inequality + with logfire.span("calculate_inequality"): + for sim, sim_id in [ + (pe_baseline_sim, baseline_sim.id), + (pe_reform_sim, reform_sim.id), + ]: + ineq = calculate_uk_inequality(sim) + inequality_record = Inequality( + simulation_id=sim_id, + report_id=report.id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(inequality_record) + # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) @@ -949,6 +996,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: from policyengine_api.models import ( Dataset, DecileImpact, + Inequality, + Poverty, ProgramStatistics, Report, ReportStatus, @@ -983,6 +1032,12 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation from policyengine.outputs import DecileImpact as PEDecileImpact + from policyengine.outputs.inequality import ( + calculate_us_inequality, + ) + from policyengine.outputs.poverty import ( + calculate_us_poverty_rates, + ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, @@ -1113,6 +1168,45 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: except KeyError: pass # Variable not in model, skip silently + # Calculate poverty rates + with logfire.span("calculate_poverty"): + for sim, sim_id in [ + (pe_baseline_sim, baseline_sim.id), + (pe_reform_sim, reform_sim.id), + ]: + poverty_collection = calculate_us_poverty_rates(sim) + for pov in poverty_collection.outputs: + poverty_record = Poverty( + simulation_id=sim_id, + report_id=report.id, + poverty_type=pov.poverty_type, + entity=pov.entity, + filter_variable=pov.filter_variable, + headcount=pov.headcount, + total_population=pov.total_population, + rate=pov.rate, + ) + session.add(poverty_record) + + # Calculate inequality + with logfire.span("calculate_inequality"): + for sim, sim_id in [ + (pe_baseline_sim, baseline_sim.id), + (pe_reform_sim, reform_sim.id), + ]: + ineq = calculate_us_inequality(sim) + inequality_record = Inequality( + simulation_id=sim_id, + report_id=report.id, + income_variable=ineq.income_variable, + entity=ineq.entity, + gini=ineq.gini, + top_10_share=ineq.top_10_share, + top_1_share=ineq.top_1_share, + bottom_50_share=ineq.bottom_50_share, + ) + session.add(inequality_record) + # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED baseline_sim.completed_at = datetime.now(timezone.utc) diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index d820ea4..408ae3f 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -16,6 +16,7 @@ HouseholdJobRead, HouseholdJobStatus, ) +from .inequality import Inequality, InequalityCreate, InequalityRead from .output import ( AggregateOutput, AggregateOutputCreate, @@ -25,6 +26,7 @@ from .parameter import Parameter, ParameterCreate, ParameterRead from .parameter_value import ParameterValue, ParameterValueCreate, ParameterValueRead from .policy import Policy, PolicyCreate, PolicyRead +from .poverty import Poverty, PovertyCreate, PovertyRead from .program_statistics import ( ProgramStatistics, ProgramStatisticsCreate, @@ -70,6 +72,9 @@ "HouseholdJobCreate", "HouseholdJobRead", "HouseholdJobStatus", + "Inequality", + "InequalityCreate", + "InequalityRead", "Parameter", "ParameterCreate", "ParameterRead", @@ -79,6 +84,9 @@ "Policy", "PolicyCreate", "PolicyRead", + "Poverty", + "PovertyCreate", + "PovertyRead", "ProgramStatistics", "ProgramStatisticsCreate", "ProgramStatisticsRead", diff --git a/src/policyengine_api/models/inequality.py b/src/policyengine_api/models/inequality.py new file mode 100644 index 0000000..322a702 --- /dev/null +++ b/src/policyengine_api/models/inequality.py @@ -0,0 +1,41 @@ +"""Inequality output model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class InequalityBase(SQLModel): + """Base inequality fields.""" + + simulation_id: UUID = Field(foreign_key="simulations.id") + report_id: UUID | None = Field(default=None, foreign_key="reports.id") + income_variable: str + entity: str = "household" + gini: float | None = None + top_10_share: float | None = None + top_1_share: float | None = None + bottom_50_share: float | None = None + + +class Inequality(InequalityBase, table=True): + """Inequality database model.""" + + __tablename__ = "inequality" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class InequalityCreate(InequalityBase): + """Schema for creating inequality records.""" + + pass + + +class InequalityRead(InequalityBase): + """Schema for reading inequality records.""" + + id: UUID + created_at: datetime diff --git a/src/policyengine_api/models/poverty.py b/src/policyengine_api/models/poverty.py new file mode 100644 index 0000000..05912c8 --- /dev/null +++ b/src/policyengine_api/models/poverty.py @@ -0,0 +1,41 @@ +"""Poverty output model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class PovertyBase(SQLModel): + """Base poverty fields.""" + + simulation_id: UUID = Field(foreign_key="simulations.id") + report_id: UUID | None = Field(default=None, foreign_key="reports.id") + poverty_type: str # e.g. "absolute_bhc", "spm", etc. + entity: str = "person" + filter_variable: str | None = None + headcount: float | None = None + total_population: float | None = None + rate: float | None = None + + +class Poverty(PovertyBase, table=True): + """Poverty database model.""" + + __tablename__ = "poverty" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class PovertyCreate(PovertyBase): + """Schema for creating poverty records.""" + + pass + + +class PovertyRead(PovertyBase): + """Schema for reading poverty records.""" + + id: UUID + created_at: datetime diff --git a/supabase/migrations/20260103000000_add_poverty_inequality.sql b/supabase/migrations/20260103000000_add_poverty_inequality.sql new file mode 100644 index 0000000..f315d93 --- /dev/null +++ b/supabase/migrations/20260103000000_add_poverty_inequality.sql @@ -0,0 +1,33 @@ +-- Add poverty and inequality tables for economic analysis + +CREATE TABLE IF NOT EXISTS poverty ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, + report_id UUID REFERENCES reports(id) ON DELETE CASCADE, + poverty_type VARCHAR NOT NULL, + entity VARCHAR NOT NULL DEFAULT 'person', + filter_variable VARCHAR, + headcount FLOAT, + total_population FLOAT, + rate FLOAT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS inequality ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, + report_id UUID REFERENCES reports(id) ON DELETE CASCADE, + income_variable VARCHAR NOT NULL, + entity VARCHAR NOT NULL DEFAULT 'household', + gini FLOAT, + top_10_share FLOAT, + top_1_share FLOAT, + bottom_50_share FLOAT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Indexes for efficient querying +CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id); +CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id); +CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id); +CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id); From 3f29aa469747f667566acd08bdf52d3aac66f421 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Sat, 3 Jan 2026 08:40:41 +0000 Subject: [PATCH 2/2] fix: sanitize NaN/Inf values in household results for JSON serialization Also updated tests to use the async job polling pattern. --- src/policyengine_api/api/household.py | 45 +++++++++++----- tests/test_household.py | 77 ++++++++++++++++++--------- 2 files changed, 83 insertions(+), 39 deletions(-) diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index b0e99c9..9ed1aef 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -4,6 +4,7 @@ Poll the status endpoint until the job is complete. """ +import math from typing import Any, Literal from uuid import UUID @@ -22,12 +23,26 @@ from policyengine_api.services.database import get_session +def _sanitize_for_json(obj: Any) -> Any: + """Replace NaN/Inf values with None for JSON serialization.""" + if isinstance(obj, float): + if math.isnan(obj) or math.isinf(obj): + return None + return obj + elif isinstance(obj, dict): + return {k: _sanitize_for_json(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_sanitize_for_json(v) for v in obj] + return obj + + def get_traceparent() -> str | None: """Get the current W3C traceparent header for distributed tracing.""" carrier: dict[str, str] = {} TraceContextTextMapPropagator().inject(carrier) return carrier.get("traceparent") + router = APIRouter(prefix="/household", tags=["household"]) @@ -254,11 +269,13 @@ def _run_local_household_uk( job = session.get(HouseholdJob, job_id) if job: job.status = HouseholdJobStatus.COMPLETED - job.result = { - "person": result.person, - "benunit": result.benunit, - "household": result.household, - } + job.result = _sanitize_for_json( + { + "person": result.person, + "benunit": result.benunit, + "household": result.household, + } + ) job.completed_at = datetime.now(timezone.utc) session.add(job) session.commit() @@ -343,14 +360,16 @@ def _run_local_household_us( job = session.get(HouseholdJob, job_id) if job: job.status = HouseholdJobStatus.COMPLETED - job.result = { - "person": result.person, - "marital_unit": result.marital_unit, - "family": result.family, - "spm_unit": result.spm_unit, - "tax_unit": result.tax_unit, - "household": result.household, - } + job.result = _sanitize_for_json( + { + "person": result.person, + "marital_unit": result.marital_unit, + "family": result.family, + "spm_unit": result.spm_unit, + "tax_unit": result.tax_unit, + "household": result.household, + } + ) job.completed_at = datetime.now(timezone.utc) session.add(job) session.commit() diff --git a/tests/test_household.py b/tests/test_household.py index 8f17176..3a79fe1 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -1,16 +1,31 @@ """Tests for household calculation endpoint.""" -import pytest - -pytestmark = pytest.mark.integration +import time +import pytest from fastapi.testclient import TestClient from policyengine_api.main import app +pytestmark = pytest.mark.integration + client = TestClient(app) +def _poll_job(job_id: str, max_attempts: int = 10) -> dict: + """Poll for job completion.""" + for _ in range(max_attempts): + response = client.get(f"/household/calculate/{job_id}") + assert response.status_code == 200 + data = response.json() + if data["status"] == "completed": + return data + if data["status"] == "failed": + raise AssertionError(f"Job failed: {data.get('error_message')}") + time.sleep(0.1) + raise AssertionError("Job timed out") + + class TestUKHouseholdCalculate: """Tests for UK household calculations.""" @@ -25,11 +40,15 @@ def test_single_adult(self): }, ) assert response.status_code == 200 - data = response.json() - assert "person" in data - assert "benunit" in data - assert "household" in data - assert len(data["person"]) == 1 + job_data = response.json() + assert "job_id" in job_data + + data = _poll_job(job_data["job_id"]) + assert data["result"] is not None + assert "person" in data["result"] + assert "benunit" in data["result"] + assert "household" in data["result"] + assert len(data["result"]["person"]) == 1 def test_couple_with_children(self): """Test calculation for a couple with children.""" @@ -47,8 +66,9 @@ def test_couple_with_children(self): }, ) assert response.status_code == 200 - data = response.json() - assert len(data["person"]) == 4 + job_data = response.json() + data = _poll_job(job_data["job_id"]) + assert len(data["result"]["person"]) == 4 def test_with_household_data(self): """Test calculation with household-level data.""" @@ -65,8 +85,9 @@ def test_with_household_data(self): }, ) assert response.status_code == 200 - data = response.json() - assert "household" in data + job_data = response.json() + data = _poll_job(job_data["job_id"]) + assert "household" in data["result"] def test_output_contains_tax_variables(self): """Test that output contains expected tax/benefit variables.""" @@ -79,10 +100,11 @@ def test_output_contains_tax_variables(self): }, ) assert response.status_code == 200 - data = response.json() - assert isinstance(data["person"], list) - assert len(data["person"]) > 0 - person_data = data["person"][0] + job_data = response.json() + data = _poll_job(job_data["job_id"]) + assert isinstance(data["result"]["person"], list) + assert len(data["result"]["person"]) > 0 + person_data = data["result"]["person"][0] assert isinstance(person_data, dict) @@ -100,14 +122,16 @@ def test_single_adult(self): }, ) assert response.status_code == 200 - data = response.json() - assert "person" in data - assert "household" in data - assert "tax_unit" in data - assert "spm_unit" in data - assert "family" in data - assert "marital_unit" in data - assert len(data["person"]) == 1 + job_data = response.json() + data = _poll_job(job_data["job_id"]) + result = data["result"] + assert "person" in result + assert "household" in result + assert "tax_unit" in result + assert "spm_unit" in result + assert "family" in result + assert "marital_unit" in result + assert len(result["person"]) == 1 def test_family_with_children(self): """Test calculation for a family with children.""" @@ -125,8 +149,9 @@ def test_family_with_children(self): }, ) assert response.status_code == 200 - data = response.json() - assert len(data["person"]) == 4 + job_data = response.json() + data = _poll_job(job_data["job_id"]) + assert len(data["result"]["person"]) == 4 class TestValidation: