diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index f65345a..8536ea8 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -1,22 +1,27 @@ -"""Economic impact analysis endpoints. +"""Economic impact and household analysis endpoints. -Use these endpoints to analyse the economy-wide effects of policy reforms. -The /analysis/economic-impact endpoint compares baseline vs reform scenarios -across a population dataset, computing distributional impacts and program statistics. +Use these endpoints to analyse policy effects at both household and population levels. -This is an async operation - the endpoint returns immediately with a report_id, -and you poll /analysis/economic-impact/{report_id} until status is "completed". +HOUSEHOLD-LEVEL ANALYSIS: +- /analysis/marginal-rate: Compute effective marginal tax rate for a household +- /analysis/budget-constraint: Compute net income across income range +- /analysis/cliffs: Identify benefit cliffs and high marginal rate regions +- /analysis/compare-policies: Compare multiple policy reforms for a household -WORKFLOW for full economic analysis: +ECONOMY-WIDE ANALYSIS: +- /analysis/economic-impact: Compare baseline vs reform across population dataset + This is async - poll until status="completed" to get results + +WORKFLOW for economic analysis: 1. Create a policy with parameter changes: POST /policies 2. Get a dataset: GET /datasets (look for UK/US datasets) 3. Start analysis: POST /analysis/economic-impact with policy_id and dataset_id -4. Check status: GET /analysis/economic-impact/{report_id} - repeat until status="completed" -5. Review results: The completed response includes decile_impacts and program_statistics +4. Check status: GET /analysis/economic-impact/{report_id} until status="completed" +5. Review results: The response includes decile_impacts and program_statistics """ import math -from typing import Literal +from typing import Any, Literal from uuid import UUID, uuid5 from fastapi import APIRouter, Depends, HTTPException @@ -356,3 +361,639 @@ def get_economic_impact_status( raise HTTPException(status_code=500, detail="Simulation data missing") return _build_response(report, baseline_sim, reform_sim, session) + + +# ============================================================================ +# Household-level analysis endpoints +# ============================================================================ + + +class MarginalRateRequest(BaseModel): + """Request for marginal rate analysis. + + Computes effective marginal tax rate by calculating net income + at base income and base income + delta, then computing the + marginal rate as 1 - (change in net income / delta). + """ + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + people: list[dict[str, Any]] = Field( + description="List of people with flat variable values" + ) + benunit: dict[str, Any] = Field(default_factory=dict) + marital_unit: dict[str, Any] = Field(default_factory=dict) + family: dict[str, Any] = Field(default_factory=dict) + spm_unit: dict[str, Any] = Field(default_factory=dict) + tax_unit: dict[str, Any] = Field(default_factory=dict) + household: dict[str, Any] = Field(default_factory=dict) + year: int | None = Field(default=None) + policy_id: UUID | None = Field(default=None) + dynamic_id: UUID | None = Field(default=None) + person_index: int = Field( + default=0, + description="Index of person to vary income for (0-indexed)", + ) + income_variable: str = Field( + default="employment_income", + description="Which income variable to vary", + ) + delta: float = Field( + default=1.0, + description="Amount to increase income by (in currency units)", + ) + net_income_variable: str = Field( + default="household_net_income", + description="Variable to use for net income calculation", + ) + + +class MarginalRateResponse(BaseModel): + """Response from marginal rate analysis.""" + + base_net_income: float + incremented_net_income: float + delta: float + marginal_rate: float = Field( + description="Effective marginal tax rate (1 - change in net income / delta)" + ) + person_index: int + income_variable: str + + +@router.post("/marginal-rate", response_model=MarginalRateResponse) +def calculate_marginal_rate( + request: MarginalRateRequest, + session: Session = Depends(get_session), +) -> MarginalRateResponse: + """Calculate effective marginal tax rate for a household. + + Computes the rate at which the next unit of income is taxed/withdrawn, + accounting for both taxes and benefit withdrawals. + + Example: If £1 extra income results in £0.32 extra net income, + the marginal rate is 68%. + """ + import logfire + + from policyengine_api.api.household import ( + HouseholdCalculateRequest, + _calculate_uk, + _calculate_us, + _get_pe_dynamic, + _get_pe_policy, + ) + + with logfire.span( + "calculate_marginal_rate", + model=request.tax_benefit_model_name, + person_index=request.person_index, + delta=request.delta, + ): + # Validate person index + if request.person_index >= len(request.people): + raise HTTPException( + status_code=400, + detail=f"person_index {request.person_index} out of range " + f"(have {len(request.people)} people)", + ) + + # Get current income value + base_income = request.people[request.person_index].get( + request.income_variable, 0 + ) + + # Create base request + base_request = HouseholdCalculateRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=request.people, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + ) + + # Load model and policy/dynamic + with logfire.span("load_model"): + if request.tax_benefit_model_name == "policyengine_uk": + from policyengine.tax_benefit_models.uk import uk_latest + + pe_model_version = uk_latest + else: + from policyengine.tax_benefit_models.us import us_latest + + pe_model_version = us_latest + + policy = _get_pe_policy(request.policy_id, pe_model_version, session) + dynamic = _get_pe_dynamic(request.dynamic_id, pe_model_version, session) + + # Calculate base scenario + with logfire.span("calculate_base"): + if request.tax_benefit_model_name == "policyengine_uk": + base_result = _calculate_uk(base_request, policy, dynamic) + else: + base_result = _calculate_us(base_request, policy, dynamic) + + # Create incremented request + incremented_people = [p.copy() for p in request.people] + incremented_people[request.person_index][request.income_variable] = ( + base_income + request.delta + ) + + incremented_request = HouseholdCalculateRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=incremented_people, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + ) + + # Calculate incremented scenario + with logfire.span("calculate_incremented"): + if request.tax_benefit_model_name == "policyengine_uk": + incremented_result = _calculate_uk(incremented_request, policy, dynamic) + else: + incremented_result = _calculate_us(incremented_request, policy, dynamic) + + # Extract net income from household + base_net = base_result.household.get(request.net_income_variable, 0) + incremented_net = incremented_result.household.get( + request.net_income_variable, 0 + ) + + # Calculate marginal rate + change_in_net = incremented_net - base_net + marginal_rate = 1 - (change_in_net / request.delta) if request.delta != 0 else 0 + + return MarginalRateResponse( + base_net_income=base_net, + incremented_net_income=incremented_net, + delta=request.delta, + marginal_rate=marginal_rate, + person_index=request.person_index, + income_variable=request.income_variable, + ) + + +class BudgetConstraintRequest(BaseModel): + """Request for budget constraint analysis. + + Computes net income across a range of gross income values + to visualise the budget constraint (effective tax schedule). + """ + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + people: list[dict[str, Any]] = Field( + description="List of people with flat variable values" + ) + benunit: dict[str, Any] = Field(default_factory=dict) + marital_unit: dict[str, Any] = Field(default_factory=dict) + family: dict[str, Any] = Field(default_factory=dict) + spm_unit: dict[str, Any] = Field(default_factory=dict) + tax_unit: dict[str, Any] = Field(default_factory=dict) + household: dict[str, Any] = Field(default_factory=dict) + year: int | None = Field(default=None) + policy_id: UUID | None = Field(default=None) + dynamic_id: UUID | None = Field(default=None) + person_index: int = Field(default=0) + income_variable: str = Field(default="employment_income") + net_income_variable: str = Field(default="household_net_income") + min_income: float = Field(default=0, description="Minimum income to compute") + max_income: float = Field(default=100000, description="Maximum income to compute") + step: float = Field(default=1000, description="Income step size") + + +class BudgetConstraintPoint(BaseModel): + """Single point on budget constraint.""" + + gross_income: float + net_income: float + marginal_rate: float | None = None + + +class BudgetConstraintResponse(BaseModel): + """Response from budget constraint analysis.""" + + points: list[BudgetConstraintPoint] + person_index: int + income_variable: str + net_income_variable: str + + +@router.post("/budget-constraint", response_model=BudgetConstraintResponse) +def calculate_budget_constraint( + request: BudgetConstraintRequest, + session: Session = Depends(get_session), +) -> BudgetConstraintResponse: + """Calculate budget constraint across income range. + + Returns net income for each gross income level, useful for + visualising effective tax schedules and identifying cliffs. + """ + import logfire + + from policyengine_api.api.household import ( + HouseholdCalculateRequest, + _calculate_uk, + _calculate_us, + _get_pe_dynamic, + _get_pe_policy, + ) + + with logfire.span( + "calculate_budget_constraint", + model=request.tax_benefit_model_name, + min_income=request.min_income, + max_income=request.max_income, + step=request.step, + ): + # Validate person index + if request.person_index >= len(request.people): + raise HTTPException( + status_code=400, + detail=f"person_index {request.person_index} out of range", + ) + + # Load model + with logfire.span("load_model"): + if request.tax_benefit_model_name == "policyengine_uk": + from policyengine.tax_benefit_models.uk import uk_latest + + pe_model_version = uk_latest + else: + from policyengine.tax_benefit_models.us import us_latest + + pe_model_version = us_latest + + policy = _get_pe_policy(request.policy_id, pe_model_version, session) + dynamic = _get_pe_dynamic(request.dynamic_id, pe_model_version, session) + + points = [] + prev_net = None + + income = request.min_income + while income <= request.max_income: + with logfire.span("calculate_point", income=income): + # Create request for this income level + people_copy = [p.copy() for p in request.people] + people_copy[request.person_index][request.income_variable] = income + + calc_request = HouseholdCalculateRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=people_copy, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + ) + + if request.tax_benefit_model_name == "policyengine_uk": + result = _calculate_uk(calc_request, policy, dynamic) + else: + result = _calculate_us(calc_request, policy, dynamic) + + net_income = result.household.get(request.net_income_variable, 0) + + # Calculate marginal rate from previous point + marginal_rate = None + if prev_net is not None and request.step > 0: + change_in_net = net_income - prev_net + marginal_rate = 1 - (change_in_net / request.step) + + points.append( + BudgetConstraintPoint( + gross_income=income, + net_income=net_income, + marginal_rate=marginal_rate, + ) + ) + + prev_net = net_income + income += request.step + + return BudgetConstraintResponse( + points=points, + person_index=request.person_index, + income_variable=request.income_variable, + net_income_variable=request.net_income_variable, + ) + + +class CliffAnalysisRequest(BaseModel): + """Request for cliff analysis. + + Identifies income ranges where marginal rates exceed a threshold, + indicating benefit cliffs or phase-out regions. + """ + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + people: list[dict[str, Any]] = Field( + description="List of people with flat variable values" + ) + benunit: dict[str, Any] = Field(default_factory=dict) + marital_unit: dict[str, Any] = Field(default_factory=dict) + family: dict[str, Any] = Field(default_factory=dict) + spm_unit: dict[str, Any] = Field(default_factory=dict) + tax_unit: dict[str, Any] = Field(default_factory=dict) + household: dict[str, Any] = Field(default_factory=dict) + year: int | None = Field(default=None) + policy_id: UUID | None = Field(default=None) + person_index: int = Field(default=0) + income_variable: str = Field(default="employment_income") + net_income_variable: str = Field(default="household_net_income") + min_income: float = Field(default=0) + max_income: float = Field(default=100000) + step: float = Field(default=500) + cliff_threshold: float = Field( + default=0.7, + description="Marginal rate threshold to consider a cliff (0.7 = 70%)", + ) + + +class CliffRegion(BaseModel): + """A region where marginal rate exceeds threshold.""" + + start_income: float + end_income: float + peak_marginal_rate: float + avg_marginal_rate: float + + +class CliffAnalysisResponse(BaseModel): + """Response from cliff analysis.""" + + cliff_regions: list[CliffRegion] + max_marginal_rate: float + avg_marginal_rate: float + cliff_threshold: float + + +@router.post("/cliffs", response_model=CliffAnalysisResponse) +def analyse_cliffs( + request: CliffAnalysisRequest, + session: Session = Depends(get_session), +) -> CliffAnalysisResponse: + """Identify benefit cliffs and high marginal rate regions. + + Scans income range to find regions where marginal rates + exceed the specified threshold, indicating cliffs or + aggressive phase-outs. + """ + import logfire + + with logfire.span( + "analyse_cliffs", + model=request.tax_benefit_model_name, + threshold=request.cliff_threshold, + ): + # First get budget constraint + bc_request = BudgetConstraintRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=request.people, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + policy_id=request.policy_id, + person_index=request.person_index, + income_variable=request.income_variable, + net_income_variable=request.net_income_variable, + min_income=request.min_income, + max_income=request.max_income, + step=request.step, + ) + + bc_result = calculate_budget_constraint(bc_request, session) + + # Identify cliff regions + cliff_regions = [] + current_cliff_start = None + current_cliff_rates: list[float] = [] + + all_rates: list[float] = [] + + for point in bc_result.points: + if point.marginal_rate is not None: + all_rates.append(point.marginal_rate) + + if point.marginal_rate >= request.cliff_threshold: + if current_cliff_start is None: + current_cliff_start = point.gross_income - request.step + current_cliff_rates.append(point.marginal_rate) + else: + if current_cliff_start is not None: + cliff_regions.append( + CliffRegion( + start_income=current_cliff_start, + end_income=point.gross_income - request.step, + peak_marginal_rate=max(current_cliff_rates), + avg_marginal_rate=sum(current_cliff_rates) + / len(current_cliff_rates), + ) + ) + current_cliff_start = None + current_cliff_rates = [] + + # Handle cliff that extends to end + if current_cliff_start is not None: + avg_rate = sum(current_cliff_rates) / len(current_cliff_rates) + cliff_regions.append( + CliffRegion( + start_income=current_cliff_start, + end_income=request.max_income, + peak_marginal_rate=max(current_cliff_rates), + avg_marginal_rate=avg_rate, + ) + ) + + return CliffAnalysisResponse( + cliff_regions=cliff_regions, + max_marginal_rate=max(all_rates) if all_rates else 0, + avg_marginal_rate=sum(all_rates) / len(all_rates) if all_rates else 0, + cliff_threshold=request.cliff_threshold, + ) + + +class MultiPolicyCompareRequest(BaseModel): + """Request for multi-policy comparison. + + Compares a household under baseline and multiple reform policies. + """ + + tax_benefit_model_name: Literal["policyengine_uk", "policyengine_us"] = Field( + description="Which country model to use" + ) + people: list[dict[str, Any]] = Field( + description="List of people with flat variable values" + ) + benunit: dict[str, Any] = Field(default_factory=dict) + marital_unit: dict[str, Any] = Field(default_factory=dict) + family: dict[str, Any] = Field(default_factory=dict) + spm_unit: dict[str, Any] = Field(default_factory=dict) + tax_unit: dict[str, Any] = Field(default_factory=dict) + household: dict[str, Any] = Field(default_factory=dict) + year: int | None = Field(default=None) + policy_ids: list[UUID] = Field( + description="List of policy IDs to compare (in addition to baseline)" + ) + + +class PolicyResult(BaseModel): + """Result for a single policy.""" + + policy_id: UUID | None + policy_name: str + household: dict[str, Any] + person: list[dict[str, Any]] + + +class MultiPolicyCompareResponse(BaseModel): + """Response from multi-policy comparison.""" + + baseline: PolicyResult + reforms: list[PolicyResult] + summary: dict[str, Any] = Field( + description="Summary of key differences across policies" + ) + + +@router.post("/compare-policies", response_model=MultiPolicyCompareResponse) +def compare_multiple_policies( + request: MultiPolicyCompareRequest, + session: Session = Depends(get_session), +) -> MultiPolicyCompareResponse: + """Compare a household under baseline and multiple reform policies. + + Useful for evaluating alternative policy proposals side-by-side. + """ + import logfire + + from policyengine_api.api.household import ( + HouseholdCalculateRequest, + _calculate_uk, + _calculate_us, + _get_pe_policy, + ) + from policyengine_api.models import Policy + + with logfire.span( + "compare_multiple_policies", + model=request.tax_benefit_model_name, + num_policies=len(request.policy_ids), + ): + # Load model + with logfire.span("load_model"): + if request.tax_benefit_model_name == "policyengine_uk": + from policyengine.tax_benefit_models.uk import uk_latest + + pe_model_version = uk_latest + else: + from policyengine.tax_benefit_models.us import us_latest + + pe_model_version = us_latest + + # Calculate baseline + with logfire.span("calculate_baseline"): + base_request = HouseholdCalculateRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=request.people, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + ) + + if request.tax_benefit_model_name == "policyengine_uk": + baseline_result = _calculate_uk(base_request, None, None) + else: + baseline_result = _calculate_us(base_request, None, None) + + baseline = PolicyResult( + policy_id=None, + policy_name="Baseline (current law)", + household=baseline_result.household, + person=baseline_result.person, + ) + + # Calculate each reform + reforms = [] + for policy_id in request.policy_ids: + with logfire.span("calculate_reform", policy_id=str(policy_id)): + db_policy = session.get(Policy, policy_id) + if not db_policy: + raise HTTPException( + status_code=404, detail=f"Policy {policy_id} not found" + ) + + policy = _get_pe_policy(policy_id, pe_model_version, session) + + reform_request = HouseholdCalculateRequest( + tax_benefit_model_name=request.tax_benefit_model_name, + people=request.people, + benunit=request.benunit, + marital_unit=request.marital_unit, + family=request.family, + spm_unit=request.spm_unit, + tax_unit=request.tax_unit, + household=request.household, + year=request.year, + policy_id=policy_id, + ) + + if request.tax_benefit_model_name == "policyengine_uk": + reform_result = _calculate_uk(reform_request, policy, None) + else: + reform_result = _calculate_us(reform_request, policy, None) + + reforms.append( + PolicyResult( + policy_id=policy_id, + policy_name=db_policy.name, + household=reform_result.household, + person=reform_result.person, + ) + ) + + # Build summary comparing key variables + summary: dict[str, Any] = { + "net_income": { + "baseline": baseline.household.get("household_net_income", 0), + } + } + for reform in reforms: + summary["net_income"][reform.policy_name] = reform.household.get( + "household_net_income", 0 + ) + + return MultiPolicyCompareResponse( + baseline=baseline, + reforms=reforms, + summary=summary, + ) diff --git a/src/policyengine_api/api/parameters.py b/src/policyengine_api/api/parameters.py index 55c90f4..98ad4a0 100644 --- a/src/policyengine_api/api/parameters.py +++ b/src/policyengine_api/api/parameters.py @@ -37,10 +37,15 @@ def list_parameters( query = select(Parameter) if search: - search_filter = ( - Parameter.name.contains(search) - | Parameter.label.contains(search) - | Parameter.description.contains(search) + from sqlmodel import or_ + + search_pattern = f"%{search}%" + search_filter = or_( + Parameter.name.ilike(search_pattern), + Parameter.label.ilike(search_pattern) if Parameter.label else False, + Parameter.description.ilike(search_pattern) + if Parameter.description + else False, ) query = query.where(search_filter) diff --git a/src/policyengine_api/api/variables.py b/src/policyengine_api/api/variables.py index 8da5cf4..b40cfa7 100644 --- a/src/policyengine_api/api/variables.py +++ b/src/policyengine_api/api/variables.py @@ -3,14 +3,16 @@ Variables are the inputs and outputs of tax-benefit calculations. Use these endpoints to discover what variables exist (e.g. employment_income, income_tax) and their metadata. Variable names can be used in household calculation requests. + +Use the `search` parameter to filter variables by name, entity, or description. """ from typing import List from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from fastapi_cache.decorator import cache -from sqlmodel import Session, select +from sqlmodel import Session, or_, select from policyengine_api.models import Variable, VariableRead from policyengine_api.services.database import get_session @@ -21,17 +23,45 @@ @router.get("/", response_model=List[VariableRead]) @cache(expire=3600) # Cache for 1 hour def list_variables( - skip: int = 0, limit: int = 100, session: Session = Depends(get_session) + skip: int = 0, + limit: int = 100, + search: str | None = Query( + default=None, description="Search by variable name, entity, or description" + ), + entity: str | None = Query( + default=None, + description="Filter by entity type (e.g. person, household, benunit, tax_unit)", + ), + session: Session = Depends(get_session), ): - """List available variables with pagination. + """List available variables with pagination and search. Variables are inputs (e.g. employment_income, age) and outputs (e.g. income_tax, household_net_income) of tax-benefit calculations. Use variable names in household calculation requests. + + Use the `search` parameter to filter by name, entity, or description. + For example: search="income_tax" or search="universal credit" """ - variables = session.exec( - select(Variable).order_by(Variable.name).offset(skip).limit(limit) - ).all() + query = select(Variable) + + if search: + search_pattern = f"%{search}%" + query = query.where( + or_( + Variable.name.ilike(search_pattern), + Variable.entity.ilike(search_pattern), + Variable.description.ilike(search_pattern) + if Variable.description + else False, + ) + ) + + if entity: + query = query.where(Variable.entity == entity) + + query = query.order_by(Variable.name).offset(skip).limit(limit) + variables = session.exec(query).all() return variables diff --git a/tests/test_analysis_advanced.py b/tests/test_analysis_advanced.py new file mode 100644 index 0000000..a2a5ecd --- /dev/null +++ b/tests/test_analysis_advanced.py @@ -0,0 +1,306 @@ +"""Tests for advanced analysis endpoints.""" + +import pytest +from fastapi.testclient import TestClient + +from policyengine_api.main import app + +client = TestClient(app) + + +class TestMarginalRate: + """Tests for marginal rate endpoint.""" + + def test_marginal_rate_uk_basic(self): + """Test basic marginal rate calculation for UK.""" + response = client.post( + "/analysis/marginal-rate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + "delta": 1.0, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "marginal_rate" in data + assert "base_net_income" in data + assert "incremented_net_income" in data + # Marginal rate should be between 0 and 1 (mostly) + assert -0.5 < data["marginal_rate"] < 1.5 + + def test_marginal_rate_us_basic(self): + """Test basic marginal rate calculation for US.""" + response = client.post( + "/analysis/marginal-rate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 35, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "year": 2024, + "delta": 1.0, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "marginal_rate" in data + + def test_marginal_rate_higher_delta(self): + """Test marginal rate with larger delta.""" + response = client.post( + "/analysis/marginal-rate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 50000}], + "year": 2026, + "delta": 1000.0, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["delta"] == 1000.0 + + def test_marginal_rate_invalid_person_index(self): + """Test marginal rate with invalid person index.""" + response = client.post( + "/analysis/marginal-rate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + "person_index": 5, # Invalid - only 1 person + }, + ) + assert response.status_code == 400 + + def test_marginal_rate_at_different_incomes(self): + """Test that marginal rates differ at different income levels.""" + responses = [] + for income in [20000, 50000, 100000]: + response = client.post( + "/analysis/marginal-rate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": income}], + "year": 2026, + }, + ) + assert response.status_code == 200 + responses.append(response.json()["marginal_rate"]) + + # Higher earner should have higher marginal rate (generally) + # This is a soft check - tax systems are complex + assert len(set(responses)) >= 1 # At least some variation expected + + +class TestBudgetConstraint: + """Tests for budget constraint endpoint.""" + + def test_budget_constraint_uk_basic(self): + """Test basic budget constraint for UK.""" + response = client.post( + "/analysis/budget-constraint", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30}], + "year": 2026, + "min_income": 0, + "max_income": 50000, + "step": 10000, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "points" in data + assert len(data["points"]) == 6 # 0, 10k, 20k, 30k, 40k, 50k + # Verify points are sorted by income + incomes = [p["gross_income"] for p in data["points"]] + assert incomes == sorted(incomes) + + def test_budget_constraint_us(self): + """Test budget constraint for US.""" + response = client.post( + "/analysis/budget-constraint", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 35}], + "tax_unit": {"state_code": "TX"}, + "year": 2024, + "min_income": 0, + "max_income": 100000, + "step": 25000, + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["points"]) == 5 + + def test_budget_constraint_net_income_increases(self): + """Test that net income generally increases with gross income.""" + response = client.post( + "/analysis/budget-constraint", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30}], + "year": 2026, + "min_income": 20000, + "max_income": 80000, + "step": 20000, + }, + ) + assert response.status_code == 200 + data = response.json() + net_incomes = [p["net_income"] for p in data["points"]] + # Net income should generally increase (may have exceptions due to cliffs) + assert net_incomes[-1] > net_incomes[0] + + def test_budget_constraint_has_marginal_rates(self): + """Test that marginal rates are computed in budget constraint.""" + response = client.post( + "/analysis/budget-constraint", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30}], + "year": 2026, + "min_income": 0, + "max_income": 30000, + "step": 10000, + }, + ) + assert response.status_code == 200 + data = response.json() + # First point has no marginal rate (no previous point) + assert data["points"][0]["marginal_rate"] is None + # Subsequent points should have marginal rates + for point in data["points"][1:]: + assert point["marginal_rate"] is not None + + +class TestCliffAnalysis: + """Tests for cliff analysis endpoint.""" + + def test_cliff_analysis_uk(self): + """Test cliff analysis for UK household.""" + response = client.post( + "/analysis/cliffs", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + {"age": 30}, + {"age": 5}, # Child for benefit cliff potential + ], + "year": 2026, + "min_income": 0, + "max_income": 100000, + "step": 5000, + "cliff_threshold": 0.5, # 50% marginal rate + }, + ) + assert response.status_code == 200 + data = response.json() + assert "cliff_regions" in data + assert "max_marginal_rate" in data + assert "avg_marginal_rate" in data + assert data["cliff_threshold"] == 0.5 + + def test_cliff_analysis_us(self): + """Test cliff analysis for US household.""" + response = client.post( + "/analysis/cliffs", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 30}, {"age": 6}], + "tax_unit": {"state_code": "CA"}, + "year": 2024, + "min_income": 0, + "max_income": 80000, + "step": 4000, + "cliff_threshold": 0.6, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "cliff_regions" in data + + def test_cliff_analysis_strict_threshold(self): + """Test cliff analysis with very strict threshold.""" + response = client.post( + "/analysis/cliffs", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30}], + "year": 2026, + "min_income": 0, + "max_income": 60000, + "step": 5000, + "cliff_threshold": 0.9, # Very high - should find fewer cliffs + }, + ) + assert response.status_code == 200 + + +class TestMultiPolicyCompare: + """Tests for multi-policy comparison endpoint.""" + + def test_compare_no_policies(self): + """Test comparison with empty policy list (just baseline).""" + response = client.post( + "/analysis/compare-policies", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 40000}], + "year": 2026, + "policy_ids": [], + }, + ) + assert response.status_code == 200 + data = response.json() + assert "baseline" in data + assert data["baseline"]["policy_name"] == "Baseline (current law)" + assert data["reforms"] == [] + + def test_compare_invalid_policy(self): + """Test comparison with non-existent policy.""" + response = client.post( + "/analysis/compare-policies", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 40000}], + "year": 2026, + "policy_ids": ["00000000-0000-0000-0000-000000000000"], + }, + ) + assert response.status_code == 404 + + +class TestVariableSearch: + """Tests for variable search functionality.""" + + def test_variable_list(self): + """Test basic variable listing.""" + response = client.get("/variables?limit=10") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_variable_search(self): + """Test variable search by name.""" + response = client.get("/variables?search=income&limit=50") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_variable_filter_by_entity(self): + """Test variable filtering by entity.""" + response = client.get("/variables?entity=person&limit=20") + assert response.status_code == 200 + data = response.json() + # All returned variables should be for 'person' entity + for var in data: + if "entity" in var: + assert var["entity"] == "person" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_stress.py b/tests/test_stress.py new file mode 100644 index 0000000..fcff7c0 --- /dev/null +++ b/tests/test_stress.py @@ -0,0 +1,617 @@ +"""Stress tests for PolicyEngine API. + +These tests validate API robustness under various conditions: +- Complex household structures +- Edge cases in inputs +- Concurrent request handling +- Large/unusual parameter values +- Error recovery +""" + +import concurrent.futures +import time +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from policyengine_api.main import app + +client = TestClient(app) + + +class TestComplexHouseholds: + """Tests for complex household structures.""" + + def test_large_household_uk(self): + """Test UK calculation with many household members.""" + # 8-person household: grandparents, parents, 4 children + people = [ + {"age": 70, "employment_income": 0, "state_pension": 9000}, # Grandpa + {"age": 68, "employment_income": 0, "state_pension": 8500}, # Grandma + {"age": 45, "employment_income": 55000}, # Parent 1 + {"age": 43, "employment_income": 35000}, # Parent 2 + {"age": 17, "employment_income": 5000}, # Teen with part-time job + {"age": 14}, # Child + {"age": 10}, # Child + {"age": 3}, # Toddler (childcare age) + ] + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": people, + "year": 2026, + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["person"]) == 8 + + def test_large_household_us(self): + """Test US calculation with many household members.""" + people = [ + {"age": 72, "social_security": 24000}, # Grandparent + {"age": 40, "employment_income": 85000}, # Parent 1 + {"age": 38, "employment_income": 45000}, # Parent 2 + {"age": 16, "employment_income": 3000}, # Teen + {"age": 12}, # Child + {"age": 8}, # Child + {"age": 2}, # Toddler + ] + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": people, + "tax_unit": {"state_code": "NY"}, + "household": {"state_fips": 36}, + "year": 2024, + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["person"]) == 7 + + def test_single_parent_multiple_children(self): + """Test single parent with multiple children.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + {"age": 32, "employment_income": 28000}, + {"age": 8}, + {"age": 5}, + {"age": 2}, + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + data = response.json() + assert len(data["person"]) == 4 + + +class TestEdgeCaseInputs: + """Tests for edge case inputs.""" + + def test_zero_income(self): + """Test household with zero income.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 0}], + "year": 2026, + }, + ) + assert response.status_code == 200 + + def test_very_high_income_uk(self): + """Test UK with very high income (over 500k).""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 45, "employment_income": 1000000}], + "year": 2026, + }, + ) + assert response.status_code == 200 + data = response.json() + # Verify high earner pays tax + assert data["person"][0].get("income_tax", 0) > 0 + + def test_very_high_income_us(self): + """Test US with very high income.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 50, "employment_income": 2000000}], + "tax_unit": {"state_code": "CA"}, + "year": 2024, + }, + ) + assert response.status_code == 200 + + def test_elderly_household(self): + """Test elderly household (pension age).""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + { + "age": 80, + "state_pension": 12000, + "private_pension_income": 15000, + }, + {"age": 78, "state_pension": 11000}, + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + + def test_newborn(self): + """Test household with newborn.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + {"age": 28, "employment_income": 30000}, + {"age": 0}, # Newborn + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + + def test_negative_income_handled(self): + """Test that negative income is handled appropriately.""" + # Self-employment loss scenario + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 35, "self_employment_income": -5000}], + "year": 2026, + }, + ) + # Should either work or return a validation error, not crash + assert response.status_code in [200, 422] + + +class TestAllUSStates: + """Test calculations work for all US states.""" + + US_STATES = [ + "AL", + "AK", + "AZ", + "AR", + "CA", + "CO", + "CT", + "DE", + "FL", + "GA", + "HI", + "ID", + "IL", + "IN", + "IA", + "KS", + "KY", + "LA", + "ME", + "MD", + "MA", + "MI", + "MN", + "MS", + "MO", + "MT", + "NE", + "NV", + "NH", + "NJ", + "NM", + "NY", + "NC", + "ND", + "OH", + "OK", + "OR", + "PA", + "RI", + "SC", + "SD", + "TN", + "TX", + "UT", + "VT", + "VA", + "WA", + "WV", + "WI", + "WY", + "DC", + ] + + @pytest.mark.parametrize("state", US_STATES[:10]) # Test first 10 for speed + def test_state_calculation(self, state): + """Test calculation works for a given US state.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 35, "employment_income": 75000}], + "tax_unit": {"state_code": state}, + "year": 2024, + }, + ) + assert response.status_code == 200, f"Failed for state {state}" + + +class TestUKRegions: + """Test UK regional calculations.""" + + UK_REGIONS = [ + "NORTH_EAST", + "NORTH_WEST", + "YORKSHIRE", + "EAST_MIDLANDS", + "WEST_MIDLANDS", + "EAST_OF_ENGLAND", + "LONDON", + "SOUTH_EAST", + "SOUTH_WEST", + "WALES", + "SCOTLAND", + "NORTHERN_IRELAND", + ] + + @pytest.mark.parametrize("region", UK_REGIONS) + def test_region_calculation(self, region): + """Test calculation works for a given UK region.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 35, "employment_income": 40000}], + "household": {"region": region}, + "year": 2026, + }, + ) + assert response.status_code == 200, f"Failed for region {region}" + + +class TestConcurrentRequests: + """Tests for concurrent request handling.""" + + def test_concurrent_household_calculations(self): + """Test multiple concurrent household calculations.""" + + def make_request(income: int) -> dict[str, Any]: + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": income}], + "year": 2026, + }, + ) + return {"income": income, "status": response.status_code} + + incomes = [20000, 30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000] + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(make_request, incomes)) + + # All requests should succeed + for result in results: + assert result["status"] == 200, f"Failed for income {result['income']}" + + def test_concurrent_different_models(self): + """Test concurrent requests to different models.""" + requests_data = [ + { + "model": "policyengine_uk", + "people": [{"age": 30, "employment_income": 40000}], + "year": 2026, + }, + { + "model": "policyengine_us", + "people": [{"age": 30, "employment_income": 60000}], + "year": 2024, + }, + { + "model": "policyengine_uk", + "people": [{"age": 40, "employment_income": 60000}], + "year": 2026, + }, + { + "model": "policyengine_us", + "people": [{"age": 40, "employment_income": 80000}], + "year": 2024, + }, + ] + + def make_request(data: dict) -> int: + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": data["model"], + "people": data["people"], + "year": data["year"], + }, + ) + return response.status_code + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(make_request, requests_data)) + + assert all(status == 200 for status in results) + + +class TestResponseTimes: + """Tests for response time benchmarks.""" + + def test_simple_calculation_latency(self): + """Test that simple calculations are fast.""" + start = time.time() + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 35000}], + "year": 2026, + }, + ) + elapsed = time.time() - start + + assert response.status_code == 200 + # Simple calculation should be under 10 seconds (generous for model load) + assert elapsed < 10, f"Simple calculation took {elapsed:.2f}s (expected < 10s)" + + def test_complex_calculation_latency(self): + """Test that complex calculations complete in reasonable time.""" + start = time.time() + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + { + "age": 65, + "state_pension": 10000, + "private_pension_income": 20000, + }, + {"age": 40, "employment_income": 55000}, + { + "age": 38, + "employment_income": 35000, + "self_employment_income": 10000, + }, + {"age": 16, "employment_income": 3000}, + {"age": 12}, + {"age": 8}, + ], + "household": {"region": "LONDON", "rent": 2000}, + "year": 2026, + }, + ) + elapsed = time.time() - start + + assert response.status_code == 200 + assert elapsed < 15, f"Complex calculation took {elapsed:.2f}s (expected < 15s)" + + +class TestMetadataEndpoints: + """Stress tests for metadata endpoints.""" + + def test_list_all_variables(self): + """Test listing variables with pagination.""" + response = client.get("/variables?limit=1000") + assert response.status_code == 200 + data = response.json() + assert len(data) > 0 + + def test_list_all_parameters(self): + """Test listing parameters with pagination.""" + response = client.get("/parameters?limit=1000") + assert response.status_code == 200 + data = response.json() + assert len(data) > 0 + + def test_parameter_search(self): + """Test parameter search functionality.""" + # Search for income tax related parameters + response = client.get("/parameters?search=income_tax") + assert response.status_code == 200 + data = response.json() + # Should find some parameters + assert isinstance(data, list) + + def test_concurrent_metadata_requests(self): + """Test concurrent metadata requests.""" + endpoints = [ + "/variables?limit=100", + "/parameters?limit=100", + "/tax-benefit-models", + "/datasets", + ] + + def make_request(endpoint: str) -> int: + return client.get(endpoint).status_code + + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + results = list(executor.map(make_request, endpoints)) + + assert all(status == 200 for status in results) + + +class TestErrorHandling: + """Tests for error handling and recovery.""" + + def test_invalid_model_name(self): + """Test invalid model name returns proper error.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "nonexistent_model", + "people": [{"age": 30}], + }, + ) + assert response.status_code == 422 + + def test_missing_required_fields(self): + """Test missing required fields returns proper error.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + # Missing 'people' + }, + ) + assert response.status_code == 422 + + def test_invalid_age(self): + """Test invalid age handling.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": -5, "employment_income": 30000}], + "year": 2026, + }, + ) + # Should either handle gracefully or return validation error + assert response.status_code in [200, 422] + + def test_invalid_year(self): + """Test invalid year handling.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 1800, # Invalid year + }, + ) + # Should either handle gracefully or return validation error + assert response.status_code in [200, 422, 500] + + def test_empty_people_list(self): + """Test empty people list handling.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [], + "year": 2026, + }, + ) + # Should return validation error + assert response.status_code in [422, 500] + + def test_malformed_json(self): + """Test malformed JSON handling.""" + response = client.post( + "/household/calculate", + content="not valid json", + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 422 + + +class TestHouseholdImpact: + """Stress tests for household impact comparison.""" + + def test_impact_without_policy(self): + """Test impact comparison without policy (baseline vs baseline).""" + response = client.post( + "/household/impact", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 40000}], + "year": 2026, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "baseline" in data + assert "reform" in data + assert "impact" in data + + def test_impact_complex_household(self): + """Test impact comparison with complex household.""" + response = client.post( + "/household/impact", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + {"age": 40, "employment_income": 50000}, + {"age": 38, "employment_income": 30000}, + {"age": 10}, + {"age": 7}, + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + + +class TestMultipleIncomeSources: + """Tests for households with multiple income sources.""" + + def test_uk_multiple_income_sources(self): + """Test UK household with diverse income sources.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [ + { + "age": 55, + "employment_income": 40000, + "self_employment_income": 15000, + "savings_interest_income": 2000, + "dividend_income": 5000, + "private_pension_income": 3000, + } + ], + "year": 2026, + }, + ) + assert response.status_code == 200 + + def test_us_multiple_income_sources(self): + """Test US household with diverse income sources.""" + response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [ + { + "age": 50, + "employment_income": 100000, + "self_employment_income": 25000, + "taxable_interest_income": 5000, + "qualified_dividend_income": 10000, + "long_term_capital_gains": 20000, + } + ], + "tax_unit": {"state_code": "CA"}, + "year": 2024, + }, + ) + assert response.status_code == 200 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])