Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions src/policyengine_api/api/household.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Poll the status endpoint until the job is complete.
"""

import math
from typing import Any, Literal
from uuid import UUID

Expand All @@ -22,12 +23,26 @@
from policyengine_api.services.database import get_session


def _sanitize_for_json(obj: Any) -> Any:
"""Replace NaN/Inf values with None for JSON serialization."""
if isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return None
return obj
elif isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_sanitize_for_json(v) for v in obj]
return obj


def get_traceparent() -> str | None:
"""Get the current W3C traceparent header for distributed tracing."""
carrier: dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
return carrier.get("traceparent")


router = APIRouter(prefix="/household", tags=["household"])


Expand Down Expand Up @@ -254,11 +269,13 @@ def _run_local_household_uk(
job = session.get(HouseholdJob, job_id)
if job:
job.status = HouseholdJobStatus.COMPLETED
job.result = {
"person": result.person,
"benunit": result.benunit,
"household": result.household,
}
job.result = _sanitize_for_json(
{
"person": result.person,
"benunit": result.benunit,
"household": result.household,
}
)
job.completed_at = datetime.now(timezone.utc)
session.add(job)
session.commit()
Expand Down Expand Up @@ -343,14 +360,16 @@ def _run_local_household_us(
job = session.get(HouseholdJob, job_id)
if job:
job.status = HouseholdJobStatus.COMPLETED
job.result = {
"person": result.person,
"marital_unit": result.marital_unit,
"family": result.family,
"spm_unit": result.spm_unit,
"tax_unit": result.tax_unit,
"household": result.household,
}
job.result = _sanitize_for_json(
{
"person": result.person,
"marital_unit": result.marital_unit,
"family": result.family,
"spm_unit": result.spm_unit,
"tax_unit": result.tax_unit,
"household": result.household,
}
)
job.completed_at = datetime.now(timezone.utc)
session.add(job)
session.commit()
Expand Down
94 changes: 94 additions & 0 deletions src/policyengine_api/modal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,8 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
from policyengine_api.models import (
Dataset,
DecileImpact,
Inequality,
Poverty,
ProgramStatistics,
Report,
ReportStatus,
Expand Down Expand Up @@ -748,6 +750,12 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
# Import policyengine
from policyengine.core import Simulation as PESimulation
from policyengine.outputs import DecileImpact as PEDecileImpact
from policyengine.outputs.inequality import (
calculate_uk_inequality,
)
from policyengine.outputs.poverty import (
calculate_uk_poverty_rates,
)
from policyengine.tax_benefit_models.uk import uk_latest
from policyengine.tax_benefit_models.uk.datasets import (
PolicyEngineUKDataset,
Expand Down Expand Up @@ -881,6 +889,45 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None:
except KeyError:
pass # Variable not in model, skip silently

# Calculate poverty rates
with logfire.span("calculate_poverty"):
for sim, sim_id in [
(pe_baseline_sim, baseline_sim.id),
(pe_reform_sim, reform_sim.id),
]:
poverty_collection = calculate_uk_poverty_rates(sim)
for pov in poverty_collection.outputs:
poverty_record = Poverty(
simulation_id=sim_id,
report_id=report.id,
poverty_type=pov.poverty_type,
entity=pov.entity,
filter_variable=pov.filter_variable,
headcount=pov.headcount,
total_population=pov.total_population,
rate=pov.rate,
)
session.add(poverty_record)

# Calculate inequality
with logfire.span("calculate_inequality"):
for sim, sim_id in [
(pe_baseline_sim, baseline_sim.id),
(pe_reform_sim, reform_sim.id),
]:
ineq = calculate_uk_inequality(sim)
inequality_record = Inequality(
simulation_id=sim_id,
report_id=report.id,
income_variable=ineq.income_variable,
entity=ineq.entity,
gini=ineq.gini,
top_10_share=ineq.top_10_share,
top_1_share=ineq.top_1_share,
bottom_50_share=ineq.bottom_50_share,
)
session.add(inequality_record)

# Mark simulations and report as completed
baseline_sim.status = SimulationStatus.COMPLETED
baseline_sim.completed_at = datetime.now(timezone.utc)
Expand Down Expand Up @@ -949,6 +996,8 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
from policyengine_api.models import (
Dataset,
DecileImpact,
Inequality,
Poverty,
ProgramStatistics,
Report,
ReportStatus,
Expand Down Expand Up @@ -983,6 +1032,12 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
# Import policyengine
from policyengine.core import Simulation as PESimulation
from policyengine.outputs import DecileImpact as PEDecileImpact
from policyengine.outputs.inequality import (
calculate_us_inequality,
)
from policyengine.outputs.poverty import (
calculate_us_poverty_rates,
)
from policyengine.tax_benefit_models.us import us_latest
from policyengine.tax_benefit_models.us.datasets import (
PolicyEngineUSDataset,
Expand Down Expand Up @@ -1113,6 +1168,45 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None:
except KeyError:
pass # Variable not in model, skip silently

# Calculate poverty rates
with logfire.span("calculate_poverty"):
for sim, sim_id in [
(pe_baseline_sim, baseline_sim.id),
(pe_reform_sim, reform_sim.id),
]:
poverty_collection = calculate_us_poverty_rates(sim)
for pov in poverty_collection.outputs:
poverty_record = Poverty(
simulation_id=sim_id,
report_id=report.id,
poverty_type=pov.poverty_type,
entity=pov.entity,
filter_variable=pov.filter_variable,
headcount=pov.headcount,
total_population=pov.total_population,
rate=pov.rate,
)
session.add(poverty_record)

# Calculate inequality
with logfire.span("calculate_inequality"):
for sim, sim_id in [
(pe_baseline_sim, baseline_sim.id),
(pe_reform_sim, reform_sim.id),
]:
ineq = calculate_us_inequality(sim)
inequality_record = Inequality(
simulation_id=sim_id,
report_id=report.id,
income_variable=ineq.income_variable,
entity=ineq.entity,
gini=ineq.gini,
top_10_share=ineq.top_10_share,
top_1_share=ineq.top_1_share,
bottom_50_share=ineq.bottom_50_share,
)
session.add(inequality_record)

# Mark simulations and report as completed
baseline_sim.status = SimulationStatus.COMPLETED
baseline_sim.completed_at = datetime.now(timezone.utc)
Expand Down
8 changes: 8 additions & 0 deletions src/policyengine_api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
HouseholdJobRead,
HouseholdJobStatus,
)
from .inequality import Inequality, InequalityCreate, InequalityRead
from .output import (
AggregateOutput,
AggregateOutputCreate,
Expand All @@ -25,6 +26,7 @@
from .parameter import Parameter, ParameterCreate, ParameterRead
from .parameter_value import ParameterValue, ParameterValueCreate, ParameterValueRead
from .policy import Policy, PolicyCreate, PolicyRead
from .poverty import Poverty, PovertyCreate, PovertyRead
from .program_statistics import (
ProgramStatistics,
ProgramStatisticsCreate,
Expand Down Expand Up @@ -70,6 +72,9 @@
"HouseholdJobCreate",
"HouseholdJobRead",
"HouseholdJobStatus",
"Inequality",
"InequalityCreate",
"InequalityRead",
"Parameter",
"ParameterCreate",
"ParameterRead",
Expand All @@ -79,6 +84,9 @@
"Policy",
"PolicyCreate",
"PolicyRead",
"Poverty",
"PovertyCreate",
"PovertyRead",
"ProgramStatistics",
"ProgramStatisticsCreate",
"ProgramStatisticsRead",
Expand Down
41 changes: 41 additions & 0 deletions src/policyengine_api/models/inequality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Inequality output model."""

from datetime import datetime, timezone
from uuid import UUID, uuid4

from sqlmodel import Field, SQLModel


class InequalityBase(SQLModel):
"""Base inequality fields."""

simulation_id: UUID = Field(foreign_key="simulations.id")
report_id: UUID | None = Field(default=None, foreign_key="reports.id")
income_variable: str
entity: str = "household"
gini: float | None = None
top_10_share: float | None = None
top_1_share: float | None = None
bottom_50_share: float | None = None


class Inequality(InequalityBase, table=True):
"""Inequality database model."""

__tablename__ = "inequality"

id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))


class InequalityCreate(InequalityBase):
"""Schema for creating inequality records."""

pass


class InequalityRead(InequalityBase):
"""Schema for reading inequality records."""

id: UUID
created_at: datetime
41 changes: 41 additions & 0 deletions src/policyengine_api/models/poverty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Poverty output model."""

from datetime import datetime, timezone
from uuid import UUID, uuid4

from sqlmodel import Field, SQLModel


class PovertyBase(SQLModel):
"""Base poverty fields."""

simulation_id: UUID = Field(foreign_key="simulations.id")
report_id: UUID | None = Field(default=None, foreign_key="reports.id")
poverty_type: str # e.g. "absolute_bhc", "spm", etc.
entity: str = "person"
filter_variable: str | None = None
headcount: float | None = None
total_population: float | None = None
rate: float | None = None


class Poverty(PovertyBase, table=True):
"""Poverty database model."""

__tablename__ = "poverty"

id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))


class PovertyCreate(PovertyBase):
"""Schema for creating poverty records."""

pass


class PovertyRead(PovertyBase):
"""Schema for reading poverty records."""

id: UUID
created_at: datetime
33 changes: 33 additions & 0 deletions supabase/migrations/20260103000000_add_poverty_inequality.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
-- Add poverty and inequality tables for economic analysis

CREATE TABLE IF NOT EXISTS poverty (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE,
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
poverty_type VARCHAR NOT NULL,
entity VARCHAR NOT NULL DEFAULT 'person',
filter_variable VARCHAR,
headcount FLOAT,
total_population FLOAT,
rate FLOAT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);

CREATE TABLE IF NOT EXISTS inequality (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE,
report_id UUID REFERENCES reports(id) ON DELETE CASCADE,
income_variable VARCHAR NOT NULL,
entity VARCHAR NOT NULL DEFAULT 'household',
gini FLOAT,
top_10_share FLOAT,
top_1_share FLOAT,
bottom_50_share FLOAT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);

-- Indexes for efficient querying
CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id);
CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id);
CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id);
CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id);
Loading