diff --git a/.beads/.gitignore b/.beads/.gitignore new file mode 100644 index 000000000..f438450fc --- /dev/null +++ b/.beads/.gitignore @@ -0,0 +1,29 @@ +# SQLite databases +*.db +*.db?* +*.db-journal +*.db-wal +*.db-shm + +# Daemon runtime files +daemon.lock +daemon.log +daemon.pid +bd.sock + +# Legacy database files +db.sqlite +bd.db + +# Merge artifacts (temporary files from 3-way merge) +beads.base.jsonl +beads.base.meta.json +beads.left.jsonl +beads.left.meta.json +beads.right.jsonl +beads.right.meta.json + +# Keep JSONL exports and config (source of truth for git) +!issues.jsonl +!metadata.json +!config.json diff --git a/.beads/README.md b/.beads/README.md new file mode 100644 index 000000000..8d603245b --- /dev/null +++ b/.beads/README.md @@ -0,0 +1,81 @@ +# Beads - AI-Native Issue Tracking + +Welcome to Beads! This repository uses **Beads** for issue tracking - a modern, AI-native tool designed to live directly in your codebase alongside your code. + +## What is Beads? + +Beads is issue tracking that lives in your repo, making it perfect for AI coding agents and developers who want their issues close to their code. No web UI required - everything works through the CLI and integrates seamlessly with git. + +**Learn more:** [github.com/steveyegge/beads](https://github.com/steveyegge/beads) + +## Quick Start + +### Essential Commands + +```bash +# Create new issues +bd create "Add user authentication" + +# View all issues +bd list + +# View issue details +bd show + +# Update issue status +bd update --status in-progress +bd update --status done + +# Sync with git remote +bd sync +``` + +### Working with Issues + +Issues in Beads are: +- **Git-native**: Stored in `.beads/issues.jsonl` and synced like code +- **AI-friendly**: CLI-first design works perfectly with AI coding agents +- **Branch-aware**: Issues can follow your branch workflow +- **Always in sync**: Auto-syncs with your commits + +## Why Beads? + +✨ **AI-Native Design** +- Built specifically for AI-assisted development workflows +- CLI-first interface works seamlessly with AI coding agents +- No context switching to web UIs + +🚀 **Developer Focused** +- Issues live in your repo, right next to your code +- Works offline, syncs when you push +- Fast, lightweight, and stays out of your way + +🔧 **Git Integration** +- Automatic sync with git commits +- Branch-aware issue tracking +- Intelligent JSONL merge resolution + +## Get Started with Beads + +Try Beads in your own projects: + +```bash +# Install Beads +curl -sSL https://raw.githubusercontent.com/steveyegge/beads/main/scripts/install.sh | bash + +# Initialize in your repo +bd init + +# Create your first issue +bd create "Try out Beads" +``` + +## Learn More + +- **Documentation**: [github.com/steveyegge/beads/docs](https://github.com/steveyegge/beads/tree/main/docs) +- **Quick Start Guide**: Run `bd quickstart` +- **Examples**: [github.com/steveyegge/beads/examples](https://github.com/steveyegge/beads/tree/main/examples) + +--- + +*Beads: Issue tracking that moves at the speed of thought* ⚡ diff --git a/.beads/config.yaml b/.beads/config.yaml new file mode 100644 index 000000000..95c5f3e70 --- /dev/null +++ b/.beads/config.yaml @@ -0,0 +1,56 @@ +# Beads Configuration File +# This file configures default behavior for all bd commands in this repository +# All settings can also be set via environment variables (BD_* prefix) +# or overridden with command-line flags + +# Issue prefix for this repository (used by bd init) +# If not set, bd init will auto-detect from directory name +# Example: issue-prefix: "myproject" creates issues like "myproject-1", "myproject-2", etc. +# issue-prefix: "" + +# Use no-db mode: load from JSONL, no SQLite, write back after each command +# When true, bd will use .beads/issues.jsonl as the source of truth +# instead of SQLite database +# no-db: false + +# Disable daemon for RPC communication (forces direct database access) +# no-daemon: false + +# Disable auto-flush of database to JSONL after mutations +# no-auto-flush: false + +# Disable auto-import from JSONL when it's newer than database +# no-auto-import: false + +# Enable JSON output by default +# json: false + +# Default actor for audit trails (overridden by BD_ACTOR or --actor) +# actor: "" + +# Path to database (overridden by BEADS_DB or --db) +# db: "" + +# Auto-start daemon if not running (can also use BEADS_AUTO_START_DAEMON) +# auto-start-daemon: true + +# Debounce interval for auto-flush (can also use BEADS_FLUSH_DEBOUNCE) +# flush-debounce: "5s" + +# Multi-repo configuration (experimental - bd-307) +# Allows hydrating from multiple repositories and routing writes to the correct JSONL +# repos: +# primary: "." # Primary repo (where this database lives) +# additional: # Additional repos to hydrate from (read-only) +# - ~/beads-planning # Personal planning repo +# - ~/work-planning # Work planning repo + +# Integration settings (access with 'bd config get/set') +# These are stored in the database, not in this file: +# - jira.url +# - jira.project +# - linear.url +# - linear.api-key +# - github.org +# - github.repo +# - sync.branch - Git branch for beads commits (use BEADS_SYNC_BRANCH env var or bd config set) diff --git a/.beads/metadata.json b/.beads/metadata.json new file mode 100644 index 000000000..4faf148a1 --- /dev/null +++ b/.beads/metadata.json @@ -0,0 +1,5 @@ +{ + "database": "beads.db", + "jsonl_export": "issues.jsonl", + "last_bd_version": "0.26.0" +} \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..807d5983d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ + +# Use bd merge for beads JSONL files +.beads/issues.jsonl merge=beads diff --git a/Makefile b/Makefile index 09d984a96..92529d5cf 100644 --- a/Makefile +++ b/Makefile @@ -63,6 +63,7 @@ database: python policyengine_us_data/db/etl_snap.py python policyengine_us_data/db/etl_state_income_tax.py python policyengine_us_data/db/etl_irs_soi.py + python policyengine_us_data/db/reconcile_targets.py python policyengine_us_data/db/validate_database.py database-refresh: diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..bd1b9320d 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + added: + - Added two-pass geographic target reconciliation to ETL pipeline, ensuring state targets sum to national and CD targets sum to state. + - Added raw_value column to Target model to preserve original source values before reconciliation. + - Added geographic reconciliation validation to validate_database.py. diff --git a/docs/pipeline-comparison.png b/docs/pipeline-comparison.png new file mode 100644 index 000000000..2b68b5a06 Binary files /dev/null and b/docs/pipeline-comparison.png differ diff --git a/paper/scripts/generate_validation_metrics.py b/paper/scripts/generate_validation_metrics.py deleted file mode 100644 index db586959d..000000000 --- a/paper/scripts/generate_validation_metrics.py +++ /dev/null @@ -1,244 +0,0 @@ -""" -Generate validation metrics for the enhanced CPS paper. - -This script computes all validation metrics comparing the Enhanced CPS -to the baseline CPS and PUF datasets. Results are saved as CSV files -for inclusion in the paper. -""" - -import pandas as pd -import numpy as np -from policyengine_us import Microsimulation -from policyengine_us_data.datasets.cps.enhanced_cps import EnhancedCPS -from policyengine_us_data.datasets.cps.cps import CPS -from policyengine_us_data.datasets.irs.puf import PUF -import json -import os - - -def calculate_validation_metrics(year: int = 2024): - """ - Calculate validation metrics across all three datasets. - - Args: - year: Tax year to analyze - - Returns: - dict: Validation results by dataset and metric type - """ - results = {} - - # Initialize datasets - print(f"Loading datasets for {year}...") - - try: - # Load each dataset - enhanced_cps = EnhancedCPS(year=year) - baseline_cps = CPS(year=year) - puf = PUF(year=year) - - # Get the loss matrix targets that Enhanced CPS was calibrated to - from policyengine_us_data.utils.loss import build_loss_matrix - - loss_matrix, targets, names = build_loss_matrix(EnhancedCPS, str(year)) - - print(f"Found {len(targets)} calibration targets") - - # Calculate how well each dataset matches the targets - for dataset_name, dataset in [ - ("Enhanced CPS", enhanced_cps), - ("Baseline CPS", baseline_cps), - ("PUF", puf), - ]: - print(f"\nCalculating metrics for {dataset_name}...") - - # Create microsimulation - sim = Microsimulation(dataset=dataset) - - # Calculate achieved values for each target - # This is placeholder - actual implementation would compute - # each target value using the microsimulation - - results[dataset_name] = { - "total_targets": len(targets), - "dataset_year": year, - "status": "TO BE CALCULATED", - } - - except Exception as e: - print(f"Error loading datasets: {e}") - print("Creating placeholder results...") - - results = { - "Enhanced CPS": { - "total_targets": "[TO BE CALCULATED]", - "outperforms_cps_pct": "[TO BE CALCULATED]", - "outperforms_puf_pct": "[TO BE CALCULATED]", - }, - "Baseline CPS": { - "total_targets": "[TO BE CALCULATED]", - }, - "PUF": { - "total_targets": "[TO BE CALCULATED]", - }, - } - - return results - - -def calculate_poverty_metrics(year: int = 2024): - """ - Calculate poverty metrics for each dataset. - - Args: - year: Tax year to analyze - - Returns: - pd.DataFrame: Poverty rates by dataset - """ - print(f"\nCalculating poverty metrics for {year}...") - - # Placeholder implementation - # Actual implementation would calculate SPM poverty rates - - results = pd.DataFrame( - { - "Dataset": ["CPS", "PUF", "Enhanced CPS"], - "SPM Poverty Rate": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Child Poverty Rate": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Senior Poverty Rate": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - } - ) - - return results - - -def calculate_income_distribution_metrics(year: int = 2024): - """ - Calculate income distribution metrics. - - Args: - year: Tax year to analyze - - Returns: - pd.DataFrame: Distribution metrics by dataset - """ - print(f"\nCalculating income distribution metrics for {year}...") - - # Placeholder implementation - results = pd.DataFrame( - { - "Dataset": ["CPS", "PUF", "Enhanced CPS"], - "Gini Coefficient": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Top 1% Share": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Top 10% Share": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Bottom 50% Share": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - } - ) - - return results - - -def calculate_policy_reform_impacts(year: int = 2024): - """ - Calculate revenue impacts of top rate reform. - - Args: - year: Tax year to analyze - - Returns: - pd.DataFrame: Revenue projections by dataset - """ - print(f"\nCalculating policy reform impacts for {year}...") - - # Placeholder for top marginal rate increase from 37% to 39.6% - results = pd.DataFrame( - { - "Dataset": ["CPS", "PUF", "Enhanced CPS"], - "Revenue Impact ($B)": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Affected Tax Units": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - "Average Tax Increase": [ - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - "[TO BE CALCULATED]", - ], - } - ) - - return results - - -def main(): - """Generate all paper results.""" - - # Create results directory - results_dir = "paper/results" - os.makedirs(results_dir, exist_ok=True) - - print("Generating validation metrics for PolicyEngine Enhanced CPS paper") - print("=" * 70) - - # Generate validation metrics - validation_results = calculate_validation_metrics() - - # Save as JSON for reference - with open(f"{results_dir}/validation_summary.json", "w") as f: - json.dump(validation_results, f, indent=2) - - # Generate poverty metrics table - poverty_df = calculate_poverty_metrics() - poverty_df.to_csv(f"{results_dir}/poverty_metrics.csv", index=False) - - # Generate income distribution table - dist_df = calculate_income_distribution_metrics() - dist_df.to_csv(f"{results_dir}/income_distribution.csv", index=False) - - # Generate policy reform table - reform_df = calculate_policy_reform_impacts() - reform_df.to_csv(f"{results_dir}/policy_reform_impacts.csv", index=False) - - print(f"\nResults saved to {results_dir}/") - print("\nNOTE: All metrics marked as [TO BE CALCULATED] require full") - print( - "dataset generation and microsimulation runs to compute actual values." - ) - - -if __name__ == "__main__": - main() diff --git a/policyengine_us_data/calibration/__init__.py b/policyengine_us_data/calibration/__init__.py new file mode 100644 index 000000000..4113f8b17 --- /dev/null +++ b/policyengine_us_data/calibration/__init__.py @@ -0,0 +1,7 @@ +from policyengine_us_data.calibration.national_matrix_builder import ( + NationalMatrixBuilder, +) +from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + build_calibration_inputs, +) diff --git a/policyengine_us_data/calibration/base_matrix_builder.py b/policyengine_us_data/calibration/base_matrix_builder.py new file mode 100644 index 000000000..9ebc5a49c --- /dev/null +++ b/policyengine_us_data/calibration/base_matrix_builder.py @@ -0,0 +1,184 @@ +""" +Base matrix builder with shared logic for calibration. + +Provides common functionality used by both ``NationalMatrixBuilder`` +(national dense matrix) and ``SparseMatrixBuilder`` (geo-stacking +sparse matrix): + +- SQLAlchemy engine setup +- Entity relationship mapping (person -> household/tax_unit/spm_unit) +- Person-level constraint evaluation with household aggregation +- Database query for stratum constraints +""" + +import logging +from typing import List, Optional + +import numpy as np +import pandas as pd +from sqlalchemy import create_engine + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, +) + +logger = logging.getLogger(__name__) + + +class BaseMatrixBuilder: + """Shared base for calibration matrix builders. + + Handles engine creation, entity relationship caching, constraint + evaluation at person level with household-level aggregation, and + stratum constraint queries. + + Args: + db_uri: SQLAlchemy-style database URI, e.g. + ``"sqlite:///path/to/policy_data.db"``. + time_period: Tax year for the calibration (e.g. 2024). + """ + + def __init__(self, db_uri: str, time_period: int): + self.db_uri = db_uri + self.engine = create_engine(db_uri) + self.time_period = time_period + self._entity_rel_cache: Optional[pd.DataFrame] = None + + # ------------------------------------------------------------------ + # Entity relationship mapping + # ------------------------------------------------------------------ + + def _build_entity_relationship(self, sim) -> pd.DataFrame: + """Build entity relationship DataFrame mapping persons to + all entity IDs. + + Args: + sim: Microsimulation instance. + + Returns: + DataFrame with columns person_id, household_id, + tax_unit_id, spm_unit_id (one row per person). + """ + if self._entity_rel_cache is not None: + return self._entity_rel_cache + + self._entity_rel_cache = pd.DataFrame( + { + "person_id": sim.calculate( + "person_id", map_to="person" + ).values, + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, + } + ) + return self._entity_rel_cache + + # ------------------------------------------------------------------ + # Constraint evaluation + # ------------------------------------------------------------------ + + def _evaluate_constraints_entity_aware( + self, + sim, + constraints: List[dict], + n_households: int, + ) -> np.ndarray: + """Evaluate constraints at person level and aggregate to + household level using ``.any()``. + + Each constraint variable is calculated at person level; the + boolean intersection is then rolled up so that a household + passes if *at least one person* satisfies all constraints. + + Args: + sim: Microsimulation instance. + constraints: List of constraint dicts with keys + ``variable``, ``operation``, ``value``. + n_households: Total number of households. + + Returns: + Boolean mask array of length *n_households*. + """ + if not constraints: + return np.ones(n_households, dtype=bool) + + entity_rel = self._build_entity_relationship(sim) + n_persons = len(entity_rel) + + person_mask = np.ones(n_persons, dtype=bool) + + for c in constraints: + var = c["variable"] + op = c["operation"] + val = c["value"] + + try: + constraint_values = sim.calculate( + var, self.time_period, map_to="person" + ).values + except Exception as exc: + logger.warning( + "Cannot evaluate constraint variable " + "'%s': %s -- returning all-False mask", + var, + exc, + ) + return np.zeros(n_households, dtype=bool) + + person_mask &= apply_op(constraint_values, op, val) + + # Aggregate to household using .any() + entity_rel_with_mask = entity_rel.copy() + entity_rel_with_mask["satisfies"] = person_mask + + household_mask_series = entity_rel_with_mask.groupby("household_id")[ + "satisfies" + ].any() + + household_ids = sim.calculate( + "household_id", map_to="household" + ).values + household_mask = np.array( + [ + household_mask_series.get(hh_id, False) + for hh_id in household_ids + ] + ) + + return household_mask + + # ------------------------------------------------------------------ + # Database queries + # ------------------------------------------------------------------ + + def _get_stratum_constraints(self, stratum_id: int) -> List[dict]: + """Get the direct constraints for a single stratum. + + Args: + stratum_id: Primary key in the ``strata`` table. + + Returns: + List of dicts with keys ``variable``, ``operation``, + ``value``. + """ + query = """ + SELECT constraint_variable AS variable, + operation, + value + FROM stratum_constraints + WHERE stratum_id = :stratum_id + """ + with self.engine.connect() as conn: + df = pd.read_sql( + query, + conn, + params={"stratum_id": int(stratum_id)}, + ) + return df.to_dict("records") diff --git a/policyengine_us_data/calibration/clone_and_assign.py b/policyengine_us_data/calibration/clone_and_assign.py new file mode 100644 index 000000000..0033b5299 --- /dev/null +++ b/policyengine_us_data/calibration/clone_and_assign.py @@ -0,0 +1,150 @@ +"""Clone CPS records and assign random geography.""" + +import logging +from functools import lru_cache +from dataclasses import dataclass + +import numpy as np +import pandas as pd + +from policyengine_us_data.storage import STORAGE_FOLDER + +logger = logging.getLogger(__name__) + + +@dataclass +class GeographyAssignment: + """Random geography assignment for cloned CPS records. + + All arrays have length n_records * n_clones. + Index i corresponds to clone i // n_records, + record i % n_records. + """ + + block_geoid: np.ndarray # str array, 15-char block GEOIDs + cd_geoid: np.ndarray # str array of CD GEOIDs + state_fips: np.ndarray # int array of 2-digit state FIPS + n_records: int + n_clones: int + + +@lru_cache(maxsize=1) +def load_global_block_distribution(): + """Load block_cd_distributions.csv.gz and build + global distribution. + + Returns: + Tuple of (block_geoids, cd_geoids, state_fips, + probabilities) where each is a numpy array indexed + by block row. Probabilities are normalized to sum + to 1 globally. + + Raises: + FileNotFoundError: If the CSV file does not exist. + """ + csv_path = STORAGE_FOLDER / "block_cd_distributions.csv.gz" + if not csv_path.exists(): + raise FileNotFoundError( + f"{csv_path} not found. " + "Run make_block_cd_distributions.py to generate." + ) + + df = pd.read_csv(csv_path, dtype={"block_geoid": str}) + + block_geoids = df["block_geoid"].values + cd_geoids = df["cd_geoid"].astype(str).values + # State FIPS is first 2 digits of block GEOID + state_fips = np.array([int(b[:2]) for b in block_geoids]) + + probs = df["probability"].values.astype(np.float64) + probs = probs / probs.sum() # Normalize globally + + return block_geoids, cd_geoids, state_fips, probs + + +def assign_random_geography( + n_records: int, + n_clones: int = 10, + seed: int = 42, +) -> GeographyAssignment: + """Assign random census block geography to cloned + CPS records. + + Each of n_records * n_clones total records gets a + random census block sampled from the global + population-weighted distribution. State and CD are + derived from the block GEOID. + + Args: + n_records: Number of households in the base CPS + dataset. + n_clones: Number of clones (default 10). + seed: Random seed for reproducibility. + + Returns: + GeographyAssignment with arrays of length + n_records * n_clones. + """ + blocks, cds, states, probs = load_global_block_distribution() + + n_total = n_records * n_clones + rng = np.random.default_rng(seed) + indices = rng.choice(len(blocks), size=n_total, p=probs) + + return GeographyAssignment( + block_geoid=blocks[indices], + cd_geoid=cds[indices], + state_fips=states[indices], + n_records=n_records, + n_clones=n_clones, + ) + + +def double_geography_for_puf( + geography: GeographyAssignment, +) -> GeographyAssignment: + """Double geography arrays for PUF clone step. + + After PUF cloning doubles the base records, the geography + assignment must also double: each record and its PUF copy + share the same geographic assignment. + + The output has n_records = 2 * geography.n_records, with + the first half being the CPS records and the second half + being the PUF copies. + + Args: + geography: Original geography assignment. + + Returns: + New GeographyAssignment with doubled n_records. + """ + n_old = geography.n_records + n_new = n_old * 2 + n_clones = geography.n_clones + + # For each clone, interleave: [CPS records, PUF records] + # Original layout: clone0_rec0..rec_N, clone1_rec0..rec_N, ... + # New layout: clone0_cps0..N_puf0..N, clone1_cps0..N_puf0..N + new_blocks = [] + new_cds = [] + new_states = [] + + for c in range(n_clones): + start = c * n_old + end = start + n_old + clone_blocks = geography.block_geoid[start:end] + clone_cds = geography.cd_geoid[start:end] + clone_states = geography.state_fips[start:end] + # CPS half + PUF half (same geography) + new_blocks.append(np.concatenate([clone_blocks, clone_blocks])) + new_cds.append(np.concatenate([clone_cds, clone_cds])) + new_states.append(np.concatenate([clone_states, clone_states])) + + return GeographyAssignment( + block_geoid=np.concatenate(new_blocks), + cd_geoid=np.concatenate(new_cds), + state_fips=np.concatenate(new_states), + n_records=n_new, + n_clones=n_clones, + ) diff --git a/policyengine_us_data/calibration/fit_national_weights.py b/policyengine_us_data/calibration/fit_national_weights.py new file mode 100644 index 000000000..9ae476138 --- /dev/null +++ b/policyengine_us_data/calibration/fit_national_weights.py @@ -0,0 +1,406 @@ +""" +National L0 calibration for Enhanced CPS. + +L0-regularized optimization via l0-python's SparseCalibrationWeights. +Reads active targets from policy_data.db via NationalMatrixBuilder. + +Usage: + python -m policyengine_us_data.calibration.fit_national_weights \\ + --dataset path/to/extended_cps_2024.h5 \\ + --db-path path/to/policy_data.db \\ + --output path/to/enhanced_cps_2024.h5 \\ + --epochs 1000 \\ + --lambda-l0 1e-6 +""" + +import argparse +import logging +from pathlib import Path +from typing import Tuple + +import h5py +import numpy as np +import scipy.sparse + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# ============================================================================ +# HYPERPARAMETERS (higher L0 than local mode for more sparsification) +# ============================================================================ +LAMBDA_L0 = 1e-6 +LAMBDA_L2 = 1e-12 +LEARNING_RATE = 0.15 +DEFAULT_EPOCHS = 1000 +BETA = 0.35 +GAMMA = -0.1 +ZETA = 1.1 +INIT_KEEP_PROB = 0.999 +LOG_WEIGHT_JITTER_SD = 0.05 +LOG_ALPHA_JITTER_SD = 0.01 + +# Minimum weight floor for zero/negative initial weights +_WEIGHT_FLOOR = 0.01 + + +def parse_args(argv=None): + """Parse CLI arguments.""" + parser = argparse.ArgumentParser( + description="National L0 calibration for Enhanced CPS" + ) + parser.add_argument( + "--dataset", + default=None, + help="Path to Extended CPS h5 file", + ) + parser.add_argument( + "--db-path", + default=None, + help="Path to policy_data.db", + ) + parser.add_argument( + "--geo-level", + default="all", + choices=["national", "state", "cd", "all"], + help="Geographic level filter (default: all)", + ) + parser.add_argument( + "--output", + default=None, + help="Path to output enhanced_cps h5 file", + ) + parser.add_argument( + "--epochs", + type=int, + default=DEFAULT_EPOCHS, + help=f"Training epochs (default: {DEFAULT_EPOCHS})", + ) + parser.add_argument( + "--lambda-l0", + type=float, + default=LAMBDA_L0, + help=f"L0 penalty (default: {LAMBDA_L0})", + ) + parser.add_argument( + "--device", + default="cpu", + choices=["cpu", "cuda"], + help="Device for training (default: cpu)", + ) + return parser.parse_args(argv) + + +def initialize_weights(original_weights: np.ndarray) -> np.ndarray: + """ + Initialize calibration weights from original household weights. + + Zero and negative weights are floored to a small positive value + to avoid log-domain issues in the L0 optimizer. + + Args: + original_weights: Array of household weights from the CPS. + + Returns: + Array of positive initial weights. + """ + weights = original_weights.copy().astype(np.float64) + weights[weights <= 0] = _WEIGHT_FLOOR + return weights + + +def build_calibration_inputs( + dataset_class, + time_period: int, + db_path: str, + sim=None, + geo_level: str = "all", +) -> Tuple[np.ndarray, np.ndarray, list]: + """ + Build calibration matrix and targets from the database. + + Reads targets from policy_data.db via NationalMatrixBuilder. + + Args: + dataset_class: The input dataset class (e.g., ExtendedCPS_2024). + time_period: Tax year for calibration. + db_path: Path to policy_data.db. + sim: Optional pre-built Microsimulation instance. + geo_level: Geographic filter -- ``"national"``, + ``"state"``, ``"cd"``, or ``"all"`` (default). + + Returns: + Tuple of (matrix, targets, target_names) where: + - matrix: shape (n_households, n_targets) float32 + - targets: shape (n_targets,) float64 + - target_names: list of str + """ + from policyengine_us_data.calibration.national_matrix_builder import ( + NationalMatrixBuilder, + ) + + db_uri = f"sqlite:///{db_path}" + builder = NationalMatrixBuilder(db_uri=db_uri, time_period=time_period) + + if sim is None: + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=dataset_class) + sim.default_calculation_period = time_period + + matrix, targets, names = builder.build_matrix( + sim=sim, dataset_class=dataset_class, geo_level=geo_level + ) + return ( + matrix.astype(np.float32), + targets.astype(np.float64), + names, + ) + + +def compute_diagnostics( + targets: np.ndarray, + estimates: np.ndarray, + names: list, + threshold: float = 0.10, + n_worst: int = 20, +) -> dict: + """ + Compute calibration diagnostics. + + Args: + targets: Target values. + estimates: Predicted values from weighted matrix. + names: Target names. + threshold: Fraction for "within X%" check. + n_worst: Number of worst targets to report. + + Returns: + Dict with keys: + - pct_within_10: % of targets within threshold of target + - worst_targets: list of (name, rel_error) tuples + """ + with np.errstate(divide="ignore", invalid="ignore"): + rel_errors = np.where( + np.abs(targets) > 1e-6, + (estimates - targets) / targets, + 0.0, + ) + + abs_rel_errors = np.abs(rel_errors) + within = np.mean(abs_rel_errors <= threshold) * 100.0 + + # Sort by absolute relative error descending + sorted_idx = np.argsort(-abs_rel_errors) + worst = [(names[i], float(rel_errors[i])) for i in sorted_idx[:n_worst]] + + return { + "pct_within_10": float(within), + "worst_targets": worst, + } + + +def fit_national_weights( + matrix: np.ndarray, + targets: np.ndarray, + initial_weights: np.ndarray, + epochs: int = DEFAULT_EPOCHS, + lambda_l0: float = LAMBDA_L0, + lambda_l2: float = LAMBDA_L2, + learning_rate: float = LEARNING_RATE, + device: str = "cpu", +) -> np.ndarray: + """Run L0-regularized calibration to find optimal household + weights. + + Uses l0-python's ``SparseCalibrationWeights`` which expects:: + + M @ w = y_hat + + where ``M`` has shape ``(n_targets, n_features)`` and ``w`` has + shape ``(n_features,)``. The input *matrix* is provided in the + more natural ``(n_households, n_targets)`` layout and is + transposed internally before being passed to the optimizer. + + Args: + matrix: Calibration matrix, shape + (n_households, n_targets). + targets: Target values, shape (n_targets,). + initial_weights: Starting household weights + (n_households,). + epochs: Number of training epochs. + lambda_l0: L0 regularization strength. + lambda_l2: L2 regularization strength. + learning_rate: Adam learning rate. + device: Torch device ("cpu" or "cuda"). + + Returns: + Calibrated weight array, shape (n_households,). + """ + try: + from l0.calibration import SparseCalibrationWeights + except ImportError: + raise ImportError( + "l0-python is required for L0 calibration. " + "Install with: pip install l0-python" + ) + + n_households, n_targets = matrix.shape + logger.info( + f"Starting L0 calibration: {n_households} households, " + f"{n_targets} targets, {epochs} epochs" + ) + logger.info( + f"Hyperparams: lambda_l0={lambda_l0}, " + f"lambda_l2={lambda_l2}, lr={learning_rate}" + ) + + # Transpose to (n_targets, n_households) for l0-python: + # M @ w = y_hat + # (n_targets, n_households) @ (n_households,) = (n_targets,) + # l0-python expects a scipy sparse matrix, not dense numpy. + M = scipy.sparse.csr_matrix(matrix.T) + + model = SparseCalibrationWeights( + n_features=n_households, + beta=BETA, + gamma=GAMMA, + zeta=ZETA, + init_keep_prob=INIT_KEEP_PROB, + init_weights=initial_weights, + log_weight_jitter_sd=LOG_WEIGHT_JITTER_SD, + log_alpha_jitter_sd=LOG_ALPHA_JITTER_SD, + device=device, + ) + + model.fit( + M=M, + y=targets, + target_groups=None, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + lr=learning_rate, + epochs=epochs, + loss_type="relative", + verbose=True, + verbose_freq=max(1, epochs // 10), + ) + + import torch + + with torch.no_grad(): + weights = model.get_weights(deterministic=True).cpu().numpy() + + logger.info( + f"Calibration complete. " + f"Non-zero weights: {(weights > 0).sum():,} " + f"/ {len(weights):,}" + ) + return weights + + +def save_weights_to_h5(h5_path: str, weights: np.ndarray, year: int = 2024): + """ + Save calibrated weights into an existing h5 dataset file. + + Overwrites the household_weight/{year} dataset while preserving + all other data in the file. + + Args: + h5_path: Path to the h5 file. + weights: Calibrated weight array. + year: Time period key. + """ + key = f"household_weight/{year}" + with h5py.File(h5_path, "a") as f: + if key in f: + del f[key] + f.create_dataset(key, data=weights) + logger.info(f"Saved weights to {h5_path} [{key}]") + + +def run_validation(weights, matrix, targets, names): + """Print quick validation of key aggregates.""" + estimates = weights @ matrix + + diag = compute_diagnostics(targets, estimates, names) + logger.info(f"Targets within 10%%: {diag['pct_within_10']:.1f}%%") + logger.info("Worst targets:") + for name, rel_err in diag["worst_targets"][:10]: + logger.info(f" {name:60s} {rel_err:+.2%}") + + # Highlight key programs if present + for keyword in ["population", "income_tax", "snap"]: + matches = [ + (n, e, t) + for n, e, t in zip(names, estimates, targets) + if keyword in n.lower() + ] + if matches: + n, e, t = matches[0] + logger.info(f" {n}: est={e:,.0f}, target={t:,.0f}") + + +def main(argv=None): + """Entry point for CLI usage.""" + args = parse_args(argv) + + from policyengine_us import Microsimulation + from policyengine_us_data.datasets.cps.extended_cps import ( + ExtendedCPS_2024, + ) + from policyengine_us_data.storage import STORAGE_FOLDER + + dataset_path = args.dataset or str(STORAGE_FOLDER / "extended_cps_2024.h5") + output_path = args.output or str(STORAGE_FOLDER / "enhanced_cps_2024.h5") + + logger.info(f"Loading dataset from {dataset_path}") + sim = Microsimulation(dataset=dataset_path) + original_weights = sim.calculate("household_weight").values + + logger.info( + f"Loaded {len(original_weights):,} households, " + f"total weight: {original_weights.sum():,.0f}" + ) + + # Build calibration inputs + matrix, targets, names = build_calibration_inputs( + dataset_class=ExtendedCPS_2024, + time_period=2024, + db_path=args.db_path, + geo_level=args.geo_level, + ) + + logger.info( + f"Calibration matrix: {matrix.shape[0]} households x " + f"{matrix.shape[1]} targets" + ) + + # Initialize and run + init_weights = initialize_weights(original_weights) + calibrated_weights = fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=init_weights, + epochs=args.epochs, + lambda_l0=args.lambda_l0, + device=args.device, + ) + + # Diagnostics + run_validation(calibrated_weights, matrix, targets, names) + + # Save + # Copy source to output if different + import shutil + + if dataset_path != output_path: + shutil.copy2(dataset_path, output_path) + save_weights_to_h5(output_path, calibrated_weights, year=2024) + logger.info(f"Enhanced CPS saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/l0_sweep.py b/policyengine_us_data/calibration/l0_sweep.py new file mode 100644 index 000000000..4f7ff83b8 --- /dev/null +++ b/policyengine_us_data/calibration/l0_sweep.py @@ -0,0 +1,418 @@ +""" +L0 sweep: build matrix once, fit at many L0 values, plot results. + +Designed to run overnight. Saves intermediate results so it can +resume if interrupted. + +Usage: + python -m policyengine_us_data.calibration.l0_sweep + +Output: + - storage/calibration/l0_sweep_matrix.npz (sparse matrix) + - storage/calibration/l0_sweep_targets.npy + - storage/calibration/l0_sweep_results.csv + - storage/calibration/l0_sweep_plot.png +""" + +import argparse +import logging +import time +from pathlib import Path + +import numpy as np +import pandas as pd +import scipy.sparse + +from policyengine_us_data.calibration.unified_calibration import ( + fit_l0_weights, + log_achievable_targets, +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# L0 values to sweep (log-spaced from 1e-10 to 1e-2) +DEFAULT_L0_VALUES = [ + 1e-10, + 3e-10, + 1e-9, + 3e-9, + 1e-8, + 3e-8, + 1e-7, + 3e-7, + 1e-6, + 3e-6, + 1e-5, + 3e-5, + 1e-4, + 3e-4, + 1e-3, + 3e-3, + 1e-2, +] + +DEFAULT_EPOCHS = 200 +DEFAULT_N_CLONES = 130 + + +def build_and_save_matrix( + dataset_path: str, + db_path: str, + output_dir: Path, + n_clones: int, + seed: int, +): + """Build sparse matrix and save to disk. + + Returns: + Tuple of (X_sparse, targets, target_names, n_total). + """ + matrix_path = output_dir / "l0_sweep_matrix.npz" + targets_path = output_dir / "l0_sweep_targets.npy" + names_path = output_dir / "l0_sweep_names.txt" + + if matrix_path.exists() and targets_path.exists(): + logger.info("Loading cached matrix from %s", matrix_path) + X = scipy.sparse.load_npz(str(matrix_path)) + targets = np.load(str(targets_path)) + names = names_path.read_text().strip().split("\n") + logger.info( + "Loaded matrix: %d targets x %d columns", + X.shape[0], + X.shape[1], + ) + return X, targets, names, X.shape[1] + + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.clone_and_assign import ( + assign_random_geography, + ) + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + + # Load dataset + logger.info("Loading dataset from %s", dataset_path) + sim = Microsimulation(dataset=dataset_path) + n_records = len(sim.calculate("household_id", map_to="household").values) + logger.info("Loaded %d households", n_records) + + # Assign geography + n_total = n_records * n_clones + logger.info( + "Assigning geography: %d x %d = %d", + n_records, + n_clones, + n_total, + ) + geography = assign_random_geography( + n_records=n_records, + n_clones=n_clones, + seed=seed, + ) + + # Build matrix + db_uri = f"sqlite:///{db_path}" + builder = UnifiedMatrixBuilder(db_uri=db_uri, time_period=2024) + clone_cache_dir = str(output_dir / "clones") + targets_df, X_sparse, target_names = builder.build_matrix( + dataset_path=dataset_path, + geography=geography, + cache_dir=clone_cache_dir, + ) + + log_achievable_targets(X_sparse) + targets = targets_df["value"].values + + # Save + scipy.sparse.save_npz(str(matrix_path), X_sparse) + np.save(str(targets_path), targets) + names_path.write_text("\n".join(target_names)) + logger.info("Saved matrix to %s", output_dir) + + return X_sparse, targets, target_names, n_total + + +def fit_one_l0( + X_sparse, + targets, + target_names, + lambda_l0: float, + epochs: int, + device: str, + output_dir: Path, +) -> dict: + """Fit at one L0 value. Saves full weights array and + per-target diagnostics. + + Returns: + Summary dict for the results CSV. + """ + n_total = X_sparse.shape[1] + + t0 = time.time() + weights = fit_l0_weights( + X_sparse=X_sparse, + targets=targets, + lambda_l0=lambda_l0, + epochs=epochs, + device=device, + verbose_freq=max(1, epochs // 5), + ) + + n_nonzero = int((weights > 0).sum()) + total_weight = float(weights.sum()) + elapsed = time.time() - t0 + + # Save the full weight array. + l0_tag = f"{lambda_l0:.0e}".replace("+", "") + weights_path = output_dir / f"weights_l0_{l0_tag}.npy" + np.save(str(weights_path), weights) + logger.info("Saved weights to %s", weights_path) + + # Compute per-target diagnostics. + estimates = weights @ X_sparse.T.toarray() + with np.errstate(divide="ignore", invalid="ignore"): + rel_errors = np.where( + np.abs(targets) > 1e-6, + (estimates - targets) / targets, + 0.0, + ) + abs_rel_errors = np.abs(rel_errors) + + # Save per-target diagnostics. + diag_df = pd.DataFrame( + { + "target_name": target_names, + "target_value": targets, + "estimate": estimates, + "rel_error": rel_errors, + "abs_rel_error": abs_rel_errors, + } + ) + diag_path = output_dir / f"diagnostics_l0_{l0_tag}.csv" + diag_df.to_csv(diag_path, index=False) + + # Summary metrics. + pct_within_5 = float(np.mean(abs_rel_errors <= 0.05) * 100) + pct_within_10 = float(np.mean(abs_rel_errors <= 0.10) * 100) + pct_within_25 = float(np.mean(abs_rel_errors <= 0.25) * 100) + median_rel_error = float(np.median(abs_rel_errors)) + mean_rel_error = float(np.mean(abs_rel_errors)) + max_rel_error = float(np.max(abs_rel_errors)) + # Relative MSE (the actual loss the optimizer minimizes) + relative_mse = float( + np.mean( + np.where( + np.abs(targets) > 1e-6, + ((estimates - targets) / targets) ** 2, + 0.0, + ) + ) + ) + + return { + "lambda_l0": lambda_l0, + "n_nonzero": n_nonzero, + "n_total": n_total, + "sparsity_pct": (1 - n_nonzero / n_total) * 100, + "total_weight": total_weight, + "relative_mse": relative_mse, + "mean_rel_error": mean_rel_error, + "median_rel_error": median_rel_error, + "max_rel_error": max_rel_error, + "pct_within_5": pct_within_5, + "pct_within_10": pct_within_10, + "pct_within_25": pct_within_25, + "elapsed_s": elapsed, + "weights_path": str(weights_path), + "diagnostics_path": str(diag_path), + } + + +def make_plot(results_df: pd.DataFrame, output_path: Path): + """Create L0 vs record count plot.""" + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + fig, ax1 = plt.subplots(figsize=(12, 7)) + + # Left axis: record count + ax1.semilogx( + results_df["lambda_l0"], + results_df["n_nonzero"], + "b-o", + linewidth=2, + markersize=8, + label="Non-zero records", + ) + ax1.set_xlabel("L0 regularization (lambda)", fontsize=13) + ax1.set_ylabel("Non-zero records", color="b", fontsize=13) + ax1.tick_params(axis="y", labelcolor="b") + + # Reference lines + ax1.axhline( + y=4_000_000, + color="b", + linestyle="--", + alpha=0.5, + label="Target: ~4M (local)", + ) + ax1.axhline( + y=50_000, + color="b", + linestyle=":", + alpha=0.5, + label="Target: ~50K (national)", + ) + + # Right axis: accuracy + ax2 = ax1.twinx() + ax2.semilogx( + results_df["lambda_l0"], + results_df["pct_within_10"], + "r-s", + linewidth=2, + markersize=8, + label="% targets within 10%", + ) + ax2.set_ylabel("% targets within 10%", color="r", fontsize=13) + ax2.tick_params(axis="y", labelcolor="r") + ax2.set_ylim(0, 100) + + # Combined legend + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend( + lines1 + lines2, + labels1 + labels2, + loc="center left", + fontsize=11, + ) + + ax1.set_title( + "Unified calibration: L0 vs sparsity and accuracy", + fontsize=15, + ) + ax1.grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(str(output_path), dpi=150) + logger.info("Plot saved to %s", output_path) + plt.close(fig) + + +def parse_args(argv=None): + parser = argparse.ArgumentParser( + description="L0 sweep for unified calibration" + ) + parser.add_argument("--dataset", default=None) + parser.add_argument("--db-path", default=None) + parser.add_argument("--output-dir", default=None) + parser.add_argument( + "--n-clones", + type=int, + default=DEFAULT_N_CLONES, + ) + parser.add_argument("--epochs", type=int, default=DEFAULT_EPOCHS) + parser.add_argument("--device", default="cpu") + parser.add_argument("--seed", type=int, default=42) + return parser.parse_args(argv) + + +def main(argv=None): + args = parse_args(argv) + + from policyengine_us_data.storage import STORAGE_FOLDER + + dataset_path = args.dataset or str(STORAGE_FOLDER / "extended_cps_2024.h5") + db_path = args.db_path or str( + STORAGE_FOLDER / "calibration" / "policy_data.db" + ) + output_dir = Path(args.output_dir or str(STORAGE_FOLDER / "calibration")) + output_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Build matrix (cached) + X_sparse, targets, target_names, n_total = build_and_save_matrix( + dataset_path=dataset_path, + db_path=db_path, + output_dir=output_dir, + n_clones=args.n_clones, + seed=args.seed, + ) + + # Step 2: Sweep L0 values + results_path = output_dir / "l0_sweep_results.csv" + + # Resume from existing results if available + completed = set() + if results_path.exists(): + existing = pd.read_csv(results_path) + completed = set(existing["lambda_l0"].values) + results = existing.to_dict("records") + logger.info( + "Resuming: %d L0 values already completed", + len(completed), + ) + else: + results = [] + + for l0_val in DEFAULT_L0_VALUES: + if l0_val in completed: + logger.info("Skipping L0=%.1e (already done)", l0_val) + continue + + logger.info( + "=" * 60 + f"\nFitting L0={l0_val:.1e} " + f"({len(results)+1}/{len(DEFAULT_L0_VALUES)})" + ) + result = fit_one_l0( + X_sparse=X_sparse, + targets=targets, + target_names=target_names, + lambda_l0=l0_val, + epochs=args.epochs, + device=args.device, + output_dir=output_dir, + ) + results.append(result) + logger.info( + "L0=%.1e: %d non-zero records (%.1f%% " + "sparsity), %.1f%% within 10%%", + l0_val, + result["n_nonzero"], + result["sparsity_pct"], + result["pct_within_10"], + ) + + # Save incrementally + df = pd.DataFrame(results) + df.to_csv(results_path, index=False) + + # Step 3: Plot + df = pd.DataFrame(results).sort_values("lambda_l0") + plot_path = output_dir / "l0_sweep_plot.png" + make_plot(df, plot_path) + + # Summary + logger.info("\n" + "=" * 60) + logger.info("SWEEP COMPLETE") + logger.info("=" * 60) + for _, row in df.iterrows(): + logger.info( + "L0=%.1e: %8d records, %5.1f%% within 10%%", + row["lambda_l0"], + row["n_nonzero"], + row["pct_within_10"], + ) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/national_matrix_builder.py b/policyengine_us_data/calibration/national_matrix_builder.py new file mode 100644 index 000000000..23bc54417 --- /dev/null +++ b/policyengine_us_data/calibration/national_matrix_builder.py @@ -0,0 +1,634 @@ +""" +National matrix builder for calibration. + +Reads ALL active targets from policy_data.db and builds a dense loss +matrix for the full Extended CPS dataset (~200k households). This +replaces the legacy ``build_loss_matrix()`` in +``policyengine_us_data/utils/loss.py``. + +The builder evaluates stratum constraints (geographic, demographic, +filing-status, AGI band, variable-specific) to produce boolean masks, +then computes target variable values under those masks. + +Tax expenditure targets (reform_id > 0) trigger counterfactual +simulations that neutralize specific deduction variables so the +matrix column captures the income_tax difference. +""" + +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +from sqlalchemy import text + +from policyengine_us_data.calibration.base_matrix_builder import ( + BaseMatrixBuilder, +) +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, +) + +logger = logging.getLogger(__name__) + +# Variables that indicate the target is a *count* of entities +# (value = 1 per entity satisfying constraints) rather than a sum. +COUNT_VARIABLES = { + "person_count", + "tax_unit_count", + "household_count", + "spm_unit_count", +} + +# Variables evaluated at person level (need person-to-household +# aggregation even for non-count targets). +PERSON_LEVEL_VARIABLES = { + "person_count", +} + +# Variables evaluated at SPM unit level. +SPM_UNIT_VARIABLES = { + "spm_unit_count", +} + +# Mapping from reform_id -> deduction variable to neutralize. +# reform_id 1 is used by the ETL for all JCT tax expenditure targets. +# The target's own ``variable`` column tells us *which* deduction the +# row pertains to; the reform neutralises that deduction and records +# the income_tax delta. +REFORM_ID_NEUTRALIZE: Dict[int, None] = { + 1: None, # sentinel -- per-target variable is used directly +} + +# Pseudo-constraint variables added by the ETL for hash uniqueness +# that do not correspond to real simulation variables. +_SYNTHETIC_CONSTRAINT_VARS = {"target_category"} + + +class NationalMatrixBuilder(BaseMatrixBuilder): + """Build a dense calibration matrix for national reweighting. + + Reads all active targets from the database, evaluates their + stratum constraints against a ``Microsimulation``, and returns + a matrix suitable for ``SparseCalibrationWeights.fit()`` or the + legacy ``microcalibrate`` interface. + + Args: + db_uri: SQLAlchemy-style database URI, e.g. + ``"sqlite:///path/to/policy_data.db"``. + time_period: Tax year for the calibration (e.g. 2024). + """ + + def __init__( + self, + db_uri: str, + time_period: int, + ): + super().__init__(db_uri, time_period) + + # ------------------------------------------------------------------ + # Database queries + # ------------------------------------------------------------------ + + # Geographic constraint variables used to classify targets. + _GEO_STATE_VARS = {"state_fips", "state_code"} + _GEO_CD_VARS = {"congressional_district_geoid"} + + def _classify_target_geo(self, stratum_id: int) -> str: + """Return geographic level of a target: 'national', 'state', + or 'cd'. + + Walks the stratum chain and checks constraints at each level. + """ + visited: set = set() + current_id = stratum_id + + while current_id is not None and current_id not in visited: + visited.add(current_id) + constraints = self._get_stratum_constraints(current_id) + for c in constraints: + if c["variable"] in self._GEO_CD_VARS: + return "cd" + if c["variable"] in self._GEO_STATE_VARS: + return "state" + current_id = self._get_parent_stratum_id(current_id) + + return "national" + + def _query_active_targets( + self, + geo_level: str = "all", + ) -> pd.DataFrame: + """Query active, non-zero targets, optionally filtered by + geographic level. + + Args: + geo_level: One of ``"national"``, ``"state"``, + ``"cd"``, or ``"all"`` (default). When set to a + specific level, only targets whose stratum chain + resolves to that geographic level are returned. + + Returns: + DataFrame with columns: target_id, stratum_id, variable, + value, period, reform_id, tolerance, stratum_group_id, + stratum_notes, target_notes. + """ + query = """ + SELECT t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.reform_id, + t.tolerance, + t.notes AS target_notes, + s.stratum_group_id, + s.notes AS stratum_notes + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE t.active = 1 + ORDER BY t.target_id + """ + with self.engine.connect() as conn: + df = pd.read_sql(query, conn) + + if geo_level == "all": + return df + + # Classify each target by its geographic level. + geo_cache: dict = {} + levels = [] + for stratum_id in df["stratum_id"]: + sid = int(stratum_id) + if sid not in geo_cache: + geo_cache[sid] = self._classify_target_geo(sid) + levels.append(geo_cache[sid]) + + df["_geo_level"] = levels + df = ( + df[df["_geo_level"] == geo_level] + .drop(columns=["_geo_level"]) + .reset_index(drop=True) + ) + return df + + def _get_parent_stratum_id(self, stratum_id: int) -> Optional[int]: + """Return the parent_stratum_id for a stratum, or None.""" + query = """ + SELECT parent_stratum_id + FROM strata + WHERE stratum_id = :stratum_id + """ + with self.engine.connect() as conn: + row = conn.execute( + text(query), + {"stratum_id": int(stratum_id)}, + ).fetchone() + if row is None: + return None + return row[0] + + def _get_all_constraints(self, stratum_id: int) -> List[dict]: + """Get constraints for a stratum *and* all its ancestors. + + Walks up the ``parent_stratum_id`` chain, collecting + constraints from each level. This ensures that if a parent + stratum defines geographic or demographic constraints, they + are applied to the child target as well. + + Synthetic constraint variables (e.g. ``target_category``) + used only for hash uniqueness are filtered out. + + Args: + stratum_id: Starting stratum whose full constraint chain + is needed. + + Returns: + De-duplicated list of constraint dicts. + """ + all_constraints: List[dict] = [] + visited: set = set() + current_id: Optional[int] = stratum_id + + while current_id is not None and current_id not in visited: + visited.add(current_id) + constraints = self._get_stratum_constraints(current_id) + all_constraints.extend(constraints) + current_id = self._get_parent_stratum_id(current_id) + + # Remove synthetic/non-simulation constraint variables. + all_constraints = [ + c + for c in all_constraints + if c["variable"] not in _SYNTHETIC_CONSTRAINT_VARS + ] + + return all_constraints + + # ------------------------------------------------------------------ + # Target value computation + # ------------------------------------------------------------------ + + def _compute_target_column( + self, + sim, + variable: str, + constraints: List[dict], + n_households: int, + ) -> np.ndarray: + """Compute the loss matrix column for a single target. + + For count variables (``person_count``, ``tax_unit_count``, + ``household_count``), the column value is 1 per entity + satisfying constraints, summed to household level. For sum + variables, the column is the variable value masked by + constraints, mapped to household level. + + Args: + sim: Microsimulation instance. + variable: PolicyEngine variable name. + constraints: Fully-resolved constraint list. + n_households: Number of households. + + Returns: + Array of length *n_households* (float64). + """ + is_count = variable in COUNT_VARIABLES + + if is_count and variable == "person_count": + # Count persons satisfying constraints at person level, + # then sum per household. + entity_rel = self._build_entity_relationship(sim) + n_persons = len(entity_rel) + person_mask = np.ones(n_persons, dtype=bool) + + for c in constraints: + try: + vals = sim.calculate( + c["variable"], + self.time_period, + map_to="person", + ).values + except Exception: + logger.warning( + "Skipping constraint '%s' for " + "person_count (variable not found)", + c["variable"], + ) + return np.zeros(n_households, dtype=np.float64) + + person_mask &= apply_op(vals, c["operation"], c["value"]) + + values = sim.map_result( + person_mask.astype(float), "person", "household" + ) + return np.asarray(values, dtype=np.float64) + + if is_count and variable == "tax_unit_count": + # Count tax units satisfying constraints. The + # household-level mask already tells us which households + # contain at least one qualifying tax unit; for a count + # target we need the *number* of qualifying tax units per + # household. In practice most constraints produce a + # 0/1-per-household result. + mask = self._evaluate_constraints_entity_aware( + sim, constraints, n_households + ) + return mask.astype(np.float64) + + if is_count and variable == "household_count": + mask = self._evaluate_constraints_entity_aware( + sim, constraints, n_households + ) + return mask.astype(np.float64) + + # Non-count variable: compute value at household level and + # apply the constraint mask. + mask = self._evaluate_constraints_entity_aware( + sim, constraints, n_households + ) + + try: + values = sim.calculate( + variable, + self.time_period, + map_to="household", + ).values.astype(np.float64) + except Exception as exc: + logger.warning( + "Cannot calculate target variable '%s': %s " + "-- returning zeros", + variable, + exc, + ) + return np.zeros(n_households, dtype=np.float64) + + return values * mask + + # ------------------------------------------------------------------ + # Tax expenditure (reform) targets + # ------------------------------------------------------------------ + + def _compute_tax_expenditure_column( + self, + sim_baseline, + variable: str, + constraints: List[dict], + n_households: int, + dataset_class=None, + ) -> np.ndarray: + """Compute a tax expenditure column by running a reform. + + The reform neutralises *variable* (the deduction), and the + column is ``income_tax_reform - income_tax_baseline`` masked + by the stratum constraints. + + Args: + sim_baseline: Baseline Microsimulation instance. + variable: Deduction variable to neutralise + (e.g. ``"salt_deduction"``). + constraints: Stratum constraints to evaluate. + n_households: Number of households. + dataset_class: Dataset class (or path) to pass to the + reform Microsimulation constructor. + + Returns: + Array of length *n_households* (float64). + """ + from policyengine_core.reforms import Reform + from policyengine_us import Microsimulation + + # Get baseline income_tax (cached from sim_baseline). + income_tax_baseline = sim_baseline.calculate( + "income_tax", map_to="household" + ).values.astype(np.float64) + + # Build a reform that neutralises the deduction variable. + def make_repeal_class(deduction_var: str): + class RepealDeduction(Reform): + def apply(self): + self.neutralize_variable(deduction_var) + + return RepealDeduction + + RepealDeduction = make_repeal_class(variable) + + dataset_arg = dataset_class + if dataset_arg is None: + # Fall back to whatever the baseline sim was loaded with. + dataset_arg = getattr(sim_baseline, "dataset", None) + + sim_reform = Microsimulation( + dataset=dataset_arg, reform=RepealDeduction + ) + sim_reform.default_calculation_period = self.time_period + + income_tax_reform = sim_reform.calculate( + "income_tax", map_to="household" + ).values.astype(np.float64) + + te_values = income_tax_reform - income_tax_baseline + + # Apply stratum constraints mask. + mask = self._evaluate_constraints_entity_aware( + sim_baseline, constraints, n_households + ) + return te_values * mask + + # ------------------------------------------------------------------ + # Target name generation + # ------------------------------------------------------------------ + + @staticmethod + def _make_target_name( + variable: str, + constraints: List[dict], + stratum_notes: Optional[str] = None, + reform_id: int = 0, + ) -> str: + """Generate a human-readable label for a target. + + Args: + variable: Target variable name. + constraints: Resolved constraints for the stratum. + stratum_notes: Optional notes string from the stratum. + reform_id: Reform identifier (0 = baseline). + + Returns: + A slash-separated label string. + """ + parts: List[str] = [] + + # Geographic level + geo_parts: List[str] = [] + for c in constraints: + if c["variable"] == "state_fips": + geo_parts.append(f"state_{c['value']}") + elif c["variable"] == "state_code": + geo_parts.append(f"state_{c['value']}") + elif c["variable"] == "congressional_district_geoid": + geo_parts.append(f"cd_{c['value']}") + + if geo_parts: + parts.append("/".join(geo_parts)) + else: + parts.append("national") + + if reform_id > 0: + parts.append(f"{variable}_expenditure") + else: + parts.append(variable) + + # Non-geo constraint summary + non_geo = [ + c + for c in constraints + if c["variable"] + not in ( + "state_fips", + "state_code", + "congressional_district_geoid", + ) + ] + if non_geo: + constraint_strs = [ + f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo + ] + parts.append("[" + ",".join(constraint_strs) + "]") + + return "/".join(parts) + + # ------------------------------------------------------------------ + # Main build method + # ------------------------------------------------------------------ + + def build_matrix( + self, + sim, + include_tax_expenditures: bool = True, + dataset_class=None, + geo_level: str = "all", + ) -> Tuple[np.ndarray, np.ndarray, List[str]]: + """Build the national calibration matrix from DB targets. + + For each active target in the database: + + 1. Retrieve the stratum and walk up the parent chain to + collect **all** constraints. + 2. Evaluate each constraint variable against the simulation + to produce a boolean mask. + 3. Calculate the target variable, apply the mask, and map to + household level. + 4. Place the result in one column of the loss matrix. + + Tax expenditure targets (``reform_id > 0``) run a + counterfactual simulation that neutralises the target's + deduction variable and records the ``income_tax`` delta. + + Args: + sim: Microsimulation instance loaded with the dataset. + include_tax_expenditures: If ``True`` (default), include + targets with ``reform_id > 0``. Set to ``False`` to + skip the expensive counterfactual simulations. + dataset_class: Dataset class (or path) required for + tax expenditure counterfactual simulations. Ignored + when ``include_tax_expenditures`` is ``False``. + geo_level: Geographic filter -- ``"national"``, + ``"state"``, ``"cd"``, or ``"all"`` (default). + + Returns: + Tuple of ``(loss_matrix, targets, target_names)`` where: + + - **loss_matrix** -- numpy array, shape + ``(n_households, n_targets)``, dtype float64. + - **targets** -- numpy array, shape ``(n_targets,)``, + dtype float64. + - **target_names** -- list of human-readable labels. + """ + household_ids = sim.calculate( + "household_id", map_to="household" + ).values + n_households = len(household_ids) + + targets_df = self._query_active_targets(geo_level=geo_level) + logger.info( + "Loaded %d active targets from database (geo_level=%s)", + len(targets_df), + geo_level, + ) + + if targets_df.empty: + raise ValueError("No active targets found in database") + + # Filter out targets with zero or null value. + targets_df = targets_df[ + targets_df["value"].notna() + & ~np.isclose(targets_df["value"].values, 0.0, atol=0.1) + ].reset_index(drop=True) + + logger.info( + "%d targets remain after removing zero/null values", + len(targets_df), + ) + + if targets_df.empty: + raise ValueError("All targets have zero or null values") + + # Cache constraints per stratum to avoid repeated DB queries. + constraint_cache: Dict[int, List[dict]] = {} + + # Pre-allocate outputs. + columns: List[np.ndarray] = [] + target_values: List[float] = [] + target_names: List[str] = [] + skipped = 0 + + # Cache baseline income_tax for tax expenditure targets. + _income_tax_baseline_cache: Optional[np.ndarray] = None + + for _, row in targets_df.iterrows(): + stratum_id = int(row["stratum_id"]) + variable = str(row["variable"]) + reform_id = int(row["reform_id"]) + + # Skip tax expenditure targets if requested. + if reform_id > 0 and not include_tax_expenditures: + skipped += 1 + continue + + # Resolve full constraint chain. + if stratum_id not in constraint_cache: + constraint_cache[stratum_id] = self._get_all_constraints( + stratum_id + ) + constraints = constraint_cache[stratum_id] + + # -- Tax expenditure target (reform_id > 0) ---------------- + if reform_id > 0: + logger.info( + "Building tax expenditure target: " + "neutralize '%s' (target_id=%s)", + variable, + row["target_id"], + ) + try: + column = self._compute_tax_expenditure_column( + sim_baseline=sim, + variable=variable, + constraints=constraints, + n_households=n_households, + dataset_class=dataset_class, + ) + except Exception as exc: + logger.warning( + "Skipping tax expenditure target '%s' " + "(target_id=%s): %s", + variable, + row["target_id"], + exc, + ) + skipped += 1 + continue + else: + # -- Baseline target ----------------------------------- + logger.debug( + "Building target: %s (target_id=%s, " "stratum_id=%s)", + variable, + row["target_id"], + stratum_id, + ) + column = self._compute_target_column( + sim, variable, constraints, n_households + ) + + columns.append(column) + target_values.append(float(row["value"])) + target_names.append( + self._make_target_name( + variable, + constraints, + stratum_notes=row.get("stratum_notes", ""), + reform_id=reform_id, + ) + ) + + if skipped: + logger.info("Skipped %d targets", skipped) + + n_targets = len(columns) + logger.info( + "Built matrix: %d households x %d targets", + n_households, + n_targets, + ) + + if n_targets == 0: + raise ValueError( + "No targets could be computed " "(all were skipped or errored)" + ) + + # Stack columns into (n_households, n_targets) matrix. + matrix = np.column_stack(columns) + targets = np.array(target_values, dtype=np.float64) + + return matrix, targets, target_names diff --git a/policyengine_us_data/calibration/puf_impute.py b/policyengine_us_data/calibration/puf_impute.py new file mode 100644 index 000000000..37d77c628 --- /dev/null +++ b/policyengine_us_data/calibration/puf_impute.py @@ -0,0 +1,469 @@ +"""PUF clone and QRF imputation for calibration pipeline. + +Doubles CPS records: one half keeps original values, the other half +gets PUF tax variables imputed via Quantile Random Forest. Geography +(state_fips) is included as a QRF predictor so imputations vary by +state. + +Usage within the calibration pipeline: + 1. Load raw CPS dataset + 2. Clone 10x and assign geography + 3. Call puf_clone_dataset() to double records and impute PUF vars + 4. Save expanded dataset for matrix building +""" + +import gc +import logging +import time +from typing import Dict, List, Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + +# Demographic predictors for QRF -- same as ExtendedCPS plus state_fips +DEMOGRAPHIC_PREDICTORS = [ + "age", + "is_male", + "tax_unit_is_joint", + "tax_unit_count_dependents", + "is_tax_unit_head", + "is_tax_unit_spouse", + "is_tax_unit_dependent", + "state_fips", +] + +# PUF variables to impute (from extended_cps.py) +IMPUTED_VARIABLES = [ + "employment_income", + "partnership_s_corp_income", + "social_security", + "taxable_pension_income", + "interest_deduction", + "tax_exempt_pension_income", + "long_term_capital_gains", + "unreimbursed_business_employee_expenses", + "pre_tax_contributions", + "taxable_ira_distributions", + "self_employment_income", + "w2_wages_from_qualified_business", + "unadjusted_basis_qualified_property", + "business_is_sstb", + "short_term_capital_gains", + "qualified_dividend_income", + "charitable_cash_donations", + "self_employed_pension_contribution_ald", + "unrecaptured_section_1250_gain", + "taxable_unemployment_compensation", + "taxable_interest_income", + "domestic_production_ald", + "self_employed_health_insurance_ald", + "rental_income", + "non_qualified_dividend_income", + "cdcc_relevant_expenses", + "tax_exempt_interest_income", + "salt_refund_income", + "foreign_tax_credit", + "estate_income", + "charitable_non_cash_donations", + "american_opportunity_credit", + "miscellaneous_income", + "alimony_expense", + "farm_income", + "partnership_se_income", + "alimony_income", + "health_savings_account_ald", + "non_sch_d_capital_gains", + "general_business_credit", + "energy_efficient_home_improvement_credit", + "traditional_ira_contributions", + "amt_foreign_tax_credit", + "excess_withheld_payroll_tax", + "savers_credit", + "student_loan_interest", + "investment_income_elected_form_4952", + "early_withdrawal_penalty", + "prior_year_minimum_tax_credit", + "farm_rent_income", + "qualified_tuition_expenses", + "educator_expense", + "long_term_capital_gains_on_collectibles", + "other_credits", + "casualty_loss", + "unreported_payroll_tax", + "recapture_of_investment_credit", + "deductible_mortgage_interest", + "qualified_reit_and_ptp_income", + "qualified_bdc_income", + "farm_operations_income", + "estate_income_would_be_qualified", + "farm_operations_income_would_be_qualified", + "farm_rent_income_would_be_qualified", + "partnership_s_corp_income_would_be_qualified", + "rental_income_would_be_qualified", + "self_employment_income_would_be_qualified", +] + +# Variables where PUF values override CPS values in BOTH halves +OVERRIDDEN_IMPUTED_VARIABLES = [ + "partnership_s_corp_income", + "interest_deduction", + "unreimbursed_business_employee_expenses", + "pre_tax_contributions", + "w2_wages_from_qualified_business", + "unadjusted_basis_qualified_property", + "business_is_sstb", + "charitable_cash_donations", + "self_employed_pension_contribution_ald", + "unrecaptured_section_1250_gain", + "taxable_unemployment_compensation", + "domestic_production_ald", + "self_employed_health_insurance_ald", + "cdcc_relevant_expenses", + "salt_refund_income", + "foreign_tax_credit", + "estate_income", + "charitable_non_cash_donations", + "american_opportunity_credit", + "miscellaneous_income", + "alimony_expense", + "health_savings_account_ald", + "non_sch_d_capital_gains", + "general_business_credit", + "energy_efficient_home_improvement_credit", + "amt_foreign_tax_credit", + "excess_withheld_payroll_tax", + "savers_credit", + "student_loan_interest", + "investment_income_elected_form_4952", + "early_withdrawal_penalty", + "prior_year_minimum_tax_credit", + "farm_rent_income", + "qualified_tuition_expenses", + "educator_expense", + "long_term_capital_gains_on_collectibles", + "other_credits", + "casualty_loss", + "unreported_payroll_tax", + "recapture_of_investment_credit", + "deductible_mortgage_interest", + "qualified_reit_and_ptp_income", + "qualified_bdc_income", + "farm_operations_income", + "estate_income_would_be_qualified", + "farm_operations_income_would_be_qualified", + "farm_rent_income_would_be_qualified", + "partnership_s_corp_income_would_be_qualified", + "rental_income_would_be_qualified", +] + + +def puf_clone_dataset( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int = 2024, + puf_dataset=None, + skip_qrf: bool = False, + dataset_path: Optional[str] = None, +) -> Dict[str, Dict[int, np.ndarray]]: + """Clone CPS data 2x and impute PUF variables on one half. + + The first half keeps CPS values (with OVERRIDDEN vars QRF'd). + The second half gets full PUF QRF imputation. The second half + has household weights set to zero. + + Args: + data: CPS dataset dict {variable: {time_period: array}}. + state_fips: State FIPS per household, shape (n_households,). + time_period: Tax year. + puf_dataset: PUF dataset class or path for QRF training. + If None, skips QRF (same as skip_qrf=True). + skip_qrf: If True, skip QRF imputation (for testing). + dataset_path: Path to CPS h5 file (needed for QRF to + compute demographic predictors via Microsimulation). + + Returns: + New data dict with doubled records. + """ + # Determine record counts from data + household_ids = data["household_id"][time_period] + n_households = len(household_ids) + + # Find person-level variables by checking array lengths + person_count = len(data["person_id"][time_period]) + persons_per_hh = person_count // n_households + + logger.info( + "PUF clone: %d households, %d persons, " "%d persons/hh", + n_households, + person_count, + persons_per_hh, + ) + + # Run QRF imputation if requested + y_full = None + y_override = None + if not skip_qrf and puf_dataset is not None: + y_full, y_override = _run_qrf_imputation( + data, + state_fips, + time_period, + puf_dataset, + dataset_path=dataset_path, + ) + + # Load a Microsimulation for entity-level mapping of imputed + # variables (QRF imputes at person level, but many PUF vars + # belong to tax_unit entity) + cps_sim = None + tbs = None + if (y_full or y_override) and dataset_path is not None: + from policyengine_us import Microsimulation + + cps_sim = Microsimulation(dataset=dataset_path) + tbs = cps_sim.tax_benefit_system + + def _map_to_entity(pred_values, variable_name): + """Map person-level predictions to the correct entity.""" + if cps_sim is None or tbs is None: + return pred_values + var_meta = tbs.variables.get(variable_name) + if var_meta is None: + return pred_values + entity = var_meta.entity.key + if entity != "person": + return cps_sim.populations[entity].value_from_first_person( + pred_values + ) + return pred_values + + # Build doubled dataset + new_data = {} + for variable, time_dict in data.items(): + values = time_dict[time_period] + + if variable in OVERRIDDEN_IMPUTED_VARIABLES and y_override: + pred = _map_to_entity(y_override[variable], variable) + new_data[variable] = {time_period: np.concatenate([pred, pred])} + elif variable in IMPUTED_VARIABLES and y_full: + pred = _map_to_entity(y_full[variable], variable) + new_data[variable] = {time_period: np.concatenate([values, pred])} + elif variable == "person_id": + new_data[variable] = { + time_period: np.concatenate([values, values + values.max()]) + } + elif "_id" in variable: + new_data[variable] = { + time_period: np.concatenate([values, values + values.max()]) + } + elif "_weight" in variable: + new_data[variable] = { + time_period: np.concatenate([values, values * 0]) + } + else: + new_data[variable] = { + time_period: np.concatenate([values, values]) + } + + # Add state_fips to the dataset (doubled) + new_data["state_fips"] = { + time_period: np.concatenate([state_fips, state_fips]).astype(np.int32) + } + + # Add any IMPUTED_VARIABLES that weren't in the original data + if y_full: + for var in IMPUTED_VARIABLES: + if var not in data: + pred = _map_to_entity(y_full[var], var) + orig = np.zeros_like(pred) + new_data[var] = {time_period: np.concatenate([orig, pred])} + + if cps_sim is not None: + del cps_sim + + logger.info( + "PUF clone complete: %d -> %d households", + n_households, + n_households * 2, + ) + return new_data + + +def _run_qrf_imputation( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int, + puf_dataset, + dataset_path: Optional[str] = None, +) -> tuple: + """Run QRF imputation for PUF variables. + + Trains QRF on PUF data (with random state assignment) and + predicts on CPS data using demographics + state_fips. + + Args: + data: CPS data dict. + state_fips: State FIPS per household. + time_period: Tax year. + puf_dataset: PUF dataset class or path. + dataset_path: Path to CPS h5 file for computing + demographic predictors via Microsimulation. + + Returns: + Tuple of (y_full_imputations, y_override_imputations) + as dicts of {variable: np.ndarray}. + """ + from microimpute.models.qrf import QRF + from policyengine_us import Microsimulation + + logger.info("Running QRF imputation with state predictor") + + # Load PUF + puf_sim = Microsimulation(dataset=puf_dataset) + # Use all PUF records for QRF training (no subsample) + + # Assign random states to PUF training records + from policyengine_us_data.calibration.clone_and_assign import ( + load_global_block_distribution, + ) + + _, _, puf_states, block_probs = load_global_block_distribution() + rng = np.random.default_rng(seed=99) + n_puf = len(puf_sim.calculate("person_id", map_to="person").values) + puf_state_indices = rng.choice(len(puf_states), size=n_puf, p=block_probs) + puf_state_fips = puf_states[puf_state_indices] + + # Build predictor names (without state for PUF calc, add later) + demo_preds = [p for p in DEMOGRAPHIC_PREDICTORS if p != "state_fips"] + + # Build PUF training DataFrame + X_train_full = puf_sim.calculate_dataframe(demo_preds + IMPUTED_VARIABLES) + X_train_full["state_fips"] = puf_state_fips.astype(np.float32) + + X_train_override = puf_sim.calculate_dataframe( + demo_preds + OVERRIDDEN_IMPUTED_VARIABLES + ) + X_train_override["state_fips"] = puf_state_fips.astype(np.float32) + + # Build CPS test DataFrame using Microsimulation to compute + # demographic predictors (many are calculated, not stored) + n_hh = len(data["household_id"][time_period]) + person_count = len(data["person_id"][time_period]) + + if dataset_path is not None: + cps_sim = Microsimulation(dataset=dataset_path) + X_test = cps_sim.calculate_dataframe(demo_preds) + del cps_sim + else: + # Fallback: extract from data dict (only stored vars) + X_test = pd.DataFrame() + for pred in demo_preds: + if pred in data: + X_test[pred] = data[pred][time_period].astype(np.float32) + + # Map household state_fips to person level + hh_ids_person = data.get("person_household_id", {}).get(time_period) + if hh_ids_person is not None: + hh_ids = data["household_id"][time_period] + hh_to_idx = {int(hh_id): i for i, hh_id in enumerate(hh_ids)} + person_states = np.array( + [state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person] + ) + else: + person_states = np.repeat( + state_fips, + person_count // n_hh, + ) + X_test["state_fips"] = person_states.astype(np.float32) + + predictors = DEMOGRAPHIC_PREDICTORS + + # Full imputation + logger.info("Imputing %d PUF variables (full)", len(IMPUTED_VARIABLES)) + y_full = _batch_qrf(X_train_full, X_test, predictors, IMPUTED_VARIABLES) + + # Override imputation + logger.info( + "Imputing %d PUF variables (override)", + len(OVERRIDDEN_IMPUTED_VARIABLES), + ) + y_override = _batch_qrf( + X_train_override, + X_test, + predictors, + OVERRIDDEN_IMPUTED_VARIABLES, + ) + + return y_full, y_override + + +def _batch_qrf( + X_train: pd.DataFrame, + X_test: pd.DataFrame, + predictors: List[str], + output_vars: List[str], + batch_size: int = 10, +) -> Dict[str, np.ndarray]: + """Run QRF in batches to control memory. + + Args: + X_train: Training data with predictors + output vars. + X_test: Test data with predictors only. + predictors: Predictor column names. + output_vars: Output variable names to impute. + batch_size: Variables per batch. + + Returns: + Dict mapping variable name to imputed values. + """ + from microimpute.models.qrf import QRF + + available = [c for c in output_vars if c in X_train.columns] + missing = [c for c in output_vars if c not in X_train.columns] + + if missing: + logger.warning( + "%d variables missing from training: %s", + len(missing), + missing[:5], + ) + + result = {} + + for batch_start in range(0, len(available), batch_size): + batch_end = min(batch_start + batch_size, len(available)) + batch_vars = available[batch_start:batch_end] + + gc.collect() + + qrf = QRF( + log_level="INFO", + memory_efficient=True, + batch_size=10, + cleanup_interval=5, + ) + + batch_X_train = X_train[predictors + batch_vars].copy() + + fitted = qrf.fit( + X_train=batch_X_train, + predictors=predictors, + imputed_variables=batch_vars, + n_jobs=1, + ) + + predictions = fitted.predict(X_test=X_test) + + for var in batch_vars: + result[var] = predictions[var].values + + del fitted, predictions, batch_X_train + gc.collect() + + # Zeros for missing variables + n_test = len(X_test) + for var in missing: + result[var] = np.zeros(n_test) + + return result diff --git a/policyengine_us_data/calibration/source_impute.py b/policyengine_us_data/calibration/source_impute.py new file mode 100644 index 000000000..3a27f47be --- /dev/null +++ b/policyengine_us_data/calibration/source_impute.py @@ -0,0 +1,829 @@ +"""Non-PUF QRF imputations with state_fips as predictor. + +Re-imputes variables from ACS, SIPP, and SCF donor surveys +with state_fips included as a QRF predictor. This runs after +geography assignment so imputations reflect assigned state. + +Sources and variables: + ACS -> rent, real_estate_taxes + SIPP -> tip_income, bank_account_assets, stock_assets, + bond_assets + SCF -> net_worth, auto_loan_balance, auto_loan_interest + +Usage in unified calibration pipeline: + 1. Load raw CPS + 2. Clone Nx, assign geography + 3. impute_source_variables() <-- this module + 4. PUF clone + QRF impute (puf_impute.py) + 5. PE simulate, build matrix, calibrate +""" + +import gc +import logging +from typing import Dict, Optional + +import numpy as np +import pandas as pd + +logger = logging.getLogger(__name__) + +# Variables imputed from each source +ACS_IMPUTED_VARIABLES = [ + "rent", + "real_estate_taxes", +] + +SIPP_IMPUTED_VARIABLES = [ + "tip_income", + "bank_account_assets", + "stock_assets", + "bond_assets", +] + +SCF_IMPUTED_VARIABLES = [ + "net_worth", + "auto_loan_balance", + "auto_loan_interest", +] + +ALL_SOURCE_VARIABLES = ( + ACS_IMPUTED_VARIABLES + SIPP_IMPUTED_VARIABLES + SCF_IMPUTED_VARIABLES +) + +# Predictors for each source (state_fips always appended) +ACS_PREDICTORS = [ + "is_household_head", + "age", + "is_male", + "tenure_type", + "employment_income", + "self_employment_income", + "social_security", + "pension_income", + "household_size", +] + +SIPP_TIPS_PREDICTORS = [ + "employment_income", + "age", + "count_under_18", + "count_under_6", +] + +SIPP_ASSETS_PREDICTORS = [ + "employment_income", + "age", + "is_female", + "is_married", + "count_under_18", +] + +SCF_PREDICTORS = [ + "age", + "is_female", + "cps_race", + "is_married", + "own_children_in_household", + "employment_income", + "interest_dividend_income", + "social_security_pension_income", +] + + +def impute_source_variables( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int = 2024, + dataset_path: Optional[str] = None, + skip_acs: bool = False, + skip_sipp: bool = False, + skip_scf: bool = False, +) -> Dict[str, Dict[int, np.ndarray]]: + """Re-impute ACS/SIPP/SCF variables with state as predictor. + + Overwrites existing imputed values in data with new values + that use assigned state_fips as a QRF predictor. + + Args: + data: CPS dataset dict {variable: {time_period: array}}. + state_fips: State FIPS per household, shape + (n_households,). + time_period: Tax year. + dataset_path: Path to CPS h5 file (for computing + demographic predictors via Microsimulation). + skip_acs: Skip ACS imputation (rent, real_estate_taxes). + skip_sipp: Skip SIPP imputation (tips, assets). + skip_scf: Skip SCF imputation (net_worth, auto_loan). + + Returns: + Updated data dict with re-imputed variables. + """ + # Add state_fips to data (household level) + data["state_fips"] = { + time_period: state_fips.astype(np.int32), + } + + if not skip_acs: + logger.info("Imputing ACS variables with state predictor") + data = _impute_acs(data, state_fips, time_period, dataset_path) + + if not skip_sipp: + logger.info("Imputing SIPP variables with state predictor") + data = _impute_sipp(data, state_fips, time_period, dataset_path) + + if not skip_scf: + logger.info("Imputing SCF variables with state predictor") + data = _impute_scf(data, state_fips, time_period, dataset_path) + + return data + + +def _build_cps_receiver( + data: Dict[str, Dict[int, np.ndarray]], + time_period: int, + dataset_path: Optional[str], + pe_variables: list, +) -> pd.DataFrame: + """Build CPS receiver DataFrame from Microsimulation. + + Uses Microsimulation for standard PE variables, falls back + to data dict for variables not in the PE tax-benefit system. + + Args: + data: CPS data dict. + time_period: Tax year. + dataset_path: Path to CPS h5 for Microsimulation. + pe_variables: List of PE variable names to compute. + + Returns: + DataFrame with requested columns. + """ + if dataset_path is not None: + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=dataset_path) + # Only request variables that exist in PE + tbs = sim.tax_benefit_system + valid_vars = [v for v in pe_variables if v in tbs.variables] + if valid_vars: + df = sim.calculate_dataframe(valid_vars) + else: + df = pd.DataFrame(index=range(len(data["person_id"][time_period]))) + del sim + else: + df = pd.DataFrame() + + # Add any remaining variables from data dict + for var in pe_variables: + if var not in df.columns and var in data: + df[var] = data[var][time_period].astype(np.float32) + + return df + + +def _get_variable_entity(variable_name: str) -> str: + """Return the entity key ('person', 'household', etc.) for a PE variable.""" + from policyengine_us import CountryTaxBenefitSystem + + tbs = CountryTaxBenefitSystem() + var = tbs.variables.get(variable_name) + if var is None: + return "person" # Default to person if unknown + return var.entity.key + + +def _person_state_fips( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int, +) -> np.ndarray: + """Map household-level state_fips to person level. + + Args: + data: CPS data dict. + state_fips: State FIPS per household. + time_period: Tax year. + + Returns: + Person-level state FIPS array. + """ + hh_ids_person = data.get("person_household_id", {}).get(time_period) + if hh_ids_person is not None: + hh_ids = data["household_id"][time_period] + hh_to_idx = {int(hh_id): i for i, hh_id in enumerate(hh_ids)} + return np.array( + [state_fips[hh_to_idx[int(hh_id)]] for hh_id in hh_ids_person] + ) + n_hh = len(data["household_id"][time_period]) + n_persons = len(data["person_id"][time_period]) + return np.repeat(state_fips, n_persons // n_hh) + + +def _impute_acs( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int, + dataset_path: Optional[str] = None, +) -> Dict[str, Dict[int, np.ndarray]]: + """Impute rent and real_estate_taxes from ACS with state. + + Trains QRF on ACS_2022 with state_fips as predictor, + predicts on CPS household heads. + + Args: + data: CPS data dict. + state_fips: State FIPS per household. + time_period: Tax year. + dataset_path: Path to CPS h5 for Microsimulation. + + Returns: + Updated data dict. + """ + from microimpute.models.qrf import QRF + from policyengine_us import Microsimulation + + from policyengine_us_data.datasets.acs.acs import ACS_2022 + + # Load ACS donor data + acs = Microsimulation(dataset=ACS_2022) + predictors = ACS_PREDICTORS + ["state_fips"] + + # ACS has state — use it directly + acs_df = acs.calculate_dataframe(ACS_PREDICTORS + ACS_IMPUTED_VARIABLES) + acs_df["state_fips"] = acs.calculate( + "state_fips", map_to="person" + ).values.astype(np.float32) + + # Filter to household heads and sample + train_df = acs_df[acs_df.is_household_head].sample(10_000, random_state=42) + # Convert tenure_type to numeric for QRF + if "tenure_type" in train_df.columns: + train_df["tenure_type"] = ( + train_df["tenure_type"] + .astype(str) + .map( + { + "OWNED_WITH_MORTGAGE": 1, + "OWNED_OUTRIGHT": 1, + "RENTED": 2, + "NONE": 0, + } + ) + .fillna(0) + .astype(np.float32) + ) + del acs + + # Build CPS receiver data + if dataset_path is not None: + cps_sim = Microsimulation(dataset=dataset_path) + cps_df = cps_sim.calculate_dataframe(ACS_PREDICTORS) + del cps_sim + else: + cps_df = pd.DataFrame() + for pred in ACS_PREDICTORS: + if pred in data: + cps_df[pred] = data[pred][time_period].astype(np.float32) + + # Convert tenure_type to numeric + if "tenure_type" in cps_df.columns: + cps_df["tenure_type"] = ( + cps_df["tenure_type"] + .astype(str) + .map( + { + "OWNED_WITH_MORTGAGE": 1, + "OWNED_OUTRIGHT": 1, + "RENTED": 2, + "NONE": 0, + } + ) + .fillna(0) + .astype(np.float32) + ) + + # Add person-level state_fips + person_states = _person_state_fips(data, state_fips, time_period) + cps_df["state_fips"] = person_states.astype(np.float32) + + # Filter to household heads + mask = ( + cps_df.is_household_head.values + if "is_household_head" in cps_df.columns + else np.ones(len(cps_df), dtype=bool) + ) + cps_heads = cps_df[mask] + + # Train and predict + qrf = QRF() + logger.info( + "ACS QRF: %d train, %d test, %d predictors", + len(train_df), + len(cps_heads), + len(predictors), + ) + fitted = qrf.fit( + X_train=train_df, + predictors=predictors, + imputed_variables=ACS_IMPUTED_VARIABLES, + ) + predictions = fitted.predict(X_test=cps_heads) + + # Write back (household heads only) + n_persons = len(data["person_id"][time_period]) + for var in ACS_IMPUTED_VARIABLES: + values = np.zeros(n_persons, dtype=np.float32) + values[mask] = predictions[var].values + data[var] = {time_period: values} + # Also set pre_subsidy_rent = rent, housing_assistance = 0 + data["pre_subsidy_rent"] = {time_period: data["rent"][time_period].copy()} + + del fitted, predictions + gc.collect() + + logger.info("ACS imputation complete: rent, real_estate_taxes") + return data + + +def _impute_sipp( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int, + dataset_path: Optional[str] = None, +) -> Dict[str, Dict[int, np.ndarray]]: + """Impute tip_income and liquid assets from SIPP with state. + + Trains QRF on SIPP 2023 with state_fips as predictor. + Since SIPP doesn't have state, random states are assigned + to donor records. + + Args: + data: CPS data dict. + state_fips: State FIPS per household. + time_period: Tax year. + dataset_path: Path to CPS h5 for Microsimulation. + + Returns: + Updated data dict. + """ + from microimpute.models.qrf import QRF + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.clone_and_assign import ( + load_global_block_distribution, + ) + + # Load state distribution for random assignment to donor + _, _, donor_states, block_probs = load_global_block_distribution() + rng = np.random.default_rng(seed=88) + + # --- Tips imputation --- + from policyengine_us_data.datasets.sipp.sipp import ( + train_tip_model, + ) + + # We need to retrain with state — can't reuse pickled model. + # Load SIPP tip training data directly. + from policyengine_us_data.storage import STORAGE_FOLDER + from huggingface_hub import hf_hub_download + + hf_hub_download( + repo_id="PolicyEngine/policyengine-us-data", + filename="pu2023_slim.csv", + repo_type="model", + local_dir=STORAGE_FOLDER, + ) + sipp_df = pd.read_csv(STORAGE_FOLDER / "pu2023_slim.csv") + + # Prepare SIPP tip data (matching sipp.py logic) + sipp_df["tip_income"] = ( + sipp_df[sipp_df.columns[sipp_df.columns.str.contains("TXAMT")]] + .fillna(0) + .sum(axis=1) + * 12 + ) + sipp_df["employment_income"] = sipp_df.TPTOTINC * 12 + sipp_df["age"] = sipp_df.TAGE + sipp_df["household_weight"] = sipp_df.WPFINWGT + sipp_df["household_id"] = sipp_df.SSUID + + sipp_df["is_under_18"] = sipp_df.TAGE < 18 + sipp_df["is_under_6"] = sipp_df.TAGE < 6 + sipp_df["count_under_18"] = ( + sipp_df.groupby("SSUID")["is_under_18"] + .sum() + .loc[sipp_df.SSUID.values] + .values + ) + sipp_df["count_under_6"] = ( + sipp_df.groupby("SSUID")["is_under_6"] + .sum() + .loc[sipp_df.SSUID.values] + .values + ) + + tip_cols = [ + "household_id", + "employment_income", + "tip_income", + "count_under_18", + "count_under_6", + "age", + "household_weight", + ] + tip_train = sipp_df[tip_cols].dropna() + tip_train = tip_train.loc[ + rng.choice( + tip_train.index, + size=min(10_000, len(tip_train)), + replace=True, + p=(tip_train.household_weight / tip_train.household_weight.sum()), + ) + ] + + # Assign random states to SIPP donor + tip_state_idx = rng.choice( + len(donor_states), size=len(tip_train), p=block_probs + ) + tip_train["state_fips"] = donor_states[tip_state_idx].astype(np.float32) + + # Build CPS receiver for tips + # count_under_18/6 aren't PE variables — compute from data + cps_tip_df = _build_cps_receiver( + data, time_period, dataset_path, ["employment_income", "age"] + ) + # Compute household child counts from ages + person_ages = data["age"][time_period] + hh_ids_person = data.get("person_household_id", {}).get(time_period) + if hh_ids_person is not None: + age_df = pd.DataFrame({"hh": hh_ids_person, "age": person_ages}) + under_18 = age_df.groupby("hh")["age"].apply(lambda x: (x < 18).sum()) + under_6 = age_df.groupby("hh")["age"].apply(lambda x: (x < 6).sum()) + cps_tip_df["count_under_18"] = under_18.loc[ + hh_ids_person + ].values.astype(np.float32) + cps_tip_df["count_under_6"] = under_6.loc[hh_ids_person].values.astype( + np.float32 + ) + else: + cps_tip_df["count_under_18"] = 0.0 + cps_tip_df["count_under_6"] = 0.0 + + person_states = _person_state_fips(data, state_fips, time_period) + cps_tip_df["state_fips"] = person_states.astype(np.float32) + + # Train and predict tips + tip_predictors = SIPP_TIPS_PREDICTORS + ["state_fips"] + qrf = QRF() + logger.info( + "SIPP tips QRF: %d train, %d test", + len(tip_train), + len(cps_tip_df), + ) + fitted = qrf.fit( + X_train=tip_train, + predictors=tip_predictors, + imputed_variables=["tip_income"], + ) + tip_preds = fitted.predict(X_test=cps_tip_df) + data["tip_income"] = { + time_period: tip_preds["tip_income"].values, + } + del fitted, tip_preds + gc.collect() + + logger.info("SIPP tip imputation complete") + + # --- Asset imputation --- + # Reload SIPP for assets (uses full file) + try: + hf_hub_download( + repo_id="PolicyEngine/policyengine-us-data", + filename="pu2023.csv", + repo_type="model", + local_dir=STORAGE_FOLDER, + ) + asset_cols = [ + "SSUID", + "PNUM", + "MONTHCODE", + "WPFINWGT", + "TAGE", + "ESEX", + "EMS", + "TPTOTINC", + "TVAL_BANK", + "TVAL_STMF", + "TVAL_BOND", + ] + asset_df = pd.read_csv( + STORAGE_FOLDER / "pu2023.csv", + delimiter="|", + usecols=asset_cols, + ) + asset_df = asset_df[asset_df.MONTHCODE == 12] + + asset_df["bank_account_assets"] = asset_df["TVAL_BANK"].fillna(0) + asset_df["stock_assets"] = asset_df["TVAL_STMF"].fillna(0) + asset_df["bond_assets"] = asset_df["TVAL_BOND"].fillna(0) + asset_df["age"] = asset_df.TAGE + asset_df["is_female"] = asset_df.ESEX == 2 + asset_df["is_married"] = asset_df.EMS == 1 + asset_df["employment_income"] = asset_df.TPTOTINC * 12 + asset_df["household_weight"] = asset_df.WPFINWGT + asset_df["is_under_18"] = asset_df.TAGE < 18 + asset_df["count_under_18"] = ( + asset_df.groupby("SSUID")["is_under_18"] + .sum() + .loc[asset_df.SSUID.values] + .values + ) + + asset_train_cols = [ + "employment_income", + "bank_account_assets", + "stock_assets", + "bond_assets", + "age", + "is_female", + "is_married", + "count_under_18", + "household_weight", + ] + asset_train = asset_df[asset_train_cols].dropna() + asset_train = asset_train.loc[ + rng.choice( + asset_train.index, + size=min(20_000, len(asset_train)), + replace=True, + p=( + asset_train.household_weight + / asset_train.household_weight.sum() + ), + ) + ] + + # Assign random states to SIPP donor + asset_state_idx = rng.choice( + len(donor_states), + size=len(asset_train), + p=block_probs, + ) + asset_train["state_fips"] = donor_states[asset_state_idx].astype( + np.float32 + ) + + # Build CPS receiver for assets + # is_female, is_married, count_under_18 need special + # handling — is_male is PE, is_married is Family-level + cps_asset_df = _build_cps_receiver( + data, + time_period, + dataset_path, + ["employment_income", "age", "is_male"], + ) + # is_female = NOT is_male + if "is_male" in cps_asset_df.columns: + cps_asset_df["is_female"] = ( + ~cps_asset_df["is_male"].astype(bool) + ).astype(np.float32) + else: + cps_asset_df["is_female"] = 0.0 + # is_married from marital_unit membership + if "is_married" in data: + cps_asset_df["is_married"] = data["is_married"][ + time_period + ].astype(np.float32) + else: + cps_asset_df["is_married"] = 0.0 + # count_under_18 + cps_asset_df["count_under_18"] = ( + cps_tip_df["count_under_18"] + if "count_under_18" in cps_tip_df.columns + else 0.0 + ) + + cps_asset_df["state_fips"] = person_states.astype(np.float32) + + asset_predictors = SIPP_ASSETS_PREDICTORS + ["state_fips"] + asset_vars = [ + "bank_account_assets", + "stock_assets", + "bond_assets", + ] + qrf = QRF() + logger.info( + "SIPP assets QRF: %d train, %d test", + len(asset_train), + len(cps_asset_df), + ) + fitted = qrf.fit( + X_train=asset_train, + predictors=asset_predictors, + imputed_variables=asset_vars, + ) + asset_preds = fitted.predict(X_test=cps_asset_df) + + for var in asset_vars: + data[var] = { + time_period: asset_preds[var].values, + } + del fitted, asset_preds + gc.collect() + + logger.info("SIPP asset imputation complete") + + except Exception as e: + logger.warning( + "SIPP asset imputation failed: %s. " "Keeping existing values.", + e, + ) + + return data + + +def _impute_scf( + data: Dict[str, Dict[int, np.ndarray]], + state_fips: np.ndarray, + time_period: int, + dataset_path: Optional[str] = None, +) -> Dict[str, Dict[int, np.ndarray]]: + """Impute net_worth and auto_loan from SCF with state. + + Trains QRF on SCF_2022 with state_fips as predictor. + Since SCF doesn't have state, random states are assigned + to donor records. + + Args: + data: CPS data dict. + state_fips: State FIPS per household. + time_period: Tax year. + dataset_path: Path to CPS h5 for Microsimulation. + + Returns: + Updated data dict. + """ + from microimpute.models.qrf import QRF + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.clone_and_assign import ( + load_global_block_distribution, + ) + from policyengine_us_data.datasets.scf.scf import SCF_2022 + + # Load state distribution for random assignment + _, _, donor_states, block_probs = load_global_block_distribution() + rng = np.random.default_rng(seed=77) + + # Load SCF donor data + scf_dataset = SCF_2022() + scf_data = scf_dataset.load_dataset() + scf_df = pd.DataFrame({key: scf_data[key] for key in scf_data.keys()}) + + # Assign random states to SCF + scf_state_idx = rng.choice( + len(donor_states), size=len(scf_df), p=block_probs + ) + scf_df["state_fips"] = donor_states[scf_state_idx].astype(np.float32) + + scf_predictors = SCF_PREDICTORS + ["state_fips"] + + # Check which predictors are available + available_preds = [p for p in scf_predictors if p in scf_df.columns] + missing_preds = [p for p in scf_predictors if p not in scf_df.columns] + if missing_preds: + logger.warning("SCF missing predictors: %s", missing_preds) + scf_predictors = available_preds + + scf_vars = SCF_IMPUTED_VARIABLES + # SCF uses 'networth' not 'net_worth' + scf_rename = {} + if "networth" in scf_df.columns and "net_worth" not in scf_df.columns: + scf_df["net_worth"] = scf_df["networth"] + scf_rename["networth"] = "net_worth" + + available_vars = [v for v in scf_vars if v in scf_df.columns] + if not available_vars: + logger.warning("No SCF imputed variables available. Skipping.") + return data + + weights = scf_df.get("wgt") + + # Sample SCF for training + donor = scf_df[scf_predictors + available_vars].copy() + if weights is not None: + donor["wgt"] = weights + donor = donor.dropna(subset=scf_predictors) + donor = donor.sample(frac=0.5, random_state=42).reset_index(drop=True) + + # Build CPS receiver — many predictors are derived + # Use PE Microsimulation for what it knows, derive the rest + pe_vars = [ + "age", + "is_male", + "employment_income", + ] + cps_df = _build_cps_receiver(data, time_period, dataset_path, pe_vars) + + # Derive is_female from is_male + if "is_male" in cps_df.columns: + cps_df["is_female"] = (~cps_df["is_male"].astype(bool)).astype( + np.float32 + ) + else: + cps_df["is_female"] = 0.0 + + # Derived predictors from data dict + for var in [ + "cps_race", + "is_married", + "own_children_in_household", + ]: + if var in data: + cps_df[var] = data[var][time_period].astype(np.float32) + else: + cps_df[var] = 0.0 + + # Composite income predictors (matching cps.py SCF logic) + for var in [ + "taxable_interest_income", + "tax_exempt_interest_income", + "qualified_dividend_income", + "non_qualified_dividend_income", + ]: + if var in data: + cps_df[var] = data[var][time_period].astype(np.float32) + cps_df["interest_dividend_income"] = ( + cps_df.get("taxable_interest_income", 0) + + cps_df.get("tax_exempt_interest_income", 0) + + cps_df.get("qualified_dividend_income", 0) + + cps_df.get("non_qualified_dividend_income", 0) + ).astype(np.float32) + + for var in [ + "tax_exempt_private_pension_income", + "taxable_private_pension_income", + "social_security_retirement", + ]: + if var in data: + cps_df[var] = data[var][time_period].astype(np.float32) + cps_df["social_security_pension_income"] = ( + cps_df.get("tax_exempt_private_pension_income", 0) + + cps_df.get("taxable_private_pension_income", 0) + + cps_df.get("social_security_retirement", 0) + ).astype(np.float32) + + person_states = _person_state_fips(data, state_fips, time_period) + cps_df["state_fips"] = person_states.astype(np.float32) + + # Train and predict + qrf = QRF() + logger.info( + "SCF QRF: %d train, %d test, vars=%s", + len(donor), + len(cps_df), + available_vars, + ) + fitted = qrf.fit( + X_train=donor, + predictors=scf_predictors, + imputed_variables=available_vars, + weight_col="wgt" if weights is not None else None, + tune_hyperparameters=False, + ) + preds = fitted.predict(X_test=cps_df) + + # SCF variables (net_worth, auto_loan_*) are household-level, + # but QRF predicts at person level. Aggregate back to household + # by taking the first person's value in each household. + hh_ids = data["household_id"][time_period] + person_hh_ids = data.get("person_household_id", {}).get(time_period) + + for var in available_vars: + person_vals = preds[var].values + entity = _get_variable_entity(var) + if entity == "household" and person_hh_ids is not None: + # Map person-level predictions to household level + hh_vals = np.zeros(len(hh_ids), dtype=np.float32) + hh_to_idx = {int(hid): i for i, hid in enumerate(hh_ids)} + seen = set() + for p_idx, p_hh in enumerate(person_hh_ids): + hh_key = int(p_hh) + if hh_key not in seen: + seen.add(hh_key) + hh_vals[hh_to_idx[hh_key]] = person_vals[p_idx] + data[var] = {time_period: hh_vals} + logger.info( + " %s: person(%d) -> household(%d)", + var, + len(person_vals), + len(hh_vals), + ) + else: + data[var] = {time_period: person_vals} + + del fitted, preds + gc.collect() + + logger.info("SCF imputation complete: %s", available_vars) + return data diff --git a/policyengine_us_data/calibration/unified_calibration.py b/policyengine_us_data/calibration/unified_calibration.py new file mode 100644 index 000000000..682dc89de --- /dev/null +++ b/policyengine_us_data/calibration/unified_calibration.py @@ -0,0 +1,737 @@ +""" +Unified L0 calibration pipeline. + +New pipeline flow: + 1. Load raw CPS (~55K households) + 2. Clone 10x (v1) / 100x (v2) + 3. Assign random geography (census block -> state, county, CD) + 4. QRF imputation: all vars with state as predictor + a. ACS -> rent, real_estate_taxes + b. SIPP -> tip_income, bank/stock/bond_assets + c. SCF -> net_worth, auto_loan_balance/interest + d. PUF clone (2x) -> 67 tax variables + 5. PE simulation (via matrix builder) + 6. Build unified sparse calibration matrix + 7. L0-regularized optimization -> calibrated weights + +Two presets control output size via L0 regularization: +- local: L0=1e-8, ~3-4M records (for local area dataset) +- national: L0=1e-4, ~50K records (for web app) + +Usage: + python -m policyengine_us_data.calibration.unified_calibration \\ + --dataset path/to/cps_2024.h5 \\ + --db-path path/to/policy_data.db \\ + --output path/to/weights.npy \\ + --preset local \\ + --epochs 100 +""" + +import argparse +import builtins +import logging +import sys +from pathlib import Path + +import numpy as np + +# Force line-buffered stdout/stderr so logs appear +# immediately under nohup/redirect. +if not sys.stderr.isatty(): + sys.stderr.reconfigure(line_buffering=True) +if not sys.stdout.isatty(): + sys.stdout.reconfigure(line_buffering=True) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + stream=sys.stderr, +) +logger = logging.getLogger(__name__) + +# L0 presets +PRESETS = { + "local": 1e-8, # ~3-4M records + "national": 1e-4, # ~50K records +} + +# Shared hyperparameters (matching local area calibration) +BETA = 0.35 +GAMMA = -0.1 +ZETA = 1.1 +INIT_KEEP_PROB = 0.999 +LOG_WEIGHT_JITTER_SD = 0.05 +LOG_ALPHA_JITTER_SD = 0.01 +LAMBDA_L2 = 1e-12 +LEARNING_RATE = 0.15 +DEFAULT_EPOCHS = 100 +DEFAULT_N_CLONES = 10 + + +def parse_args(argv=None): + """Parse CLI arguments. + + Args: + argv: Optional list of argument strings. Defaults to + sys.argv if None. + + Returns: + Parsed argparse.Namespace. + """ + parser = argparse.ArgumentParser( + description="Unified L0 calibration pipeline" + ) + parser.add_argument( + "--dataset", + default=None, + help="Path to raw CPS h5 file (or extended CPS)", + ) + parser.add_argument( + "--puf-dataset", + default=None, + help="Path to PUF h5 file for QRF training", + ) + parser.add_argument( + "--db-path", + default=None, + help="Path to policy_data.db", + ) + parser.add_argument( + "--output", + default=None, + help="Path to save calibrated weights (.npy)", + ) + parser.add_argument( + "--n-clones", + type=int, + default=DEFAULT_N_CLONES, + help=f"Number of clones (default: {DEFAULT_N_CLONES})", + ) + parser.add_argument( + "--preset", + choices=list(PRESETS.keys()), + default=None, + help="L0 preset: local (~3-4M) or national (~50K)", + ) + parser.add_argument( + "--lambda-l0", + type=float, + default=None, + help="Custom L0 penalty (overrides preset)", + ) + parser.add_argument( + "--epochs", + type=int, + default=DEFAULT_EPOCHS, + help=f"Training epochs (default: {DEFAULT_EPOCHS})", + ) + parser.add_argument( + "--device", + default="cpu", + choices=["cpu", "cuda"], + help="Device for training", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for geography assignment", + ) + parser.add_argument( + "--skip-puf", + action="store_true", + help="Skip PUF clone + QRF (use raw CPS as-is)", + ) + parser.add_argument( + "--skip-source-impute", + action="store_true", + help="Skip ACS/SIPP/SCF re-imputation with state", + ) + parser.add_argument( + "--stratum-groups", + type=str, + default=None, + help=( + "Comma-separated stratum group IDs to calibrate " + "(e.g. '1,2,3'). Default: all targets with " + "calibrate=1 in DB." + ), + ) + return parser.parse_args(argv) + + +def _build_puf_cloned_dataset( + dataset_path: str, + puf_dataset_path: str, + state_fips: np.ndarray, + time_period: int = 2024, + skip_qrf: bool = False, + skip_source_impute: bool = False, +) -> str: + """Build a PUF-cloned dataset from raw CPS. + + Loads the CPS dataset, runs source imputations (ACS/SIPP/SCF) + with state as predictor, then PUF clone + QRF imputation. + + Args: + dataset_path: Path to raw CPS h5 file. + puf_dataset_path: Path to PUF h5 file. + state_fips: State FIPS per household (from geography + assignment, for the base n_records only). + time_period: Tax year. + skip_qrf: Skip QRF imputation (for testing). + skip_source_impute: Skip ACS/SIPP/SCF imputations. + + Returns: + Path to the PUF-cloned h5 file. + """ + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.puf_impute import ( + puf_clone_dataset, + ) + + logger.info("Building PUF-cloned dataset from %s", dataset_path) + + # Load CPS data + sim = Microsimulation(dataset=dataset_path) + data = sim.dataset.load_dataset() + + # Convert to the time_period_arrays format expected by puf_clone + data_dict = {} + for var in data: + values = data[var][...] + data_dict[var] = {time_period: values} + + # Source imputations (ACS/SIPP/SCF) with state as predictor + if not skip_source_impute: + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + data_dict = impute_source_variables( + data=data_dict, + state_fips=state_fips, + time_period=time_period, + dataset_path=dataset_path, + ) + + # Determine PUF dataset + puf_dataset = puf_dataset_path if not skip_qrf else None + + # PUF clone + QRF impute + new_data = puf_clone_dataset( + data=data_dict, + state_fips=state_fips, + time_period=time_period, + puf_dataset=puf_dataset, + skip_qrf=skip_qrf, + dataset_path=dataset_path, + ) + + # Save expanded dataset + output_path = str( + Path(dataset_path).parent / f"puf_cloned_{Path(dataset_path).stem}.h5" + ) + + import h5py + + with h5py.File(output_path, "w") as f: + for var, time_dict in new_data.items(): + for tp, values in time_dict.items(): + key = f"{var}/{tp}" + f.create_dataset(key, data=values) + + logger.info("PUF-cloned dataset saved to %s", output_path) + return output_path + + +def log_achievable_targets(X_sparse) -> None: + """Log how many targets are achievable vs impossible. + + Impossible targets have all-zero rows in the matrix — no + record can contribute. They stay in the matrix as constant + error terms so metrics reflect the true picture. + + Args: + X_sparse: Sparse calibration matrix (targets x records). + """ + row_sums = np.array(X_sparse.sum(axis=1)).flatten() + achievable = row_sums > 0 + n_impossible = (~achievable).sum() + logger.info( + "Achievable: %d / %d targets (%d impossible, kept)", + achievable.sum(), + len(achievable), + n_impossible, + ) + + +def fit_l0_weights( + X_sparse, + targets: np.ndarray, + lambda_l0: float, + epochs: int = DEFAULT_EPOCHS, + device: str = "cpu", + verbose_freq: int = None, +) -> np.ndarray: + """Fit L0-regularized calibration weights. + + Args: + X_sparse: Sparse matrix (targets x records). + targets: Target values array. + lambda_l0: L0 regularization strength. + epochs: Training epochs. + device: Torch device. + verbose_freq: How often to print progress. Defaults to + every 10% of epochs. + + Returns: + Weight array of shape (n_records,). + """ + import sys + import time + + try: + from l0.calibration import SparseCalibrationWeights + except ImportError: + raise ImportError( + "l0-python required. Install: pip install l0-python" + ) + + import torch + + n_total = X_sparse.shape[1] + initial_weights = np.ones(n_total) * 100 + + logger.info( + "Starting L0 calibration: %d targets, %d features, " + "lambda_l0=%.1e, epochs=%d", + X_sparse.shape[0], + n_total, + lambda_l0, + epochs, + ) + + model = SparseCalibrationWeights( + n_features=n_total, + beta=BETA, + gamma=GAMMA, + zeta=ZETA, + init_keep_prob=INIT_KEEP_PROB, + init_weights=initial_weights, + log_weight_jitter_sd=LOG_WEIGHT_JITTER_SD, + log_alpha_jitter_sd=LOG_ALPHA_JITTER_SD, + device=device, + ) + + if verbose_freq is None: + verbose_freq = max(1, epochs // 10) + + # Monkey-patch print to flush + log, so epoch output + # isn't lost to stdout buffering under nohup/redirect. + _builtin_print = builtins.print + + def _flushed_print(*args, **kwargs): + _builtin_print(*args, **kwargs) + sys.stdout.flush() + + builtins.print = _flushed_print + + t_fit_start = time.time() + try: + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=lambda_l0, + lambda_l2=LAMBDA_L2, + lr=LEARNING_RATE, + epochs=epochs, + loss_type="relative", + verbose=True, + verbose_freq=verbose_freq, + ) + finally: + builtins.print = _builtin_print + + t_fit_end = time.time() + logger.info( + "L0 optimization finished in %.1f min (%.1f sec/epoch)", + (t_fit_end - t_fit_start) / 60, + (t_fit_end - t_fit_start) / epochs, + ) + + with torch.no_grad(): + weights = model.get_weights(deterministic=True).cpu().numpy() + + n_nonzero = (weights > 0).sum() + logger.info( + "Calibration complete. Non-zero: %d / %d (%.1f%% sparsity)", + n_nonzero, + n_total, + (1 - n_nonzero / n_total) * 100, + ) + return weights + + +def run_calibration( + dataset_path: str, + db_path: str, + n_clones: int = DEFAULT_N_CLONES, + lambda_l0: float = 1e-8, + epochs: int = DEFAULT_EPOCHS, + device: str = "cpu", + seed: int = 42, + puf_dataset_path: str = None, + skip_puf: bool = False, + skip_source_impute: bool = False, + stratum_group_ids: list = None, +): + """Run unified calibration pipeline. + + New pipeline: + 1. Load raw CPS -> get n_records + 2. Clone n_clones x, assign geography + 3. Source imputations (ACS/SIPP/SCF) with state + 4. PUF clone (2x) + QRF impute with state + 5. Build sparse calibration matrix + 6. L0 calibration + + Args: + dataset_path: Path to raw CPS h5 file. + db_path: Path to policy_data.db. + n_clones: Number of dataset clones. + lambda_l0: L0 regularization strength. + epochs: Training epochs. + device: Torch device. + seed: Random seed. + puf_dataset_path: Path to PUF h5 for QRF training. + skip_puf: Skip PUF clone step. + skip_source_impute: Skip ACS/SIPP/SCF imputations. + stratum_group_ids: Only calibrate to targets in + these stratum groups. None means all targets + with ``calibrate = 1``. + + Returns: + Tuple of (weights, targets_df, X_sparse, + target_names). + """ + import time + + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.clone_and_assign import ( + assign_random_geography, + double_geography_for_puf, + ) + from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + ) + + t0 = time.time() + + # Step 1: Load raw CPS and get record count + logger.info("Loading dataset from %s", dataset_path) + sim = Microsimulation(dataset=dataset_path) + n_records = len( + sim.calculate("household_id", map_to="household").values + ) + logger.info("Loaded %d households", n_records) + del sim + + # Step 2: Clone and assign geography + logger.info( + "Assigning geography: %d records x %d clones = %d total", + n_records, + n_clones, + n_records * n_clones, + ) + geography = assign_random_geography( + n_records=n_records, + n_clones=n_clones, + seed=seed, + ) + logger.info( + "Geography assigned in %.1f sec", time.time() - t0 + ) + + # Step 3: PUF clone (2x) + QRF imputation + if not skip_puf: + # Get state_fips for the base records (first clone) + base_states = geography.state_fips[:n_records] + + puf_cloned_path = _build_puf_cloned_dataset( + dataset_path=dataset_path, + puf_dataset_path=puf_dataset_path or "", + state_fips=base_states, + time_period=2024, + skip_qrf=puf_dataset_path is None, + skip_source_impute=skip_source_impute, + ) + + # Double geography to match PUF-cloned records + geography = double_geography_for_puf(geography) + dataset_for_matrix = puf_cloned_path + n_records_for_matrix = n_records * 2 + + logger.info( + "After PUF clone: %d records x %d clones = %d total", + n_records_for_matrix, + n_clones, + n_records_for_matrix * n_clones, + ) + else: + # Even without PUF, run source imputations if requested + if not skip_source_impute: + from policyengine_us import Microsimulation + + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + sim = Microsimulation(dataset=dataset_path) + raw_data = sim.dataset.load_dataset() + data_dict = {} + for var in raw_data: + data_dict[var] = {2024: raw_data[var][...]} + del sim + + base_states = geography.state_fips[:n_records] + data_dict = impute_source_variables( + data=data_dict, + state_fips=base_states, + time_period=2024, + dataset_path=dataset_path, + ) + + # Save updated dataset + import h5py + + source_path = str( + Path(dataset_path).parent + / f"source_imputed_{Path(dataset_path).stem}.h5" + ) + with h5py.File(source_path, "w") as f: + for var, time_dict in data_dict.items(): + for tp, values in time_dict.items(): + f.create_dataset(f"{var}/{tp}", data=values) + dataset_for_matrix = source_path + logger.info( + "Source-imputed dataset saved to %s", + source_path, + ) + else: + dataset_for_matrix = dataset_path + n_records_for_matrix = n_records + + # Step 5: Build sparse calibration matrix + t_matrix_start = time.time() + db_uri = f"sqlite:///{db_path}" + builder = UnifiedMatrixBuilder( + db_uri=db_uri, time_period=2024 + ) + targets_df, X_sparse, target_names = builder.build_matrix( + dataset_path=dataset_for_matrix, + geography=geography, + stratum_group_ids=stratum_group_ids, + ) + t_matrix_end = time.time() + logger.info( + "Matrix build completed in %.1f min", + (t_matrix_end - t_matrix_start) / 60, + ) + + # Report achievable vs impossible targets (keep all) + targets = targets_df["value"].values + log_achievable_targets(X_sparse) + + # Step 6: Run L0 calibration + weights = fit_l0_weights( + X_sparse=X_sparse, + targets=targets, + lambda_l0=lambda_l0, + epochs=epochs, + device=device, + ) + + logger.info( + "Total pipeline time: %.1f min", + (time.time() - t0) / 60, + ) + return weights, targets_df, X_sparse, target_names + + +def compute_diagnostics( + weights: np.ndarray, + X_sparse, + targets_df, + target_names: list, +) -> "pd.DataFrame": + """Compute per-target diagnostics from calibrated weights. + + Args: + weights: Calibrated weight array. + X_sparse: Sparse matrix (targets x records). + targets_df: DataFrame with target values. + target_names: List of target name strings. + + Returns: + DataFrame with columns: target, true_value, estimate, + rel_error, abs_rel_error, achievable. + """ + import pandas as pd + + estimates = X_sparse.dot(weights) + true_values = targets_df["value"].values + row_sums = np.array(X_sparse.sum(axis=1)).flatten() + + rel_errors = np.where( + np.abs(true_values) > 0, + (estimates - true_values) / np.abs(true_values), + 0.0, + ) + abs_rel_errors = np.abs(rel_errors) + achievable = row_sums > 0 + + return pd.DataFrame( + { + "target": target_names, + "true_value": true_values, + "estimate": estimates, + "rel_error": rel_errors, + "abs_rel_error": abs_rel_errors, + "achievable": achievable, + } + ) + + +def main(argv=None): + """Entry point for CLI usage. + + Args: + argv: Optional list of argument strings. + """ + import json + import time + + import pandas as pd + + args = parse_args(argv) + + from policyengine_us_data.storage import STORAGE_FOLDER + + dataset_path = args.dataset or str( + STORAGE_FOLDER / "cps_2024_full.h5" + ) + puf_dataset_path = args.puf_dataset or str( + STORAGE_FOLDER / "puf_2024.h5" + ) + db_path = args.db_path or str( + STORAGE_FOLDER / "calibration" / "policy_data.db" + ) + output_path = args.output or str( + STORAGE_FOLDER / "calibration" / "unified_weights.npy" + ) + + # Resolve L0 + if args.lambda_l0 is not None: + lambda_l0 = args.lambda_l0 + elif args.preset is not None: + lambda_l0 = PRESETS[args.preset] + else: + lambda_l0 = PRESETS["local"] + logger.info("No preset/lambda specified, using 'local'") + + # Parse stratum group filter + stratum_group_ids = None + if args.stratum_groups: + stratum_group_ids = [ + int(x.strip()) for x in args.stratum_groups.split(",") + ] + logger.info( + "Filtering to stratum groups: %s", + stratum_group_ids, + ) + + t_start = time.time() + + weights, targets_df, X_sparse, target_names = run_calibration( + dataset_path=dataset_path, + db_path=db_path, + n_clones=args.n_clones, + lambda_l0=lambda_l0, + epochs=args.epochs, + device=args.device, + seed=args.seed, + puf_dataset_path=puf_dataset_path, + skip_puf=args.skip_puf, + skip_source_impute=args.skip_source_impute, + stratum_group_ids=stratum_group_ids, + ) + + t_calibration = time.time() + + # Save weights + np.save(output_path, weights) + logger.info("Weights saved to %s", output_path) + + # Save per-target diagnostics + output_dir = Path(output_path).parent + diag_df = compute_diagnostics( + weights, X_sparse, targets_df, target_names + ) + diag_path = output_dir / "unified_diagnostics.csv" + diag_df.to_csv(diag_path, index=False) + + ach = diag_df[diag_df.achievable] + err_pct = ach.abs_rel_error * 100 + logger.info( + "Diagnostics saved to %s: %d targets, " + "mean_error=%.1f%%, median=%.1f%%, " + "within_10%%=%.1f%%, within_25%%=%.1f%%", + diag_path, + len(ach), + err_pct.mean(), + err_pct.median(), + (err_pct < 10).mean() * 100, + (err_pct < 25).mean() * 100, + ) + + # Save run config + t_end = time.time() + run_config = { + "dataset": dataset_path, + "db_path": db_path, + "n_clones": args.n_clones, + "lambda_l0": lambda_l0, + "epochs": args.epochs, + "device": args.device, + "seed": args.seed, + "skip_puf": args.skip_puf, + "skip_source_impute": args.skip_source_impute, + "stratum_group_ids": stratum_group_ids, + "n_targets": len(targets_df), + "n_achievable": int(diag_df.achievable.sum()), + "n_impossible": int((~diag_df.achievable).sum()), + "n_records": X_sparse.shape[1], + "n_nonzero_matrix": int(X_sparse.nnz), + "weight_sum": float(weights.sum()), + "weight_nonzero": int((weights > 0).sum()), + "mean_error_pct": float(err_pct.mean()), + "median_error_pct": float(err_pct.median()), + "within_10_pct": float((err_pct < 10).mean() * 100), + "within_25_pct": float((err_pct < 25).mean() * 100), + "elapsed_seconds": round(t_end - t_start, 1), + "calibration_seconds": round( + t_calibration - t_start, 1 + ), + } + config_path = output_dir / "unified_run_config.json" + with open(config_path, "w") as f: + json.dump(run_config, f, indent=2) + logger.info("Run config saved to %s", config_path) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/calibration/unified_matrix_builder.py b/policyengine_us_data/calibration/unified_matrix_builder.py new file mode 100644 index 000000000..fd0206b88 --- /dev/null +++ b/policyengine_us_data/calibration/unified_matrix_builder.py @@ -0,0 +1,635 @@ +""" +Unified sparse matrix builder for calibration. + +Builds a sparse calibration matrix for cloned+geography-assigned CPS +records. Processes clone-by-clone: for each clone, sets each +record's state_fips to its assigned value, simulates, and extracts +variable values. Every simulation result is used. + +Matrix shape: (n_targets, n_records * n_clones) +Column ordering: index i = clone_idx * n_records + record_idx +""" + +import logging +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +from scipy import sparse +from sqlalchemy import text + +from policyengine_us_data.calibration.base_matrix_builder import ( + BaseMatrixBuilder, +) + +logger = logging.getLogger(__name__) + +# Geographic constraint variables +_GEO_STATE_VARS = {"state_fips", "state_code"} +_GEO_CD_VARS = {"congressional_district_geoid"} +_GEO_VARS = _GEO_STATE_VARS | _GEO_CD_VARS + +# Synthetic constraint variables to skip +_SYNTHETIC_CONSTRAINT_VARS = {"target_category"} + +# Count variables (value = 1 per entity satisfying constraints) +COUNT_VARIABLES = { + "person_count", + "tax_unit_count", + "household_count", + "spm_unit_count", +} + + +class UnifiedMatrixBuilder(BaseMatrixBuilder): + """Build sparse calibration matrix for cloned CPS records. + + Processes clone-by-clone: each clone's 111K records get their + assigned geography, are simulated, and the results fill the + corresponding columns. This ensures state-dependent variables + (state income tax, state benefits) are correct for the assigned + geography. + + Args: + db_uri: SQLAlchemy-style database URI. + time_period: Tax year for the calibration (e.g. 2024). + """ + + def __init__(self, db_uri: str, time_period: int): + super().__init__(db_uri, time_period) + + # ------------------------------------------------------------------ + # Database queries + # ------------------------------------------------------------------ + + def _query_active_targets( + self, + calibrate_only: bool = True, + stratum_group_ids: List[int] = None, + ) -> pd.DataFrame: + """Query active targets for calibration. + + Args: + calibrate_only: If True, only include targets with + ``calibrate = 1``. Set False to include all + active targets (e.g. for diagnostics). + stratum_group_ids: If provided, only include targets + whose stratum belongs to one of these group IDs. + + Returns: + DataFrame with columns: target_id, stratum_id, + variable, value, period, reform_id, tolerance, + stratum_group_id, stratum_notes, target_notes. + """ + conditions = ["t.active = 1"] + if calibrate_only: + conditions.append("t.calibrate = 1") + if stratum_group_ids: + ids_str = ",".join(str(i) for i in stratum_group_ids) + conditions.append( + f"s.stratum_group_id IN ({ids_str})" + ) + where_clause = " AND ".join(conditions) + + query = f""" + SELECT t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.reform_id, + t.tolerance, + t.notes AS target_notes, + s.stratum_group_id, + s.notes AS stratum_notes + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE {where_clause} + ORDER BY t.target_id + """ + with self.engine.connect() as conn: + df = pd.read_sql(query, conn) + + # Filter out zero/null target values. + df = df[ + df["value"].notna() + & ~np.isclose(df["value"].values, 0.0, atol=0.1) + ].reset_index(drop=True) + return df + + def _get_all_constraints(self, stratum_id: int) -> List[dict]: + """Get constraints for stratum and all ancestors. + + Walks up the ``parent_stratum_id`` chain, collecting + constraints from each level. Filters out synthetic + constraints. + + Args: + stratum_id: Starting stratum whose full constraint + chain is needed. + + Returns: + List of constraint dicts with keys ``variable``, + ``operation``, ``value``. + """ + all_constraints: List[dict] = [] + visited: set = set() + current_id: Optional[int] = stratum_id + + while current_id is not None and current_id not in visited: + visited.add(current_id) + constraints = self._get_stratum_constraints(current_id) + all_constraints.extend(constraints) + + query = """ + SELECT parent_stratum_id FROM strata + WHERE stratum_id = :sid + """ + with self.engine.connect() as conn: + row = conn.execute( + text(query), {"sid": int(current_id)} + ).fetchone() + current_id = row[0] if row else None + + return [ + c + for c in all_constraints + if c["variable"] not in _SYNTHETIC_CONSTRAINT_VARS + ] + + # ------------------------------------------------------------------ + # Geographic classification + # ------------------------------------------------------------------ + + def _classify_constraint_geo( + self, constraints: List[dict] + ) -> Tuple[str, Optional[str]]: + """Classify geographic level and ID from constraints. + + CD-level takes priority over state-level (a CD target + always has a parent state constraint too). + + Args: + constraints: Full constraint list for a target. + + Returns: + Tuple of ``(geo_level, geo_id)`` where geo_level is + ``"national"``, ``"state"``, or ``"cd"``, and geo_id + is the FIPS/GEOID string or ``None`` for national. + """ + # Check CD first (highest specificity). + for c in constraints: + if c["variable"] in _GEO_CD_VARS: + return "cd", str(c["value"]) + # Then state. + for c in constraints: + if c["variable"] in _GEO_STATE_VARS: + return "state", str(int(float(c["value"]))) + return "national", None + + # ------------------------------------------------------------------ + # Target name generation + # ------------------------------------------------------------------ + + @staticmethod + def _make_target_name( + variable: str, + constraints: List[dict], + reform_id: int = 0, + ) -> str: + """Generate human-readable target label. + + Args: + variable: Target variable name. + constraints: Resolved constraints for the stratum. + reform_id: Reform identifier (0 = baseline). + + Returns: + A slash-separated label string. + """ + parts: List[str] = [] + + # Geographic level + geo_parts: List[str] = [] + for c in constraints: + if c["variable"] == "state_fips": + geo_parts.append(f"state_{c['value']}") + elif c["variable"] == "state_code": + geo_parts.append(f"state_{c['value']}") + elif c["variable"] == "congressional_district_geoid": + geo_parts.append(f"cd_{c['value']}") + + if geo_parts: + parts.append("/".join(geo_parts)) + else: + parts.append("national") + + if reform_id > 0: + parts.append(f"{variable}_expenditure") + else: + parts.append(variable) + + # Non-geo constraint summary + non_geo = [ + c + for c in constraints + if c["variable"] + not in ( + "state_fips", + "state_code", + "congressional_district_geoid", + ) + ] + if non_geo: + constraint_strs = [ + f"{c['variable']}{c['operation']}{c['value']}" for c in non_geo + ] + parts.append("[" + ",".join(constraint_strs) + "]") + + return "/".join(parts) + + # ------------------------------------------------------------------ + # Clone simulation + # ------------------------------------------------------------------ + + def _simulate_clone( + self, + dataset_path: str, + clone_block_geoid: np.ndarray, + clone_cd_geoid: np.ndarray, + clone_state_fips: np.ndarray, + n_records: int, + variables: set, + constraint_keys: set, + ) -> Tuple[Dict[str, np.ndarray], "Microsimulation"]: + """Simulate one clone with assigned geography. + + Loads the base dataset, overrides all geographic inputs + derived from the assigned census block, clears calculated + variables, and computes all needed target/constraint + variables. + + Args: + dataset_path: Path to the base extended CPS h5 file. + clone_block_geoid: Block GEOID (15-char str) for each + record, shape ``(n_records,)``. + clone_cd_geoid: Congressional district GEOID for each + record, shape ``(n_records,)``. + clone_state_fips: State FIPS for each record, + shape ``(n_records,)``. + n_records: Number of base records. + variables: Set of target variable names to compute. + constraint_keys: Set of constraint variable names + needed for mask evaluation. + + Returns: + Tuple of ``(var_values, sim)`` where var_values maps + variable name to household-level float32 array. + """ + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=dataset_path) + sim.default_calculation_period = self.time_period + + # Override all geography from assigned census block. + sim.set_input( + "block_geoid", + self.time_period, + clone_block_geoid, + ) + sim.set_input( + "state_fips", + self.time_period, + clone_state_fips.astype(np.int32), + ) + sim.set_input( + "congressional_district_geoid", + self.time_period, + clone_cd_geoid.astype(np.int64), + ) + # County FIPS = first 5 chars of block GEOID. + county_fips = np.array([b[:5] for b in clone_block_geoid]) + sim.set_input( + "county_fips", + self.time_period, + county_fips, + ) + + # Calculate all target variables. + var_values: Dict[str, np.ndarray] = {} + for var in variables: + if var in COUNT_VARIABLES: + continue + try: + var_values[var] = sim.calculate( + var, + self.time_period, + map_to="household", + ).values.astype(np.float32) + except Exception as exc: + logger.warning("Cannot calculate '%s': %s", var, exc) + + return var_values, sim + + # ------------------------------------------------------------------ + # Main build method + # ------------------------------------------------------------------ + + def build_matrix( + self, + dataset_path: str, + geography, + cache_dir: Optional[str] = None, + calibrate_only: bool = True, + stratum_group_ids: List[int] = None, + ) -> Tuple[pd.DataFrame, sparse.csr_matrix, List[str]]: + """Build sparse calibration matrix. + + Two-phase build: (1) simulate each clone and save + COO entries to disk, (2) assemble the full sparse + matrix from cached files. This keeps memory low during + simulation and allows resuming if interrupted. + + Args: + dataset_path: Path to the base extended CPS h5 file. + geography: Geography assignment object with + ``state_fips``, ``cd_geoid``, ``block_geoid`` + arrays and ``n_records``, ``n_clones`` + attributes. + cache_dir: Directory for per-clone COO caches. + If ``None``, COO data is held in memory + calibrate_only: Only include targets with + ``calibrate = 1``. + stratum_group_ids: Only include targets from + these stratum groups. + (suitable for tests only). + + Returns: + Tuple of ``(targets_df, X_sparse, target_names)`` + where: + + - **targets_df** -- DataFrame of target metadata. + - **X_sparse** -- sparse CSR matrix of shape + ``(n_targets, n_records * n_clones)``. + - **target_names** -- list of human-readable labels. + """ + n_records = geography.n_records + n_clones = geography.n_clones + n_total = n_records * n_clones + + # Build column index structures from geography. + state_to_cols: Dict[int, np.ndarray] = {} + cd_to_cols: Dict[str, np.ndarray] = {} + + state_col_lists: Dict[int, list] = defaultdict(list) + cd_col_lists: Dict[str, list] = defaultdict(list) + for col in range(n_total): + state_col_lists[int(geography.state_fips[col])].append(col) + cd_col_lists[str(geography.cd_geoid[col])].append(col) + state_to_cols = {s: np.array(c) for s, c in state_col_lists.items()} + cd_to_cols = {cd: np.array(c) for cd, c in cd_col_lists.items()} + + # Query targets from database. + targets_df = self._query_active_targets( + calibrate_only=calibrate_only, + stratum_group_ids=stratum_group_ids, + ) + n_targets = len(targets_df) + + logger.info( + "Building unified matrix: %d targets, %d total " + "columns (%d records x %d clones)", + n_targets, + n_total, + n_records, + n_clones, + ) + + # Pre-process targets: resolve constraints, classify geo. + constraint_cache: Dict[int, List[dict]] = {} + target_geo_info: List[Tuple[str, Optional[str]]] = [] + target_names: List[str] = [] + non_geo_constraints_list: List[List[dict]] = [] + + for _, row in targets_df.iterrows(): + sid = int(row["stratum_id"]) + if sid not in constraint_cache: + constraint_cache[sid] = self._get_all_constraints(sid) + constraints = constraint_cache[sid] + + geo_level, geo_id = self._classify_constraint_geo(constraints) + target_geo_info.append((geo_level, geo_id)) + + non_geo = [ + c for c in constraints if c["variable"] not in _GEO_VARS + ] + non_geo_constraints_list.append(non_geo) + + target_names.append( + self._make_target_name( + str(row["variable"]), + constraints, + reform_id=int(row["reform_id"]), + ) + ) + + # Collect all variables and constraint variables needed. + unique_variables = set(targets_df["variable"].values) + constraint_vars = set() + for non_geo in non_geo_constraints_list: + for c in non_geo: + constraint_vars.add(c["variable"]) + + # Two-phase build: save COO entries per clone to disk, + # then assemble in one pass. This keeps memory low during + # the expensive simulation phase and allows resume if the + # process is interrupted. + from pathlib import Path + + clone_dir = Path(cache_dir) if cache_dir else None + if clone_dir: + clone_dir.mkdir(parents=True, exist_ok=True) + + # Phase 1: simulate each clone, save COO entries. + for clone_idx in range(n_clones): + # Skip if already saved to disk. + if clone_dir: + coo_path = clone_dir / f"clone_{clone_idx:04d}.npz" + if coo_path.exists(): + logger.info( + "Clone %d / %d already cached, " "skipping.", + clone_idx + 1, + n_clones, + ) + continue + + col_start = clone_idx * n_records + col_end = col_start + n_records + + clone_blocks = geography.block_geoid[col_start:col_end] + clone_cds = geography.cd_geoid[col_start:col_end] + clone_states = geography.state_fips[col_start:col_end] + + logger.info( + "Processing clone %d / %d " + "(cols %d-%d, %d unique states)...", + clone_idx + 1, + n_clones, + col_start, + col_end - 1, + len(np.unique(clone_states)), + ) + + self._entity_rel_cache = None + var_values, sim = self._simulate_clone( + dataset_path=dataset_path, + clone_block_geoid=clone_blocks, + clone_cd_geoid=clone_cds, + clone_state_fips=clone_states, + n_records=n_records, + variables=unique_variables, + constraint_keys=constraint_vars, + ) + + mask_cache: Dict[tuple, np.ndarray] = {} + + def _get_mask( + constraints_list: List[dict], + ) -> np.ndarray: + key = tuple( + (c["variable"], c["operation"], c["value"]) + for c in sorted( + constraints_list, + key=lambda c: c["variable"], + ) + ) + if key not in mask_cache: + mask_cache[key] = self._evaluate_constraints_entity_aware( + sim, constraints_list, n_records + ) + return mask_cache[key] + + # Collect COO entries for this clone. + rows_list: list = [] + cols_list: list = [] + vals_list: list = [] + + for row_idx in range(n_targets): + variable = str(targets_df.iloc[row_idx]["variable"]) + geo_level, geo_id = target_geo_info[row_idx] + non_geo = non_geo_constraints_list[row_idx] + + if geo_level == "cd": + if geo_id not in cd_to_cols: + continue + cd_cols = cd_to_cols[geo_id] + clone_target_cols = cd_cols[ + (cd_cols >= col_start) & (cd_cols < col_end) + ] + elif geo_level == "state": + state_key = int(geo_id) + if state_key not in state_to_cols: + continue + s_cols = state_to_cols[state_key] + clone_target_cols = s_cols[ + (s_cols >= col_start) & (s_cols < col_end) + ] + else: + clone_target_cols = np.arange(col_start, col_end) + + if len(clone_target_cols) == 0: + continue + + mask = _get_mask(non_geo) + + if variable in COUNT_VARIABLES: + values = mask.astype(np.float32) + elif variable in var_values: + values = var_values[variable] * mask + else: + continue + + rec_idx = clone_target_cols % n_records + vals = values[rec_idx] + nonzero = vals != 0 + if nonzero.any(): + rows_list.append( + np.full( + nonzero.sum(), + row_idx, + dtype=np.int32, + ) + ) + cols_list.append( + clone_target_cols[nonzero].astype(np.int32) + ) + vals_list.append(vals[nonzero]) + + # Save COO entries for this clone. + if rows_list: + clone_rows = np.concatenate(rows_list) + clone_cols = np.concatenate(cols_list) + clone_vals = np.concatenate(vals_list) + else: + clone_rows = np.array([], dtype=np.int32) + clone_cols = np.array([], dtype=np.int32) + clone_vals = np.array([], dtype=np.float32) + + if clone_dir: + np.savez_compressed( + str(coo_path), + rows=clone_rows, + cols=clone_cols, + vals=clone_vals, + ) + logger.info( + "Clone %d: saved %d nonzero entries " "to %s", + clone_idx + 1, + len(clone_vals), + coo_path.name, + ) + # Free memory. + del var_values, sim + del rows_list, cols_list, vals_list + else: + # No cache dir: accumulate in memory + # (for tests). + if not hasattr(self, "_coo_parts"): + self._coo_parts = ([], [], []) + self._coo_parts[0].append(clone_rows) + self._coo_parts[1].append(clone_cols) + self._coo_parts[2].append(clone_vals) + + # Phase 2: assemble sparse matrix from COO files. + logger.info("Assembling sparse matrix from %d clones...", n_clones) + + if clone_dir: + all_rows, all_cols, all_vals = [], [], [] + for clone_idx in range(n_clones): + coo_path = clone_dir / f"clone_{clone_idx:04d}.npz" + data = np.load(str(coo_path)) + all_rows.append(data["rows"]) + all_cols.append(data["cols"]) + all_vals.append(data["vals"]) + rows = np.concatenate(all_rows) + cols = np.concatenate(all_cols) + vals = np.concatenate(all_vals) + else: + rows = np.concatenate(self._coo_parts[0]) + cols = np.concatenate(self._coo_parts[1]) + vals = np.concatenate(self._coo_parts[2]) + del self._coo_parts + + X_csr = sparse.csr_matrix( + (vals, (rows, cols)), + shape=(n_targets, n_total), + dtype=np.float32, + ) + logger.info( + "Matrix built: %d targets x %d columns, " "%d nonzero entries", + X_csr.shape[0], + X_csr.shape[1], + X_csr.nnz, + ) + return targets_df, X_csr, target_names diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 385ec1e97..29b3a0995 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -1,143 +1,12 @@ from policyengine_core.data import Dataset -import pandas as pd -from policyengine_us_data.utils import ( - pe_to_soi, - get_soi, - build_loss_matrix, - fmt, - HardConcrete, - print_reweighting_diagnostics, - set_seeds, -) import numpy as np -from tqdm import trange from typing import Type from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.datasets.cps.extended_cps import ( ExtendedCPS_2024, - CPS_2024, ) import logging -try: - import torch -except ImportError: - torch = None - - -def reweight( - original_weights, - loss_matrix, - targets_array, - log_path="calibration_log.csv", - epochs=500, - l0_lambda=2.6445e-07, - init_mean=0.999, # initial proportion with non-zero weights - temperature=0.25, - seed=1456, -): - target_names = np.array(loss_matrix.columns) - is_national = loss_matrix.columns.str.startswith("nation/") - loss_matrix = torch.tensor(loss_matrix.values, dtype=torch.float32) - nation_normalisation_factor = is_national * (1 / is_national.sum()) - state_normalisation_factor = ~is_national * (1 / (~is_national).sum()) - normalisation_factor = np.where( - is_national, nation_normalisation_factor, state_normalisation_factor - ) - normalisation_factor = torch.tensor( - normalisation_factor, dtype=torch.float32 - ) - targets_array = torch.tensor(targets_array, dtype=torch.float32) - - inv_mean_normalisation = 1 / np.mean(normalisation_factor.numpy()) - - def loss(weights): - if torch.isnan(weights).any(): - raise ValueError("Weights contain NaNs") - if torch.isnan(loss_matrix).any(): - raise ValueError("Loss matrix contains NaNs") - estimate = weights @ loss_matrix - if torch.isnan(estimate).any(): - raise ValueError("Estimate contains NaNs") - rel_error = ( - ((estimate - targets_array) + 1) / (targets_array + 1) - ) ** 2 - rel_error_normalized = ( - inv_mean_normalisation * rel_error * normalisation_factor - ) - if torch.isnan(rel_error_normalized).any(): - raise ValueError("Relative error contains NaNs") - return rel_error_normalized.mean() - - logging.info( - f"Sparse optimization using seed {seed}, temp {temperature} " - + f"init_mean {init_mean}, l0_lambda {l0_lambda}" - ) - set_seeds(seed) - - weights = torch.tensor( - np.log(original_weights), requires_grad=True, dtype=torch.float32 - ) - gates = HardConcrete( - len(original_weights), init_mean=init_mean, temperature=temperature - ) - # NOTE: Results are pretty sensitve to learning rates - # optimizer breaks down somewhere near .005, does better at above .1 - optimizer = torch.optim.Adam([weights] + list(gates.parameters()), lr=0.2) - start_loss = None - - iterator = trange(epochs * 2) # lower learning rate, harder optimization - performance = pd.DataFrame() - for i in iterator: - optimizer.zero_grad() - masked = torch.exp(weights) * gates() - l_main = loss(masked) - l = l_main + l0_lambda * gates.get_penalty() - if (log_path is not None) and (i % 10 == 0): - gates.eval() - estimates = (torch.exp(weights) * gates()) @ loss_matrix - gates.train() - estimates = estimates.detach().numpy() - df = pd.DataFrame( - { - "target_name": target_names, - "estimate": estimates, - "target": targets_array.detach().numpy(), - } - ) - df["epoch"] = i - df["error"] = df.estimate - df.target - df["rel_error"] = df.error / df.target - df["abs_error"] = df.error.abs() - df["rel_abs_error"] = df.rel_error.abs() - df["loss"] = df.rel_abs_error**2 - performance = pd.concat([performance, df], ignore_index=True) - - if (log_path is not None) and (i % 1000 == 0): - performance.to_csv(log_path, index=False) - if start_loss is None: - start_loss = l.item() - loss_rel_change = (l.item() - start_loss) / start_loss - l.backward() - iterator.set_postfix( - {"loss": l.item(), "loss_rel_change": loss_rel_change} - ) - optimizer.step() - if log_path is not None: - performance.to_csv(log_path, index=False) - - gates.eval() - final_weights_sparse = (torch.exp(weights) * gates()).detach().numpy() - - print_reweighting_diagnostics( - final_weights_sparse, - loss_matrix, - targets_array, - "L0 Sparse Solution", - ) - - return final_weights_sparse - class EnhancedCPS(Dataset): data_format = Dataset.TIME_PERIOD_ARRAYS @@ -156,81 +25,55 @@ def generate(self): 1, 0.1, len(original_weights) ) - bad_targets = [ - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All", - ] - - # Run the optimization procedure to get (close to) minimum loss weights for year in range(self.start_year, self.end_year + 1): - loss_matrix, targets_array = build_loss_matrix( - self.input_dataset, year - ) - zero_mask = np.isclose(targets_array, 0.0, atol=0.1) - bad_mask = loss_matrix.columns.isin(bad_targets) - keep_mask_bool = ~(zero_mask | bad_mask) - keep_idx = np.where(keep_mask_bool)[0] - loss_matrix_clean = loss_matrix.iloc[:, keep_idx] - targets_array_clean = targets_array[keep_idx] - assert loss_matrix_clean.shape[1] == targets_array_clean.size - - optimised_weights = reweight( - original_weights, - loss_matrix_clean, - targets_array_clean, - log_path="calibration_log.csv", - epochs=250, - seed=1456, - ) + optimised_weights = self._calibrate(sim, original_weights, year) data["household_weight"][year] = optimised_weights self.save_dataset(data) + def _calibrate(self, sim, original_weights, year): + """Run national calibration for one year. + + Reads active targets from policy_data.db via + NationalMatrixBuilder, then fits weights using + l0-python's SparseCalibrationWeights. + + Args: + sim: Microsimulation instance. + original_weights: Jittered household weights. + year: Tax year to calibrate. + + Returns: + Optimised weight array (n_households,). + """ + from policyengine_us_data.calibration.fit_national_weights import ( + build_calibration_inputs, + fit_national_weights, + initialize_weights, + ) -class ReweightedCPS_2024(Dataset): - data_format = Dataset.ARRAYS - file_path = STORAGE_FOLDER / "reweighted_cps_2024.h5" - name = "reweighted_cps_2024" - label = "Reweighted CPS 2024" - input_dataset = CPS_2024 - time_period = 2024 - - def generate(self): - from policyengine_us import Microsimulation + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + matrix, targets, names = build_calibration_inputs( + dataset_class=self.input_dataset, + time_period=year, + db_path=str(db_path), + sim=sim, + ) - sim = Microsimulation(dataset=self.input_dataset) - data = sim.dataset.load_dataset() - original_weights = sim.calculate("household_weight") - original_weights = original_weights.values + np.random.normal( - 1, 0.1, len(original_weights) + init_weights = initialize_weights(original_weights) + optimised_weights = fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=init_weights, + epochs=500, ) - for year in [2024]: - loss_matrix, targets_array = build_loss_matrix( - self.input_dataset, year - ) - optimised_weights = reweight( - original_weights, loss_matrix, targets_array - ) - data["household_weight"] = optimised_weights - self.save_dataset(data) + logging.info( + f"Calibration for {year}: " + f"{len(targets)} targets, " + f"{(optimised_weights > 0).sum():,} non-zero weights" + ) + return optimised_weights class EnhancedCPS_2024(EnhancedCPS): @@ -240,7 +83,7 @@ class EnhancedCPS_2024(EnhancedCPS): name = "enhanced_cps_2024" label = "Enhanced CPS 2024" file_path = STORAGE_FOLDER / "enhanced_cps_2024.h5" - url = "hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" + url = "hf://policyengine/policyengine-us-data/" "enhanced_cps_2024.h5" if __name__ == "__main__": diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py index 5ffe5474a..cdc2dbad8 100644 --- a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -9,21 +9,23 @@ import logging from collections import defaultdict from typing import Dict, List, Optional, Tuple + import numpy as np import pandas as pd from scipy import sparse -from sqlalchemy import create_engine, text logger = logging.getLogger(__name__) +from policyengine_us_data.calibration.base_matrix_builder import ( + BaseMatrixBuilder, +) from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( get_calculated_variables, - apply_op, _get_geo_level, ) -class SparseMatrixBuilder: +class SparseMatrixBuilder(BaseMatrixBuilder): """Build sparse calibration matrices for geo-stacking.""" def __init__( @@ -33,110 +35,9 @@ def __init__( cds_to_calibrate: List[str], dataset_path: Optional[str] = None, ): - self.db_uri = db_uri - self.engine = create_engine(db_uri) - self.time_period = time_period + super().__init__(db_uri, time_period) self.cds_to_calibrate = cds_to_calibrate self.dataset_path = dataset_path - self._entity_rel_cache = None - - def _build_entity_relationship(self, sim) -> pd.DataFrame: - """ - Build entity relationship DataFrame mapping persons to all entity IDs. - - This is used to evaluate constraints at the person level and then - aggregate to household level, handling variables defined at different - entity levels (person, tax_unit, household, spm_unit). - - Returns: - DataFrame with person_id, household_id, tax_unit_id, spm_unit_id - """ - if self._entity_rel_cache is not None: - return self._entity_rel_cache - - self._entity_rel_cache = pd.DataFrame( - { - "person_id": sim.calculate( - "person_id", map_to="person" - ).values, - "household_id": sim.calculate( - "household_id", map_to="person" - ).values, - "tax_unit_id": sim.calculate( - "tax_unit_id", map_to="person" - ).values, - "spm_unit_id": sim.calculate( - "spm_unit_id", map_to="person" - ).values, - } - ) - return self._entity_rel_cache - - def _evaluate_constraints_entity_aware( - self, state_sim, constraints: List[dict], n_households: int - ) -> np.ndarray: - """ - Evaluate non-geographic constraints at person level, aggregate to - household level using .any(). - - This properly handles constraints on variables defined at different - entity levels (e.g., tax_unit_is_filer at tax_unit level). Instead of - summing values at household level (which would give 2, 3, etc. for - households with multiple tax units), we evaluate at person level and - use .any() aggregation ("does this household have at least one person - satisfying all constraints?"). - - Args: - state_sim: Microsimulation with state_fips set - constraints: List of constraint dicts with variable, operation, - value keys (geographic constraints should be pre-filtered) - n_households: Number of households - - Returns: - Boolean mask array of length n_households - """ - if not constraints: - return np.ones(n_households, dtype=bool) - - entity_rel = self._build_entity_relationship(state_sim) - n_persons = len(entity_rel) - - person_mask = np.ones(n_persons, dtype=bool) - - for c in constraints: - var = c["variable"] - op = c["operation"] - val = c["value"] - - # Calculate constraint variable at person level - constraint_values = state_sim.calculate( - var, self.time_period, map_to="person" - ).values - - # Apply operation at person level - person_mask &= apply_op(constraint_values, op, val) - - # Aggregate to household level using .any() - # "At least one person in this household satisfies ALL constraints" - entity_rel_with_mask = entity_rel.copy() - entity_rel_with_mask["satisfies"] = person_mask - - household_mask_series = entity_rel_with_mask.groupby("household_id")[ - "satisfies" - ].any() - - # Ensure we return a mask aligned with household order - household_ids = state_sim.calculate( - "household_id", map_to="household" - ).values - household_mask = np.array( - [ - household_mask_series.get(hh_id, False) - for hh_id in household_ids - ] - ) - - return household_mask def _query_targets(self, target_filter: dict) -> pd.DataFrame: """Query targets based on filter criteria using OR logic.""" @@ -177,20 +78,9 @@ def _query_targets(self, target_filter: dict) -> pd.DataFrame: with self.engine.connect() as conn: return pd.read_sql(query, conn) - def _get_constraints(self, stratum_id: int) -> List[dict]: - """Get all constraints for a stratum (including geographic).""" - query = """ - SELECT constraint_variable as variable, operation, value - FROM stratum_constraints - WHERE stratum_id = :stratum_id - """ - with self.engine.connect() as conn: - df = pd.read_sql(query, conn, params={"stratum_id": stratum_id}) - return df.to_dict("records") - def _get_geographic_id(self, stratum_id: int) -> str: """Extract geographic_id from constraints for targets_df.""" - constraints = self._get_constraints(stratum_id) + constraints = self._get_stratum_constraints(stratum_id) for c in constraints: if c["variable"] == "state_fips": return c["value"] @@ -284,7 +174,9 @@ def build_matrix( col_start = cd_idx * n_households for row_idx, (_, target) in enumerate(targets_df.iterrows()): - constraints = self._get_constraints(target["stratum_id"]) + constraints = self._get_stratum_constraints( + target["stratum_id"] + ) geo_constraints = [] non_geo_constraints = [] diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index 9485d02e1..b6be0d32e 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -155,7 +155,12 @@ class Target(SQLModel, table=True): description="Identifier for a policy reform scenario (0 for baseline).", ) value: Optional[float] = Field( - default=None, description="The numerical value of the target variable." + default=None, + description="The numerical value of the target variable.", + ) + raw_value: Optional[float] = Field( + default=None, + description="Original value from the data source, before reconciliation.", ) source_id: Optional[int] = Field( default=None, diff --git a/policyengine_us_data/db/etl_all_targets.py b/policyengine_us_data/db/etl_all_targets.py new file mode 100644 index 000000000..165d7e3ec --- /dev/null +++ b/policyengine_us_data/db/etl_all_targets.py @@ -0,0 +1,169 @@ +"""Orchestrator that loads ALL calibration targets from the legacy +``loss.py`` system into the ``policy_data.db`` database. + +This thin wrapper delegates to focused ETL modules: + - etl_misc_national (census age, EITC, SOI filers, neg. market + income, infant, net worth, Medicaid, + SOI filing-status) + - etl_healthcare_spending (healthcare by age band) + - etl_tax_expenditure (SALT, medical, charitable, interest, QBI) + - etl_state_targets (state pop, real estate taxes, ACA, + 10-yr age, state AGI) + +Individual extract functions are re-exported here so existing +callers (tests, notebooks) continue to work. +""" + +import argparse +import logging + +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import Stratum + +# -- Focused ETL loaders ------------------------------------------- +from policyengine_us_data.db.etl_misc_national import ( + load_misc_national, +) +from policyengine_us_data.db.etl_healthcare_spending import ( + load_healthcare_spending, +) +# SPM threshold decile targets removed -- we now calculate SPM +# thresholds by metro area, making these obsolete. +from policyengine_us_data.db.etl_tax_expenditure import ( + load_tax_expenditure, +) +from policyengine_us_data.db.etl_state_targets import ( + load_state_targets, +) + +# -- Re-export every extract function for backward compatibility --- +from policyengine_us_data.db.etl_misc_national import ( # noqa: F401 + extract_census_age_populations, + extract_eitc_by_child_count, + extract_soi_filer_counts, + extract_negative_market_income, + extract_infant_count, + extract_net_worth, + extract_state_medicaid_enrollment, + extract_soi_filing_status_targets, +) +from policyengine_us_data.db.etl_healthcare_spending import ( # noqa: F401 + extract_healthcare_by_age, +) +from policyengine_us_data.db.etl_tax_expenditure import ( # noqa: F401 + extract_tax_expenditure_targets, +) +from policyengine_us_data.db.etl_state_targets import ( # noqa: F401 + extract_state_population, + extract_state_real_estate_taxes, + extract_state_aca, + extract_state_10yr_age, + extract_state_agi, +) + +# Re-export shared helpers under their old private names +from policyengine_us_data.db.etl_helpers import ( # noqa: F401 + fmt as _fmt, + get_or_create_stratum as _get_or_create_stratum, + upsert_target as _upsert_target, + FILING_STATUS_MAP as _FILING_STATUS_MAP, +) + +logger = logging.getLogger(__name__) + +DEFAULT_DATASET = ( + "hf://policyengine/policyengine-us-data/" + "calibration/stratified_extended_cps.h5" +) + + +# ------------------------------------------------------------------ +# Orchestrator +# ------------------------------------------------------------------ + + +def load_all_targets( + engine, + time_period: int, + root_stratum_id: int, +): + """Load every target category into the database. + + Parameters + ---------- + engine : sqlalchemy.Engine + Database engine (can be in-memory for tests). + time_period : int + Year for the targets (e.g. 2024). + root_stratum_id : int + ID of the national root stratum. + """ + load_misc_national(engine, time_period, root_stratum_id) + load_healthcare_spending(engine, time_period, root_stratum_id) + load_tax_expenditure(engine, time_period, root_stratum_id) + load_state_targets(engine, time_period, root_stratum_id) + logger.info("All legacy targets loaded successfully.") + + +# ------------------------------------------------------------------ +# CLI entry point +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description=( + "ETL: migrate ALL calibration targets " + "from legacy loss.py into the database" + ) + ) + parser.add_argument( + "--dataset", + default=DEFAULT_DATASET, + help="Source dataset. Default: %(default)s", + ) + args = parser.parse_args() + + from policyengine_us import Microsimulation + + print(f"Loading dataset: {args.dataset}") + sim = Microsimulation(dataset=args.dataset) + time_period = int(sim.default_calculation_period) + print(f"Derived time_period={time_period}") + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + from sqlmodel import SQLModel + + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + root = Stratum( + definition_hash="root_national", + parent_stratum_id=None, + stratum_group_id=1, + notes="United States", + ) + sess.add(root) + sess.commit() + sess.refresh(root) + root_id = root.stratum_id + + load_all_targets( + engine=engine, + time_period=time_period, + root_stratum_id=root_id, + ) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_healthcare_spending.py b/policyengine_us_data/db/etl_healthcare_spending.py new file mode 100644 index 000000000..7df66f767 --- /dev/null +++ b/policyengine_us_data/db/etl_healthcare_spending.py @@ -0,0 +1,148 @@ +"""ETL: Healthcare spending by age band (9 bands x 4 expense types). + +Migrated from category 4 of the legacy ``etl_all_targets.py``. +""" + +import argparse +import logging + +import pandas as pd +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import ( + STORAGE_FOLDER, + CALIBRATION_FOLDER, +) +from policyengine_us_data.db.create_database_tables import ( + SourceType, + Stratum, +) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.db.etl_helpers import ( + get_or_create_stratum, + upsert_target, +) + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Extract +# ------------------------------------------------------------------ + + +def extract_healthcare_by_age(): + """Return list of 9 dicts (one per 10-year age band).""" + df = pd.read_csv(CALIBRATION_FOLDER / "healthcare_spending.csv") + expense_cols = [ + "health_insurance_premiums_without_medicare_part_b", + "over_the_counter_health_expenses", + "other_medical_expenses", + "medicare_part_b_premiums", + ] + records = [] + for _, row in df.iterrows(): + age_lower = int(row["age_10_year_lower_bound"]) + expenses = {c: float(row[c]) for c in expense_cols} + records.append({"age_lower": age_lower, "expenses": expenses}) + return records + + +# ------------------------------------------------------------------ +# Load +# ------------------------------------------------------------------ + + +def load_healthcare_spending(engine, time_period, root_stratum_id): + """Load healthcare spending targets into the database.""" + with Session(engine) as session: + source = get_or_create_source( + session, + name="Legacy loss.py calibration targets", + source_type=SourceType.HARDCODED, + vintage=str(time_period), + description=( + "Comprehensive calibration targets migrated from " + "the legacy build_loss_matrix() in loss.py" + ), + ) + sid = source.source_id + + hc_records = extract_healthcare_by_age() + for rec in hc_records: + age_lo = rec["age_lower"] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "age", + "operation": ">=", + "value": str(age_lo), + }, + { + "constraint_variable": "age", + "operation": "<", + "value": str(age_lo + 10), + }, + ], + stratum_group_id=13, + notes=f"Healthcare age {age_lo}-{age_lo + 9}", + category_tag="healthcare", + ) + for var_name, amount in rec["expenses"].items(): + upsert_target( + session, + stratum.stratum_id, + var_name, + time_period, + amount, + sid, + notes=( + f"Healthcare {var_name} " f"age {age_lo}-{age_lo + 9}" + ), + ) + + session.commit() + logger.info("Healthcare spending targets loaded.") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description="ETL: healthcare spending by age band" + ) + parser.add_argument( + "--time-period", + type=int, + default=2024, + help="Target year (default: %(default)s)", + ) + args = parser.parse_args() + + from sqlmodel import SQLModel + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + raise RuntimeError("Root stratum not found.") + root_id = root.stratum_id + + load_healthcare_spending(engine, args.time_period, root_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_helpers.py b/policyengine_us_data/db/etl_helpers.py new file mode 100644 index 000000000..20bee4a7d --- /dev/null +++ b/policyengine_us_data/db/etl_helpers.py @@ -0,0 +1,140 @@ +"""Shared helpers for legacy-target ETL modules. + +Functions extracted from the original monolithic ``etl_all_targets.py`` +so that every focused ETL module can reuse them without duplication. +""" + +import numpy as np +from sqlmodel import select + +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, +) + +# ------------------------------------------------------------------ +# Format AGI bounds (mirrors loss.py ``fmt``) +# ------------------------------------------------------------------ + + +def fmt(x): + """Human-readable label for an AGI bound value.""" + if x == -np.inf: + return "-inf" + if x == np.inf: + return "inf" + if x < 1e3: + return f"{x:.0f}" + if x < 1e6: + return f"{x/1e3:.0f}k" + if x < 1e9: + return f"{x/1e6:.0f}m" + return f"{x/1e9:.1f}bn" + + +# ------------------------------------------------------------------ +# Stratum upsert +# ------------------------------------------------------------------ + + +def get_or_create_stratum( + session, + parent_id, + constraints, + stratum_group_id, + notes, + category_tag=None, +): + """Find an existing stratum by notes + parent, or create one. + + Parameters + ---------- + category_tag : str, optional + If given, an extra ``target_category == `` constraint + is appended so the definition hash is unique across + categories that otherwise share the same constraints. + """ + existing = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_id, + Stratum.notes == notes, + ) + ).first() + if existing: + return existing + + all_constraints = list(constraints) + if category_tag: + all_constraints.append( + { + "constraint_variable": "target_category", + "operation": "==", + "value": category_tag, + } + ) + + stratum = Stratum( + parent_stratum_id=parent_id, + stratum_group_id=stratum_group_id, + notes=notes, + ) + stratum.constraints_rel = [StratumConstraint(**c) for c in all_constraints] + session.add(stratum) + session.flush() + return stratum + + +# ------------------------------------------------------------------ +# Target upsert +# ------------------------------------------------------------------ + + +def upsert_target( + session, + stratum_id, + variable, + period, + value, + source_id, + notes, + reform_id=0, +): + """Insert or update a target row.""" + existing = session.exec( + select(Target).where( + Target.stratum_id == stratum_id, + Target.variable == variable, + Target.period == period, + Target.reform_id == reform_id, + ) + ).first() + if existing: + existing.value = value + existing.notes = notes + existing.source_id = source_id + else: + t = Target( + stratum_id=stratum_id, + variable=variable, + period=period, + value=value, + source_id=source_id, + reform_id=reform_id, + active=True, + notes=notes, + ) + session.add(t) + + +# ------------------------------------------------------------------ +# Filing-status mapping (SOI names -> PE enum values) +# ------------------------------------------------------------------ + +FILING_STATUS_MAP = { + "Single": "SINGLE", + "Married Filing Jointly/Surviving Spouse": "JOINT", + "Head of Household": "HEAD_OF_HOUSEHOLD", + "Married Filing Separately": "SEPARATE", + "All": None, +} diff --git a/policyengine_us_data/db/etl_misc_national.py b/policyengine_us_data/db/etl_misc_national.py new file mode 100644 index 000000000..28ba637bc --- /dev/null +++ b/policyengine_us_data/db/etl_misc_national.py @@ -0,0 +1,529 @@ +"""ETL: Miscellaneous national calibration targets. + +Combines several categories from the legacy ``etl_all_targets.py`` +that are national-level and not covered by other focused ETLs: + 1. Census single-year age populations (86 bins) + 2. EITC by child count (returns + spending) + 3. SOI filer counts by AGI band (7 bands) + 6. Negative household market income (total + count) + 7. Infant count + 8. Net worth + 13. State Medicaid enrollment (51 rows) + 16. SOI filing-status x AGI bin targets +""" + +import argparse +import logging + +import numpy as np +import pandas as pd +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import ( + STORAGE_FOLDER, + CALIBRATION_FOLDER, +) +from policyengine_us_data.db.create_database_tables import ( + SourceType, + Stratum, +) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.db.etl_helpers import ( + fmt, + get_or_create_stratum, + upsert_target, + FILING_STATUS_MAP, +) + +logger = logging.getLogger(__name__) + +# ------------------------------------------------------------------ +# Constants +# ------------------------------------------------------------------ + +SOI_FILER_COUNTS_2015 = { + (-np.inf, 0): 2_072_066, + (0, 5_000): 10_134_703, + (5_000, 10_000): 11_398_595, + (10_000, 25_000): 23_447_927, + (25_000, 50_000): 23_727_745, + (50_000, 100_000): 32_801_908, + (100_000, np.inf): 25_120_985, +} + +INFANTS_2023 = 3_491_679 +INFANTS_2022 = 3_437_933 +NET_WORTH_2024 = 160e12 + +# Map SOI variable names to PE variable names +_SOI_VAR_MAP = { + "count": "tax_unit_count", +} + + +# ------------------------------------------------------------------ +# Extract helpers +# ------------------------------------------------------------------ + + +def extract_census_age_populations(time_period: int): + """Return list of 86 dicts with keys ``age`` and ``value``.""" + populations = pd.read_csv(CALIBRATION_FOLDER / "np2023_d5_mid.csv") + populations = populations[ + (populations.SEX == 0) & (populations.RACE_HISP == 0) + ] + pop_cols = [f"POP_{i}" for i in range(86)] + year_pops = ( + populations.groupby("YEAR").sum()[pop_cols].T[time_period].values + ) + return [{"age": i, "value": float(year_pops[i])} for i in range(86)] + + +def extract_eitc_by_child_count(): + """Return list of 4 dicts (one per child bucket).""" + df = pd.read_csv(CALIBRATION_FOLDER / "eitc.csv") + return [ + { + "count_children": int(row["count_children"]), + "eitc_returns": float(row["eitc_returns"]), + "eitc_total": float(row["eitc_total"]), + } + for _, row in df.iterrows() + ] + + +def extract_soi_filer_counts(): + """Return list of 7 dicts (one per AGI band).""" + return [ + { + "agi_lower": lo, + "agi_upper": hi, + "filer_count_2015": count, + } + for (lo, hi), count in SOI_FILER_COUNTS_2015.items() + ] + + +def extract_negative_market_income(): + """Return dict with total and count.""" + return {"total": -138e9, "count": 3e6} + + +def extract_infant_count(): + """Return projected infant count.""" + return INFANTS_2023 * (INFANTS_2023 / INFANTS_2022) + + +def extract_net_worth(): + """Return total household net worth.""" + return NET_WORTH_2024 + + +def extract_state_medicaid_enrollment(): + """Return list of 51 dicts with state and enrollment.""" + df = pd.read_csv(CALIBRATION_FOLDER / "medicaid_enrollment_2024.csv") + return [ + { + "state": row["state"], + "enrollment": float(row["enrollment"]), + } + for _, row in df.iterrows() + ] + + +def _get_pe_variables(): + """Return the set of valid PolicyEngine US variable names.""" + try: + from policyengine_us import CountryTaxBenefitSystem + + system = CountryTaxBenefitSystem() + return set(system.variables.keys()) + except Exception: + return None + + +def extract_soi_filing_status_targets(): + """Return filtered list of SOI filing-status x AGI bin records.""" + df = pd.read_csv(CALIBRATION_FOLDER / "soi_targets.csv") + filtered = df[ + (df["Taxable only"] == True) # noqa: E712 + & (df["AGI upper bound"] > 10_000) + ] + + pe_vars = _get_pe_variables() + + records = [] + for _, row in filtered.iterrows(): + var = row["Variable"] + mapped = _SOI_VAR_MAP.get(var, var) + if pe_vars is not None and mapped not in pe_vars: + continue + records.append( + { + "variable": mapped, + "filing_status": row["Filing status"], + "agi_lower": float(row["AGI lower bound"]), + "agi_upper": float(row["AGI upper bound"]), + "is_count": bool(row["Count"]), + "taxable_only": bool(row["Taxable only"]), + "value": float(row["Value"]), + } + ) + return records + + +# ------------------------------------------------------------------ +# Load +# ------------------------------------------------------------------ + + +def load_misc_national(engine, time_period, root_stratum_id): + """Load all miscellaneous national targets into the database.""" + with Session(engine) as session: + source = get_or_create_source( + session, + name="Legacy loss.py calibration targets", + source_type=SourceType.HARDCODED, + vintage=str(time_period), + description=( + "Comprehensive calibration targets migrated from " + "the legacy build_loss_matrix() in loss.py" + ), + ) + sid = source.source_id + + # -- 1. Census single-year age populations ------ + age_pops = extract_census_age_populations(time_period) + age_strata = {} + for rec in age_pops: + age = rec["age"] + notes_str = f"Census age bin {age}" + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "age", + "operation": ">=", + "value": str(age), + }, + { + "constraint_variable": "age", + "operation": "<", + "value": str(age + 1), + }, + ], + stratum_group_id=10, + notes=notes_str, + ) + age_strata[age] = stratum + upsert_target( + session, + stratum.stratum_id, + "person_count", + time_period, + rec["value"], + sid, + notes=notes_str, + ) + + # -- 2. EITC by child count -------------------- + eitc_records = extract_eitc_by_child_count() + for rec in eitc_records: + cc = rec["count_children"] + if cc < 2: + op, val = "==", str(cc) + else: + op, val = ">=", str(cc) + + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "eitc_child_count", + "operation": op, + "value": val, + }, + { + "constraint_variable": "eitc", + "operation": ">", + "value": "0", + }, + ], + stratum_group_id=11, + notes=f"EITC {cc} children", + ) + upsert_target( + session, + stratum.stratum_id, + "tax_unit_count", + time_period, + rec["eitc_returns"], + sid, + notes=f"EITC returns, {cc} children", + ) + upsert_target( + session, + stratum.stratum_id, + "eitc", + time_period, + rec["eitc_total"], + sid, + notes=f"EITC spending, {cc} children", + ) + + # -- 3. SOI filer counts by AGI band ------------ + filer_counts = extract_soi_filer_counts() + for rec in filer_counts: + lo, hi = rec["agi_lower"], rec["agi_upper"] + label = f"{fmt(lo)}_{fmt(hi)}" + constraints = [ + { + "constraint_variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + }, + { + "constraint_variable": "adjusted_gross_income", + "operation": ">=", + "value": str(lo), + }, + { + "constraint_variable": "adjusted_gross_income", + "operation": "<", + "value": str(hi), + }, + ] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=constraints, + stratum_group_id=12, + notes=f"SOI filer count AGI {label}", + ) + upsert_target( + session, + stratum.stratum_id, + "tax_unit_count", + time_period, + rec["filer_count_2015"], + sid, + notes=f"SOI filer count AGI {label}", + ) + + # -- 6. Negative household market income -------- + nmi = extract_negative_market_income() + nmi_stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": ("household_market_income"), + "operation": "<", + "value": "0", + }, + ], + stratum_group_id=15, + notes="Negative household market income", + ) + upsert_target( + session, + nmi_stratum.stratum_id, + "household_market_income", + time_period, + nmi["total"], + sid, + notes="Negative household market income total", + ) + upsert_target( + session, + nmi_stratum.stratum_id, + "household_count", + time_period, + nmi["count"], + sid, + notes="Negative household market income count", + ) + + # -- 7. Infant count ---------------------------- + infant_count = extract_infant_count() + infant_stratum = age_strata[0] + upsert_target( + session, + infant_stratum.stratum_id, + "person_count", + time_period, + infant_count, + sid, + notes="Census age bin 0", + ) + + # -- 8. Net worth ------------------------------- + nw_value = extract_net_worth() + upsert_target( + session, + root_stratum_id, + "net_worth", + time_period, + nw_value, + sid, + notes="Total household net worth (Fed Reserve)", + ) + + # -- 13. State Medicaid enrollment -------------- + med_records = extract_state_medicaid_enrollment() + for rec in med_records: + st = rec["state"] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + { + "constraint_variable": "medicaid_enrolled", + "operation": "==", + "value": "True", + }, + ], + stratum_group_id=21, + notes=f"State Medicaid {st}", + category_tag="medicaid", + ) + upsert_target( + session, + stratum.stratum_id, + "person_count", + time_period, + rec["enrollment"], + sid, + notes=f"State Medicaid enrollment {st}", + ) + + # -- 16. SOI filing-status x AGI bin targets ---- + soi_records = extract_soi_filing_status_targets() + soi_strata_cache = {} + for rec in soi_records: + lo = rec["agi_lower"] + hi = rec["agi_upper"] + fs = rec["filing_status"] + var = rec["variable"] + is_count = rec["is_count"] + + cache_key = (fs, lo, hi) + if cache_key not in soi_strata_cache: + constraints = [ + { + "constraint_variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + }, + { + "constraint_variable": ("adjusted_gross_income"), + "operation": ">=", + "value": str(lo), + }, + { + "constraint_variable": ("adjusted_gross_income"), + "operation": "<", + "value": str(hi), + }, + { + "constraint_variable": "income_tax", + "operation": ">", + "value": "0", + }, + ] + + pe_fs = FILING_STATUS_MAP.get(fs) + if pe_fs is not None: + constraints.append( + { + "constraint_variable": "filing_status", + "operation": "==", + "value": pe_fs, + } + ) + + stratum_notes = ( + f"SOI filing-status {fs} " f"AGI {fmt(lo)}-{fmt(hi)}" + ) + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=constraints, + stratum_group_id=24, + notes=stratum_notes, + ) + soi_strata_cache[cache_key] = stratum + else: + stratum = soi_strata_cache[cache_key] + + target_variable = var + if is_count and var != "count": + target_variable = "tax_unit_count" + elif var == "count": + target_variable = "tax_unit_count" + + count_label = "count" if is_count else "total" + upsert_target( + session, + stratum.stratum_id, + target_variable, + time_period, + rec["value"], + sid, + notes=( + f"SOI filing-status {fs} " + f"{var} {count_label} " + f"AGI {fmt(lo)}-{fmt(hi)}" + ), + ) + + session.commit() + logger.info("Misc national targets loaded successfully.") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description=("ETL: miscellaneous national calibration targets") + ) + parser.add_argument( + "--time-period", + type=int, + default=2024, + help="Target year (default: %(default)s)", + ) + args = parser.parse_args() + + from sqlmodel import SQLModel + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + raise RuntimeError("Root stratum not found.") + root_id = root.stratum_id + + load_misc_national(engine, args.time_period, root_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_spm_threshold.py b/policyengine_us_data/db/etl_spm_threshold.py new file mode 100644 index 000000000..a94abd04b --- /dev/null +++ b/policyengine_us_data/db/etl_spm_threshold.py @@ -0,0 +1,151 @@ +"""ETL: AGI by SPM threshold decile (10 deciles x 2 metrics). + +Migrated from category 5 of the legacy ``etl_all_targets.py``. +""" + +import argparse +import logging + +import pandas as pd +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import ( + STORAGE_FOLDER, + CALIBRATION_FOLDER, +) +from policyengine_us_data.db.create_database_tables import ( + SourceType, + Stratum, +) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.db.etl_helpers import ( + get_or_create_stratum, + upsert_target, +) + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Extract +# ------------------------------------------------------------------ + + +def extract_spm_threshold_agi(): + """Return list of 10 dicts (one per decile).""" + df = pd.read_csv(CALIBRATION_FOLDER / "spm_threshold_agi.csv") + return [ + { + "decile": int(row["decile"]), + "lower_spm_threshold": float(row["lower_spm_threshold"]), + "upper_spm_threshold": float(row["upper_spm_threshold"]), + "adjusted_gross_income": float(row["adjusted_gross_income"]), + "count": float(row["count"]), + } + for _, row in df.iterrows() + ] + + +# ------------------------------------------------------------------ +# Load +# ------------------------------------------------------------------ + + +def load_spm_threshold(engine, time_period, root_stratum_id): + """Load SPM-threshold-decile targets into the database.""" + with Session(engine) as session: + source = get_or_create_source( + session, + name="Legacy loss.py calibration targets", + source_type=SourceType.HARDCODED, + vintage=str(time_period), + description=( + "Comprehensive calibration targets migrated from " + "the legacy build_loss_matrix() in loss.py" + ), + ) + sid = source.source_id + + spm_records = extract_spm_threshold_agi() + for rec in spm_records: + d = rec["decile"] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": ("spm_unit_spm_threshold"), + "operation": ">=", + "value": str(rec["lower_spm_threshold"]), + }, + { + "constraint_variable": ("spm_unit_spm_threshold"), + "operation": "<", + "value": str(rec["upper_spm_threshold"]), + }, + ], + stratum_group_id=14, + notes=f"SPM threshold decile {d}", + ) + upsert_target( + session, + stratum.stratum_id, + "adjusted_gross_income", + time_period, + rec["adjusted_gross_income"], + sid, + notes=f"SPM threshold decile {d} AGI", + ) + upsert_target( + session, + stratum.stratum_id, + "spm_unit_count", + time_period, + rec["count"], + sid, + notes=f"SPM threshold decile {d} count", + ) + + session.commit() + logger.info("SPM threshold targets loaded.") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description="ETL: AGI by SPM threshold decile" + ) + parser.add_argument( + "--time-period", + type=int, + default=2024, + help="Target year (default: %(default)s)", + ) + args = parser.parse_args() + + from sqlmodel import SQLModel + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + raise RuntimeError("Root stratum not found.") + root_id = root.stratum_id + + load_spm_threshold(engine, args.time_period, root_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_state_targets.py b/policyengine_us_data/db/etl_state_targets.py new file mode 100644 index 000000000..4745a4fff --- /dev/null +++ b/policyengine_us_data/db/etl_state_targets.py @@ -0,0 +1,409 @@ +"""ETL: State-level calibration targets. + +Combines several categories from the legacy ``etl_all_targets.py``: + 9. State population (total + under-5) + 11. State real estate taxes (51 rows) + 12. State ACA spending and enrollment (51 x 2) + 14. State 10-year age targets (50 states x 18 ranges) + 15. State AGI targets (918 rows) +""" + +import argparse +import logging + +import pandas as pd +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import ( + STORAGE_FOLDER, + CALIBRATION_FOLDER, +) +from policyengine_us_data.db.create_database_tables import ( + SourceType, + Stratum, +) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.db.etl_helpers import ( + get_or_create_stratum, + upsert_target, +) + +logger = logging.getLogger(__name__) + +HARD_CODED_NATIONAL_TOTAL = 500e9 # real_estate_taxes national +ACA_SPENDING_2024 = 9.8e10 + + +# ------------------------------------------------------------------ +# Extract helpers +# ------------------------------------------------------------------ + + +def extract_state_population(): + """Return list of dicts with state, population, population_under_5.""" + df = pd.read_csv(CALIBRATION_FOLDER / "population_by_state.csv") + return [ + { + "state": row["state"], + "population": float(row["population"]), + "population_under_5": float(row["population_under_5"]), + } + for _, row in df.iterrows() + ] + + +def extract_state_real_estate_taxes(): + """Return list of 51 dicts with state_code and scaled value.""" + df = pd.read_csv(CALIBRATION_FOLDER / "real_estate_taxes_by_state_acs.csv") + state_sum = df["real_estate_taxes_bn"].sum() * 1e9 + scale = HARD_CODED_NATIONAL_TOTAL / state_sum + return [ + { + "state_code": row["state_code"], + "value": float(row["real_estate_taxes_bn"] * scale * 1e9), + } + for _, row in df.iterrows() + ] + + +def extract_state_aca(): + """Return list of dicts with state, spending (scaled), enrollment.""" + df = pd.read_csv( + CALIBRATION_FOLDER / "aca_spending_and_enrollment_2024.csv" + ) + df["spending_annual"] = df["spending"] * 12 + spending_scale = ACA_SPENDING_2024 / df["spending_annual"].sum() + df["spending_scaled"] = df["spending_annual"] * spending_scale + return [ + { + "state": row["state"], + "spending": float(row["spending_scaled"]), + "enrollment": float(row["enrollment"]), + } + for _, row in df.iterrows() + ] + + +def extract_state_10yr_age(): + """Return list of dicts with state, age_range, value.""" + df = pd.read_csv(CALIBRATION_FOLDER / "age_state.csv") + records = [] + for _, row in df.iterrows(): + state = row["GEO_NAME"] + for col in df.columns[2:]: # skip GEO_ID, GEO_NAME + records.append( + { + "state": state, + "age_range": col, + "value": float(row[col]), + } + ) + return records + + +def extract_state_agi(): + """Return list of 918 dicts with state AGI data.""" + df = pd.read_csv(CALIBRATION_FOLDER / "agi_state.csv") + return [ + { + "geo_name": row["GEO_NAME"], + "agi_lower": float(row["AGI_LOWER_BOUND"]), + "agi_upper": float(row["AGI_UPPER_BOUND"]), + "value": float(row["VALUE"]), + "is_count": bool(row["IS_COUNT"]), + "variable": row["VARIABLE"], + } + for _, row in df.iterrows() + ] + + +# ------------------------------------------------------------------ +# Load +# ------------------------------------------------------------------ + + +def load_state_targets(engine, time_period, root_stratum_id): + """Load all state-level targets into the database.""" + with Session(engine) as session: + source = get_or_create_source( + session, + name="Legacy loss.py calibration targets", + source_type=SourceType.HARDCODED, + vintage=str(time_period), + description=( + "Comprehensive calibration targets migrated from " + "the legacy build_loss_matrix() in loss.py" + ), + ) + sid = source.source_id + + # -- 9. State population (total + under-5) ------ + state_pops = extract_state_population() + for rec in state_pops: + st = rec["state"] + st_stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + ], + stratum_group_id=17, + notes=f"State {st} population", + category_tag="state_population", + ) + upsert_target( + session, + st_stratum.stratum_id, + "person_count", + time_period, + rec["population"], + sid, + notes=f"State {st} total population", + ) + + under5_stratum = get_or_create_stratum( + session, + parent_id=st_stratum.stratum_id, + constraints=[ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + { + "constraint_variable": "age", + "operation": "<", + "value": "5", + }, + ], + stratum_group_id=17, + notes=f"State {st} under 5", + ) + upsert_target( + session, + under5_stratum.stratum_id, + "person_count", + time_period, + rec["population_under_5"], + sid, + notes=f"State {st} population under 5", + ) + + # -- 11. State real estate taxes ----------------- + ret_records = extract_state_real_estate_taxes() + for rec in ret_records: + sc = rec["state_code"] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "state_code", + "operation": "==", + "value": sc, + }, + ], + stratum_group_id=19, + notes=f"State real estate taxes {sc}", + category_tag="real_estate_tax", + ) + upsert_target( + session, + stratum.stratum_id, + "real_estate_taxes", + time_period, + rec["value"], + sid, + notes=f"State real estate taxes {sc}", + ) + + # -- 12. State ACA spending + enrollment --------- + aca_records = extract_state_aca() + for rec in aca_records: + st = rec["state"] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + ], + stratum_group_id=20, + notes=f"State ACA {st}", + category_tag="aca", + ) + upsert_target( + session, + stratum.stratum_id, + "aca_ptc", + time_period, + rec["spending"], + sid, + notes=f"ACA spending {st}", + ) + upsert_target( + session, + stratum.stratum_id, + "person_count", + time_period, + rec["enrollment"], + sid, + notes=f"ACA enrollment {st}", + ) + + # -- 14. State 10-year age targets --------------- + age_records = extract_state_10yr_age() + for rec in age_records: + st = rec["state"] + ar = rec["age_range"] + if "+" in ar: + age_lo = int(ar.replace("+", "")) + constraints = [ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + { + "constraint_variable": "age", + "operation": ">=", + "value": str(age_lo), + }, + ] + else: + parts = ar.split("-") + age_lo = int(parts[0]) + age_hi = int(parts[1]) + constraints = [ + { + "constraint_variable": "state_code", + "operation": "==", + "value": st, + }, + { + "constraint_variable": "age", + "operation": ">=", + "value": str(age_lo), + }, + { + "constraint_variable": "age", + "operation": "<=", + "value": str(age_hi), + }, + ] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=constraints, + stratum_group_id=22, + notes=f"State 10yr age {st} {ar}", + category_tag="state_age", + ) + upsert_target( + session, + stratum.stratum_id, + "person_count", + time_period, + rec["value"], + sid, + notes=f"State 10yr age {st} {ar}", + ) + + # -- 15. State AGI targets ----------------------- + agi_records = extract_state_agi() + for rec in agi_records: + gn = rec["geo_name"] + lo = rec["agi_lower"] + hi = rec["agi_upper"] + stratum_notes = f"State AGI {gn} {lo}-{hi}" + + constraints = [ + { + "constraint_variable": "state_code", + "operation": "==", + "value": gn, + }, + { + "constraint_variable": "adjusted_gross_income", + "operation": ">", + "value": str(lo), + }, + { + "constraint_variable": "adjusted_gross_income", + "operation": "<=", + "value": str(hi), + }, + ] + stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=constraints, + stratum_group_id=23, + notes=stratum_notes, + category_tag="state_agi", + ) + variable = ( + "tax_unit_count" + if rec["is_count"] + else "adjusted_gross_income" + ) + upsert_target( + session, + stratum.stratum_id, + variable, + time_period, + rec["value"], + sid, + notes=(f"State AGI {gn} " f"{lo}-{hi} {rec['variable']}"), + ) + + session.commit() + logger.info("State targets loaded successfully.") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description="ETL: state-level calibration targets" + ) + parser.add_argument( + "--time-period", + type=int, + default=2024, + help="Target year (default: %(default)s)", + ) + args = parser.parse_args() + + from sqlmodel import SQLModel + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + raise RuntimeError("Root stratum not found.") + root_id = root.stratum_id + + load_state_targets(engine, args.time_period, root_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_tax_expenditure.py b/policyengine_us_data/db/etl_tax_expenditure.py new file mode 100644 index 000000000..b19affbc4 --- /dev/null +++ b/policyengine_us_data/db/etl_tax_expenditure.py @@ -0,0 +1,133 @@ +"""ETL: Tax expenditure targets (SALT, medical, charitable, +interest, QBI). + +Migrated from category 10 of the legacy ``etl_all_targets.py``. +""" + +import argparse +import logging + +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + SourceType, + Stratum, +) +from policyengine_us_data.utils.db_metadata import get_or_create_source +from policyengine_us_data.db.etl_helpers import ( + get_or_create_stratum, + upsert_target, +) + +logger = logging.getLogger(__name__) + +ITEMIZED_DEDUCTIONS = { + "salt_deduction": 21.247e9, + "medical_expense_deduction": 11.4e9, + "charitable_deduction": 65.301e9, + "interest_deduction": 24.8e9, + "qualified_business_income_deduction": 63.1e9, +} + + +# ------------------------------------------------------------------ +# Extract +# ------------------------------------------------------------------ + + +def extract_tax_expenditure_targets(): + """Return list of 5 dicts (one per deduction type).""" + return [ + {"variable": var, "value": val} + for var, val in ITEMIZED_DEDUCTIONS.items() + ] + + +# ------------------------------------------------------------------ +# Load +# ------------------------------------------------------------------ + + +def load_tax_expenditure(engine, time_period, root_stratum_id): + """Load tax-expenditure targets into the database.""" + with Session(engine) as session: + source = get_or_create_source( + session, + name="Legacy loss.py calibration targets", + source_type=SourceType.HARDCODED, + vintage=str(time_period), + description=( + "Comprehensive calibration targets migrated from " + "the legacy build_loss_matrix() in loss.py" + ), + ) + sid = source.source_id + + te_records = extract_tax_expenditure_targets() + te_stratum = get_or_create_stratum( + session, + parent_id=root_stratum_id, + constraints=[], + stratum_group_id=18, + notes="Tax expenditure targets (counterfactual)", + category_tag="tax_expenditure", + ) + for rec in te_records: + upsert_target( + session, + te_stratum.stratum_id, + rec["variable"], + time_period, + rec["value"], + sid, + notes=( + f"Tax expenditure: {rec['variable']} " + "(JCT 2024, requires counterfactual sim)" + ), + reform_id=1, + ) + + session.commit() + logger.info("Tax expenditure targets loaded.") + + +# ------------------------------------------------------------------ +# CLI +# ------------------------------------------------------------------ + + +def main(): + parser = argparse.ArgumentParser( + description="ETL: tax expenditure targets" + ) + parser.add_argument( + "--time-period", + type=int, + default=2024, + help="Target year (default: %(default)s)", + ) + args = parser.parse_args() + + from sqlmodel import SQLModel + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + SQLModel.metadata.create_all(engine) + + with Session(engine) as sess: + root = sess.exec( + select(Stratum).where( + Stratum.parent_stratum_id == None # noqa: E711 + ) + ).first() + if not root: + raise RuntimeError("Root stratum not found.") + root_id = root.stratum_id + + load_tax_expenditure(engine, args.time_period, root_id) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/reconcile_targets.py b/policyengine_us_data/db/reconcile_targets.py new file mode 100644 index 000000000..ccc663976 --- /dev/null +++ b/policyengine_us_data/db/reconcile_targets.py @@ -0,0 +1,304 @@ +"""Two-pass proportional rescaling of geographic targets. + +Ensures that child-level targets (state, congressional district) sum +to their parent-level target for each (variable, period, reform_id) +group. Original source values are preserved in ``raw_value``. + +Pass 1: scale state targets so they sum to the national target. +Pass 2: scale CD targets so they sum to their (corrected) state target. +""" + +import logging +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +from sqlmodel import Session, create_engine, select + +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, +) +from policyengine_us_data.storage import STORAGE_FOLDER + +logger = logging.getLogger(__name__) + +# Type alias for the grouping key +GroupKey = Tuple[str, int, int] # (variable, period, reform_id) + + +def _resolve_geo_ancestor( + stratum: Stratum, + stratum_cache: Dict[int, Stratum], +) -> Optional[int]: + """Walk up the parent chain to find the nearest geo ancestor. + + A geographic stratum has ``stratum_group_id == 1``. + + Returns: + The ``stratum_id`` of the geographic ancestor, or ``None`` + if the stratum itself is geographic (group 1) or has no + geographic ancestor. + """ + if stratum.stratum_group_id == 1: + return stratum.stratum_id + + current = stratum + while current.parent_stratum_id is not None: + parent = stratum_cache.get(current.parent_stratum_id) + if parent is None: + break + if parent.stratum_group_id == 1: + return parent.stratum_id + current = parent + return None + + +def _classify_geo_level( + stratum: Stratum, + stratum_cache: Dict[int, Stratum], +) -> Optional[str]: + """Classify a geographic stratum as national/state/district. + + Works by inspecting the constraint variables on the stratum. + """ + if stratum.stratum_group_id != 1: + return None + + constraint_vars = { + c.constraint_variable for c in (stratum.constraints_rel or []) + } + + if "congressional_district_geoid" in constraint_vars: + return "district" + elif "state_fips" in constraint_vars: + return "state" + elif stratum.parent_stratum_id is None: + return "national" + + return None + + +def _get_state_fips(stratum: Stratum) -> Optional[int]: + """Extract state_fips from a stratum's constraints.""" + for c in stratum.constraints_rel or []: + if c.constraint_variable == "state_fips": + return int(c.value) + return None + + +def _get_cd_geoid(stratum: Stratum) -> Optional[int]: + """Extract congressional_district_geoid from constraints.""" + for c in stratum.constraints_rel or []: + if c.constraint_variable == "congressional_district_geoid": + return int(c.value) + return None + + +def reconcile_targets(session: Session) -> Dict[str, int]: + """Run two-pass proportional rescaling on all active targets. + + Args: + session: An open SQLModel session. + + Returns: + A dict with counts: ``scaled_state``, ``scaled_cd``, + ``skipped_zero_sum``. + """ + stats = { + "scaled_state": 0, + "scaled_cd": 0, + "skipped_zero_sum": 0, + } + + # Load all strata into a cache + all_strata = session.exec(select(Stratum)).unique().all() + stratum_cache: Dict[int, Stratum] = {s.stratum_id: s for s in all_strata} + + # Build geo-stratum lookup: geo_stratum_id -> (level, state_fips) + geo_info: Dict[int, dict] = {} + for s in all_strata: + level = _classify_geo_level(s, stratum_cache) + if level is not None: + info = {"level": level} + if level == "state": + info["state_fips"] = _get_state_fips(s) + elif level == "district": + info["cd_geoid"] = _get_cd_geoid(s) + # Derive state_fips from CD geoid (first 1-2 digits) + cd = _get_cd_geoid(s) + if cd is not None: + info["state_fips"] = cd // 100 + geo_info[s.stratum_id] = info + + # Load all active targets + stmt = select(Target).where(Target.active == True) # noqa: E712 + all_targets = session.exec(stmt).all() + + # Group targets by (variable, period, reform_id) + groups: Dict[GroupKey, List[Target]] = defaultdict(list) + for t in all_targets: + key = (t.variable, t.period, t.reform_id) + groups[key].append(t) + + for key, targets in groups.items(): + variable, period, reform_id = key + + # Resolve each target to its geographic ancestor + national_targets = [] + state_targets: Dict[int, List[Target]] = defaultdict(list) + district_targets: Dict[int, List[Target]] = defaultdict(list) + + for t in targets: + stratum = stratum_cache.get(t.stratum_id) + if stratum is None: + continue + + geo_ancestor_id = _resolve_geo_ancestor(stratum, stratum_cache) + if geo_ancestor_id is None: + continue + + info = geo_info.get(geo_ancestor_id) + if info is None: + continue + + level = info["level"] + if level == "national": + national_targets.append(t) + elif level == "state": + fips = info.get("state_fips") + if fips is not None: + state_targets[fips].append(t) + elif level == "district": + fips = info.get("state_fips") + if fips is not None: + district_targets[fips].append(t) + + # Pass 1: scale states -> national + if national_targets and state_targets: + national_value = sum( + (t.raw_value if t.raw_value is not None else t.value) + for t in national_targets + if t.value is not None + ) + + all_state_targets = [] + for fips_targets in state_targets.values(): + all_state_targets.extend(fips_targets) + + state_sum = sum( + (t.raw_value if t.raw_value is not None else t.value) + for t in all_state_targets + if t.value is not None + ) + + if state_sum == 0: + logger.warning( + "State sum is zero for %s/%s/%s; skipping", + variable, + period, + reform_id, + ) + stats["skipped_zero_sum"] += 1 + elif national_value != 0: + scale = national_value / state_sum + if abs(scale - 1.0) > 1e-9: + logger.info( + "Scaling %d state targets for %s/%s: " + "factor=%.6f (state_sum=%.2f, " + "national=%.2f)", + len(all_state_targets), + variable, + period, + scale, + state_sum, + national_value, + ) + for t in all_state_targets: + if t.value is None: + continue + base = ( + t.raw_value if t.raw_value is not None else t.value + ) + t.raw_value = base + t.value = base * scale + stats["scaled_state"] += 1 + + # Pass 2: scale CDs -> state + # After pass 1, state target .value is already corrected. + for fips, cd_targets_for_state in district_targets.items(): + state_targets_for_fips = state_targets.get(fips, []) + if not state_targets_for_fips: + continue + + state_value = sum( + t.value for t in state_targets_for_fips if t.value is not None + ) + + cd_sum = sum( + (t.raw_value if t.raw_value is not None else t.value) + for t in cd_targets_for_state + if t.value is not None + ) + + if cd_sum == 0: + logger.warning( + "CD sum is zero for state %s, %s/%s/%s; " "skipping", + fips, + variable, + period, + reform_id, + ) + stats["skipped_zero_sum"] += 1 + elif state_value != 0: + scale = state_value / cd_sum + if abs(scale - 1.0) > 1e-9: + logger.info( + "Scaling %d CD targets for state %s, " + "%s/%s: factor=%.6f", + len(cd_targets_for_state), + fips, + variable, + period, + scale, + ) + for t in cd_targets_for_state: + if t.value is None: + continue + base = ( + t.raw_value if t.raw_value is not None else t.value + ) + t.raw_value = base + t.value = base * scale + stats["scaled_cd"] += 1 + + session.commit() + + logger.info( + "Reconciliation complete: %d state targets scaled, " + "%d CD targets scaled, %d groups skipped (zero sum)", + stats["scaled_state"], + stats["scaled_cd"], + stats["skipped_zero_sum"], + ) + + return stats + + +def main(): + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + db_uri = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}" + engine = create_engine(db_uri) + + with Session(engine) as session: + stats = reconcile_targets(session) + + logger.info("Final stats: %s", stats) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/validate_database.py b/policyengine_us_data/db/validate_database.py index 2fa819f29..c4e499dfb 100644 --- a/policyengine_us_data/db/validate_database.py +++ b/policyengine_us_data/db/validate_database.py @@ -4,11 +4,19 @@ the overall correctness of data after a full pipeline run with production data. """ +import logging import sqlite3 +import numpy as np import pandas as pd from policyengine_us.system import system +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + conn = sqlite3.connect( "policyengine_us_data/storage/calibration/policy_data.db" ) @@ -17,9 +25,189 @@ targets_df = pd.read_sql("SELECT * FROM targets", conn) for var_name in set(targets_df["variable"]): - if not var_name in system.variables.keys(): + if var_name not in system.variables.keys(): raise ValueError(f"{var_name} not a policyengine-us variable") for var_name in set(stratum_constraints_df["constraint_variable"]): - if not var_name in system.variables.keys(): + if var_name not in system.variables.keys(): raise ValueError(f"{var_name} not a policyengine-us variable") + + +# ------------------------------------------------------------------ +# Validate geographic target reconciliation +# ------------------------------------------------------------------ + +strata_df = pd.read_sql("SELECT * FROM strata", conn) + +# Build parent chain for each stratum to find geo ancestor +geo_strata = strata_df[strata_df["stratum_group_id"] == 1].copy() + +# Classify geo strata +sc_df = stratum_constraints_df.copy() +geo_constraints = sc_df[ + sc_df["stratum_id"].isin(geo_strata["stratum_id"]) +].copy() + +national_ids = set( + geo_strata[geo_strata["parent_stratum_id"].isna()]["stratum_id"] +) + +state_constraints = geo_constraints[ + geo_constraints["constraint_variable"] == "state_fips" +] +state_stratum_to_fips = dict( + zip( + state_constraints["stratum_id"], + state_constraints["value"].astype(int), + ) +) + +cd_constraints = geo_constraints[ + geo_constraints["constraint_variable"] == "congressional_district_geoid" +] +cd_stratum_to_geoid = dict( + zip( + cd_constraints["stratum_id"], + cd_constraints["value"].astype(int), + ) +) + +# Build a lookup: stratum_id -> geo ancestor stratum_id +parent_map = dict(zip(strata_df["stratum_id"], strata_df["parent_stratum_id"])) +group_map = dict(zip(strata_df["stratum_id"], strata_df["stratum_group_id"])) + + +def find_geo_ancestor(sid): + """Walk up parent chain to find geo stratum (group_id==1).""" + if group_map.get(sid) == 1: + return sid + current = sid + visited = set() + while current is not None and current not in visited: + visited.add(current) + p = parent_map.get(current) + if p is None or np.isnan(p) if isinstance(p, float) else False: + return None + p = int(p) + if group_map.get(p) == 1: + return p + current = p + return None + + +active_targets = targets_df[targets_df["active"] == 1].copy() +active_targets["geo_ancestor"] = active_targets["stratum_id"].apply( + find_geo_ancestor +) + +# Drop targets with no geo ancestor +geo_targets = active_targets.dropna(subset=["geo_ancestor"]).copy() +geo_targets["geo_ancestor"] = geo_targets["geo_ancestor"].astype(int) + + +# Classify each target's geo level +def classify_geo(geo_sid): + if geo_sid in national_ids: + return "national" + if geo_sid in state_stratum_to_fips: + return "state" + if geo_sid in cd_stratum_to_geoid: + return "district" + return "unknown" + + +geo_targets["geo_level"] = geo_targets["geo_ancestor"].apply(classify_geo) + +# For state targets, get state_fips +geo_targets["state_fips"] = geo_targets["geo_ancestor"].map( + state_stratum_to_fips +) + +# For district targets, derive state_fips from geoid +geo_targets["cd_geoid"] = geo_targets["geo_ancestor"].map(cd_stratum_to_geoid) +geo_targets.loc[geo_targets["geo_level"] == "district", "state_fips"] = ( + geo_targets.loc[geo_targets["geo_level"] == "district", "cd_geoid"] // 100 +) + +# Check: for each (variable, period, reform_id), sum(state) ≈ national +RTOL = 1e-6 +reconciliation_failures = [] + +for (var, period, reform), group in geo_targets.groupby( + ["variable", "period", "reform_id"] +): + nat = group[group["geo_level"] == "national"] + states = group[group["geo_level"] == "state"] + + if nat.empty or states.empty: + continue + + nat_val = nat["value"].sum() + state_sum = states["value"].sum() + + if nat_val != 0 and state_sum != 0: + ratio = abs(state_sum - nat_val) / abs(nat_val) + if ratio > RTOL: + reconciliation_failures.append( + { + "variable": var, + "period": period, + "reform_id": reform, + "level": "state->national", + "parent_value": nat_val, + "child_sum": state_sum, + "ratio": ratio, + } + ) + + # Check CDs sum to state + districts = group[group["geo_level"] == "district"] + if districts.empty: + continue + + for fips, state_group in states.groupby("state_fips"): + state_val = state_group["value"].sum() + cd_group = districts[districts["state_fips"] == fips] + if cd_group.empty: + continue + cd_sum = cd_group["value"].sum() + + if state_val != 0 and cd_sum != 0: + ratio = abs(cd_sum - state_val) / abs(state_val) + if ratio > RTOL: + reconciliation_failures.append( + { + "variable": var, + "period": period, + "reform_id": reform, + "level": f"cd->state(fips={fips})", + "parent_value": state_val, + "child_sum": cd_sum, + "ratio": ratio, + } + ) + +if reconciliation_failures: + logger.warning( + "Found %d geographic reconciliation mismatches:", + len(reconciliation_failures), + ) + for f in reconciliation_failures[:10]: + logger.warning( + " %s %s/%s: %s parent=%.2f child_sum=%.2f " "ratio=%.6f", + f["variable"], + f["period"], + f["reform_id"], + f["level"], + f["parent_value"], + f["child_sum"], + f["ratio"], + ) + raise ValueError( + f"{len(reconciliation_failures)} geographic target groups " + f"are not reconciled (rtol={RTOL})" + ) + +logger.info("All geographic target reconciliation checks passed.") + +conn.close() diff --git a/policyengine_us_data/storage/calibration/clones/clone_0000.npz b/policyengine_us_data/storage/calibration/clones/clone_0000.npz new file mode 100644 index 000000000..ee32cdbd2 Binary files /dev/null and b/policyengine_us_data/storage/calibration/clones/clone_0000.npz differ diff --git a/policyengine_us_data/storage/calibration/l0_sweep.log b/policyengine_us_data/storage/calibration/l0_sweep.log new file mode 100644 index 000000000..e43016a6c --- /dev/null +++ b/policyengine_us_data/storage/calibration/l0_sweep.log @@ -0,0 +1,39 @@ +2026-02-09 05:43:03,399 - INFO - Loading dataset from /Users/maxghenis/policyengine-us-data/policyengine_us_data/storage/extended_cps_2024.h5 +2026-02-09 05:43:03,657 - INFO - Loaded 111524 households +2026-02-09 05:43:03,657 - INFO - Assigning geography: 111524 x 130 = 14498120 +2026-02-09 05:43:19,761 - INFO - Building unified matrix: 37749 targets, 14498120 total columns (111524 records x 130 clones) +2026-02-09 05:43:57,849 - INFO - Processing clone 1 / 130 (cols 0-111523, 50 unique states)... +2026-02-09 05:45:00,370 - INFO - Clone 1: saved 7027961 nonzero entries to clone_0000.npz +2026-02-09 05:45:00,380 - INFO - Processing clone 2 / 130 (cols 111524-223047, 50 unique states)... +2026-02-09 05:46:01,789 - INFO - Clone 2: saved 7027439 nonzero entries to clone_0001.npz +2026-02-09 05:46:01,801 - INFO - Processing clone 3 / 130 (cols 223048-334571, 50 unique states)... +2026-02-09 05:47:03,059 - INFO - Clone 3: saved 7025963 nonzero entries to clone_0002.npz +2026-02-09 05:47:03,069 - INFO - Processing clone 4 / 130 (cols 334572-446095, 50 unique states)... +2026-02-09 05:48:04,501 - INFO - Clone 4: saved 7026664 nonzero entries to clone_0003.npz +2026-02-09 05:48:04,512 - INFO - Processing clone 5 / 130 (cols 446096-557619, 50 unique states)... +2026-02-09 05:49:04,899 - INFO - Clone 5: saved 7028450 nonzero entries to clone_0004.npz +2026-02-09 05:49:04,911 - INFO - Processing clone 6 / 130 (cols 557620-669143, 50 unique states)... +2026-02-09 05:50:07,800 - INFO - Clone 6: saved 7027970 nonzero entries to clone_0005.npz +2026-02-09 05:50:07,813 - INFO - Processing clone 7 / 130 (cols 669144-780667, 50 unique states)... +2026-02-09 05:51:10,541 - INFO - Clone 7: saved 7026497 nonzero entries to clone_0006.npz +2026-02-09 05:51:10,553 - INFO - Processing clone 8 / 130 (cols 780668-892191, 50 unique states)... +2026-02-09 05:52:14,777 - INFO - Clone 8: saved 7025503 nonzero entries to clone_0007.npz +2026-02-09 05:52:14,792 - INFO - Processing clone 9 / 130 (cols 892192-1003715, 50 unique states)... +2026-02-09 05:53:20,975 - INFO - Clone 9: saved 7027479 nonzero entries to clone_0008.npz +2026-02-09 05:53:20,988 - INFO - Processing clone 10 / 130 (cols 1003716-1115239, 50 unique states)... +2026-02-09 05:54:27,704 - INFO - Clone 10: saved 7027762 nonzero entries to clone_0009.npz +2026-02-09 05:54:27,720 - INFO - Processing clone 11 / 130 (cols 1115240-1226763, 50 unique states)... +2026-02-09 05:55:34,861 - INFO - Clone 11: saved 7027847 nonzero entries to clone_0010.npz +2026-02-09 05:55:34,884 - INFO - Processing clone 12 / 130 (cols 1226764-1338287, 50 unique states)... +2026-02-09 05:56:39,280 - INFO - Clone 12: saved 7027713 nonzero entries to clone_0011.npz +2026-02-09 05:56:39,296 - INFO - Processing clone 13 / 130 (cols 1338288-1449811, 50 unique states)... +2026-02-09 05:57:45,508 - INFO - Clone 13: saved 7028085 nonzero entries to clone_0012.npz +2026-02-09 05:57:45,522 - INFO - Processing clone 14 / 130 (cols 1449812-1561335, 50 unique states)... +2026-02-09 05:58:50,721 - INFO - Clone 14: saved 7026974 nonzero entries to clone_0013.npz +2026-02-09 05:58:50,738 - INFO - Processing clone 15 / 130 (cols 1561336-1672859, 50 unique states)... +2026-02-09 05:59:56,953 - INFO - Clone 15: saved 7027944 nonzero entries to clone_0014.npz +2026-02-09 05:59:56,966 - INFO - Processing clone 16 / 130 (cols 1672860-1784383, 50 unique states)... +2026-02-09 06:01:02,245 - INFO - Clone 16: saved 7026584 nonzero entries to clone_0015.npz +2026-02-09 06:01:02,261 - INFO - Processing clone 17 / 130 (cols 1784384-1895907, 50 unique states)... +2026-02-09 06:02:02,757 - INFO - Clone 17: saved 7025779 nonzero entries to clone_0016.npz +2026-02-09 06:02:02,768 - INFO - Processing clone 18 / 130 (cols 1895908-2007431, 50 unique states)... diff --git a/policyengine_us_data/tests/test_calibration/__init__.py b/policyengine_us_data/tests/test_calibration/__init__.py new file mode 100644 index 000000000..3ac75e097 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/__init__.py @@ -0,0 +1 @@ +"""Tests for the unified national calibration pipeline.""" diff --git a/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py new file mode 100644 index 000000000..2dbb469d4 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_clone_and_assign.py @@ -0,0 +1,358 @@ +"""Tests for clone_and_assign module. + +Uses mock CSV data so tests don't require the real +block_cd_distributions.csv.gz file. +""" + +import numpy as np +import pandas as pd +import pytest +from unittest.mock import patch +from pathlib import Path + +from policyengine_us_data.calibration.clone_and_assign import ( + GeographyAssignment, + load_global_block_distribution, + assign_random_geography, + double_geography_for_puf, +) + +# ------------------------------------------------------------------ +# Mock data: 3 CDs with known blocks +# ------------------------------------------------------------------ + +MOCK_BLOCKS = pd.DataFrame( + { + "cd_geoid": [ + 101, + 101, + 101, + 102, + 102, + 103, + 103, + 103, + 103, + ], + "block_geoid": [ + "010010001001001", # AL CD-01 + "010010001001002", + "010010001001003", + "020010001001001", # AK CD-02 + "020010001001002", + "360100001001001", # NY CD-10 + "360100001001002", + "360100001001003", + "360100001001004", + ], + "probability": [ + 0.4, + 0.3, + 0.3, + 0.6, + 0.4, + 0.25, + 0.25, + 0.25, + 0.25, + ], + } +) + + +@pytest.fixture(autouse=True) +def _clear_lru_cache(): + """Clear the lru_cache between tests.""" + load_global_block_distribution.cache_clear() + yield + load_global_block_distribution.cache_clear() + + +# ------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------ + + +class TestLoadGlobalBlockDistribution: + """Tests for load_global_block_distribution.""" + + def _write_mock_csv(self, tmp_path): + """Write MOCK_BLOCKS as a gzipped CSV into tmp_path.""" + csv_path = tmp_path / "block_cd_distributions.csv.gz" + MOCK_BLOCKS.to_csv(csv_path, index=False, compression="gzip") + return tmp_path + + def test_loads_and_normalizes(self, tmp_path): + """Probabilities sum to 1 globally after loading.""" + storage = self._write_mock_csv(tmp_path) + + with patch( + "policyengine_us_data.calibration" + ".clone_and_assign.STORAGE_FOLDER", + storage, + ): + blocks, cds, states, probs = ( + load_global_block_distribution.__wrapped__() + ) + + assert len(blocks) == 9 + assert len(probs) == 9 + np.testing.assert_almost_equal(probs.sum(), 1.0) + + def test_state_fips_extracted(self, tmp_path): + """State FIPS extracted correctly from block GEOID.""" + storage = self._write_mock_csv(tmp_path) + + with patch( + "policyengine_us_data.calibration" + ".clone_and_assign.STORAGE_FOLDER", + storage, + ): + blocks, cds, states, probs = ( + load_global_block_distribution.__wrapped__() + ) + + # First 3 blocks start with "01" -> state 1 + assert states[0] == 1 + assert states[1] == 1 + assert states[2] == 1 + # Next 2 start with "02" -> state 2 + assert states[3] == 2 + assert states[4] == 2 + # Last 4 start with "36" -> state 36 + assert states[5] == 36 + assert states[6] == 36 + assert states[7] == 36 + assert states[8] == 36 + + +class TestAssignRandomGeography: + """Tests for assign_random_geography.""" + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_assign_shape(self, mock_load): + """Output arrays have length n_records * n_clones.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + result = assign_random_geography(n_records=10, n_clones=3, seed=42) + + assert len(result.block_geoid) == 30 + assert len(result.cd_geoid) == 30 + assert len(result.state_fips) == 30 + assert result.n_records == 10 + assert result.n_clones == 3 + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_assign_deterministic(self, mock_load): + """Same seed produces identical results.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + r1 = assign_random_geography(n_records=10, n_clones=3, seed=99) + r2 = assign_random_geography(n_records=10, n_clones=3, seed=99) + + np.testing.assert_array_equal(r1.block_geoid, r2.block_geoid) + np.testing.assert_array_equal(r1.cd_geoid, r2.cd_geoid) + np.testing.assert_array_equal(r1.state_fips, r2.state_fips) + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_assign_different_seeds(self, mock_load): + """Different seeds produce different results.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + r1 = assign_random_geography(n_records=100, n_clones=3, seed=1) + r2 = assign_random_geography(n_records=100, n_clones=3, seed=2) + + # Extremely unlikely to be identical with different seeds + assert not np.array_equal(r1.block_geoid, r2.block_geoid) + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_state_from_block(self, mock_load): + """state_fips[i] == int(block_geoid[i][:2]) for all i.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + result = assign_random_geography(n_records=20, n_clones=5, seed=42) + + for i in range(len(result.block_geoid)): + expected = int(result.block_geoid[i][:2]) + assert result.state_fips[i] == expected, ( + f"Index {i}: state_fips={result.state_fips[i]}" + f" but block starts with" + f" {result.block_geoid[i][:2]}" + ) + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_cd_from_block(self, mock_load): + """cd_geoid matches the block's CD from distribution.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + # Build a lookup from mock data + block_to_cd = dict( + zip( + MOCK_BLOCKS["block_geoid"], + MOCK_BLOCKS["cd_geoid"].astype(str), + ) + ) + + result = assign_random_geography(n_records=20, n_clones=5, seed=42) + + for i in range(len(result.block_geoid)): + blk = result.block_geoid[i] + expected_cd = block_to_cd[blk] + assert result.cd_geoid[i] == expected_cd, ( + f"Index {i}: cd_geoid={result.cd_geoid[i]}" + f" but block {blk} belongs to" + f" CD {expected_cd}" + ) + + @patch( + "policyengine_us_data.calibration.clone_and_assign" + ".load_global_block_distribution" + ) + def test_column_ordering(self, mock_load): + """Verify clone_idx = i // n_records, etc.""" + blocks = MOCK_BLOCKS["block_geoid"].values + cds = MOCK_BLOCKS["cd_geoid"].astype(str).values + states = np.array([int(b[:2]) for b in blocks]) + probs = MOCK_BLOCKS["probability"].values.astype(np.float64) + probs = probs / probs.sum() + mock_load.return_value = (blocks, cds, states, probs) + + n_records = 10 + n_clones = 3 + result = assign_random_geography( + n_records=n_records, + n_clones=n_clones, + seed=42, + ) + + n_total = n_records * n_clones + assert len(result.block_geoid) == n_total + + # Verify the dataclass stores dimensions correctly + assert result.n_records == n_records + assert result.n_clones == n_clones + + # Verify indexing convention: i maps to + # clone_idx = i // n_records + # record_idx = i % n_records + for i in range(n_total): + clone_idx = i // n_records + record_idx = i % n_records + assert 0 <= clone_idx < n_clones + assert 0 <= record_idx < n_records + + def test_missing_file_raises(self, tmp_path): + """FileNotFoundError if CSV doesn't exist.""" + fake_storage = tmp_path / "nonexistent_storage" + fake_storage.mkdir() + + with patch( + "policyengine_us_data.calibration.clone_and_assign" + ".STORAGE_FOLDER", + fake_storage, + ): + with pytest.raises(FileNotFoundError): + load_global_block_distribution.__wrapped__() + + +class TestDoubleGeographyForPuf: + """Tests for double_geography_for_puf.""" + + def test_doubles_n_records(self): + """n_records doubles, n_clones stays the same.""" + geo = GeographyAssignment( + block_geoid=np.array(["010010001001001", "020010001001001"] * 3), + cd_geoid=np.array(["101", "202"] * 3), + state_fips=np.array([1, 2] * 3), + n_records=2, + n_clones=3, + ) + result = double_geography_for_puf(geo) + assert result.n_records == 4 + assert result.n_clones == 3 + assert len(result.block_geoid) == 12 # 4 * 3 + + def test_array_length(self): + """Output arrays have length n_records * 2 * n_clones.""" + geo = GeographyAssignment( + block_geoid=np.array(["010010001001001"] * 15), + cd_geoid=np.array(["101"] * 15), + state_fips=np.array([1] * 15), + n_records=5, + n_clones=3, + ) + result = double_geography_for_puf(geo) + assert len(result.block_geoid) == 30 + assert len(result.cd_geoid) == 30 + assert len(result.state_fips) == 30 + + def test_puf_half_matches_cps_half(self): + """Each clone's PUF half has same geography as CPS half.""" + geo = GeographyAssignment( + block_geoid=np.array( + [ + "010010001001001", + "020010001001001", + "360100001001001", + "060100001001001", + "480100001001001", + "120100001001001", + ] + ), + cd_geoid=np.array(["101", "202", "1036", "653", "4831", "1227"]), + state_fips=np.array([1, 2, 36, 6, 48, 12]), + n_records=3, + n_clones=2, + ) + result = double_geography_for_puf(geo) + n_new = result.n_records # 6 + + for c in range(result.n_clones): + start = c * n_new + mid = start + n_new // 2 + end = start + n_new + # CPS half + cps_states = result.state_fips[start:mid] + # PUF half + puf_states = result.state_fips[mid:end] + np.testing.assert_array_equal(cps_states, puf_states) diff --git a/policyengine_us_data/tests/test_calibration/test_fit_national_weights.py b/policyengine_us_data/tests/test_calibration/test_fit_national_weights.py new file mode 100644 index 000000000..a18cf9b67 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_fit_national_weights.py @@ -0,0 +1,956 @@ +""" +Tests for national L0 calibration script (fit_national_weights.py). + +Uses TDD: tests written first, then implementation. +Mocks heavy dependencies (torch, l0, Microsimulation) to keep +tests fast. +""" + +import sys +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import h5py +import numpy as np +import pytest + +# ------------------------------------------------------------------- +# Import tests +# ------------------------------------------------------------------- + + +class TestImports: + """Test that the module can be imported.""" + + def test_module_imports(self): + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration.fit_national_weights" + ) + assert hasattr(mod, "fit_national_weights") + + def test_public_functions_exist(self): + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration.fit_national_weights" + ) + + for name in [ + "fit_national_weights", + "initialize_weights", + "build_calibration_inputs", + "compute_diagnostics", + "save_weights_to_h5", + "run_validation", + "parse_args", + "main", + ]: + assert hasattr(mod, name), f"Missing: {name}" + + def test_constants_defined(self): + from policyengine_us_data.calibration.fit_national_weights import ( + LAMBDA_L0, + LAMBDA_L2, + LEARNING_RATE, + DEFAULT_EPOCHS, + BETA, + GAMMA, + ZETA, + INIT_KEEP_PROB, + LOG_WEIGHT_JITTER_SD, + LOG_ALPHA_JITTER_SD, + ) + + assert LAMBDA_L0 == 1e-6 + assert LAMBDA_L2 == 1e-12 + assert LEARNING_RATE == 0.15 + assert DEFAULT_EPOCHS == 1000 + assert BETA == 0.35 + assert GAMMA == -0.1 + assert ZETA == 1.1 + assert INIT_KEEP_PROB == 0.999 + assert LOG_WEIGHT_JITTER_SD == 0.05 + assert LOG_ALPHA_JITTER_SD == 0.01 + + def test_weight_floor_constant(self): + from policyengine_us_data.calibration.fit_national_weights import ( + _WEIGHT_FLOOR, + ) + + assert _WEIGHT_FLOOR > 0 + assert _WEIGHT_FLOOR < 1 + + +# ------------------------------------------------------------------- +# initialize_weights tests +# ------------------------------------------------------------------- + + +class TestInitializeWeights: + """Test weight initialization from household_weight.""" + + def test_returns_correct_shape(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.array([100.0, 200.0, 50.0, 0.5, 300.0]) + result = initialize_weights(original) + assert result.shape == original.shape + + def test_returns_float64(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.array([100.0, 200.0], dtype=np.float32) + result = initialize_weights(original) + assert result.dtype == np.float64 + + def test_zero_weights_get_floor(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + _WEIGHT_FLOOR, + ) + + original = np.array([100.0, 0.0, 50.0, -1.0]) + result = initialize_weights(original) + assert np.all(result > 0) + assert result[1] == _WEIGHT_FLOOR + assert result[3] == _WEIGHT_FLOOR + + def test_preserves_positive_weights(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.array([100.0, 200.0, 50.0]) + result = initialize_weights(original) + np.testing.assert_array_almost_equal(result, original) + + def test_does_not_mutate_input(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.array([100.0, 0.0, -5.0]) + original_copy = original.copy() + initialize_weights(original) + np.testing.assert_array_equal(original, original_copy) + + def test_all_zero_weights(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + _WEIGHT_FLOOR, + ) + + original = np.zeros(10) + result = initialize_weights(original) + assert np.all(result == _WEIGHT_FLOOR) + + def test_all_negative_weights(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.array([-1.0, -100.0, -0.001]) + result = initialize_weights(original) + assert np.all(result > 0) + + def test_single_element(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + result = initialize_weights(np.array([42.0])) + assert result[0] == pytest.approx(42.0) + + def test_large_array(self): + from policyengine_us_data.calibration.fit_national_weights import ( + initialize_weights, + ) + + original = np.random.uniform(-10, 1000, 200_000) + result = initialize_weights(original) + assert result.shape == (200_000,) + assert np.all(result > 0) + + +# ------------------------------------------------------------------- +# build_calibration_inputs tests +# ------------------------------------------------------------------- + + +class TestBuildCalibrationInputs: + """Test building the calibration matrix and targets via DB.""" + + def test_returns_float32_matrix(self): + """Matrix should be float32 for memory efficiency.""" + from policyengine_us_data.calibration.fit_national_weights import ( + build_calibration_inputs, + ) + + n_hh, n_targets = 10, 5 + mock_matrix = np.random.rand(n_hh, n_targets).astype(np.float64) + mock_targets = np.array([1e9, 2e9, 3e9, 4e9, 5e9]) + mock_names = [f"t_{i}" for i in range(n_targets)] + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + mock_matrix, + mock_targets, + mock_names, + ) + + mock_sim = MagicMock() + + with patch( + "policyengine_us_data.calibration." + "national_matrix_builder.NationalMatrixBuilder", + return_value=mock_builder, + ): + matrix, _, _ = build_calibration_inputs( + dataset_class=MagicMock, + time_period=2024, + db_path="/fake/db.sqlite", + sim=mock_sim, + ) + + assert matrix.dtype == np.float32 + + def test_returns_float64_targets(self): + """Targets should be float64 for precision.""" + from policyengine_us_data.calibration.fit_national_weights import ( + build_calibration_inputs, + ) + + n_hh, n_targets = 10, 5 + mock_matrix = np.random.rand(n_hh, n_targets) + mock_targets = np.array([1e9, 2e9, 3e9, 4e9, 5e9]) + mock_names = [f"t_{i}" for i in range(n_targets)] + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + mock_matrix, + mock_targets, + mock_names, + ) + mock_sim = MagicMock() + + with patch( + "policyengine_us_data.calibration." + "national_matrix_builder.NationalMatrixBuilder", + return_value=mock_builder, + ): + _, targets, _ = build_calibration_inputs( + dataset_class=MagicMock, + time_period=2024, + db_path="/fake/db.sqlite", + sim=mock_sim, + ) + + assert targets.dtype == np.float64 + + def test_passes_geo_level(self): + """geo_level is forwarded to build_matrix.""" + from policyengine_us_data.calibration.fit_national_weights import ( + build_calibration_inputs, + ) + + n_hh, n_targets = 10, 3 + mock_matrix = np.random.rand(n_hh, n_targets) + mock_targets = np.random.rand(n_targets) * 1e9 + mock_names = [f"t_{i}" for i in range(n_targets)] + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + mock_matrix, + mock_targets, + mock_names, + ) + mock_sim = MagicMock() + + with patch( + "policyengine_us_data.calibration." + "national_matrix_builder.NationalMatrixBuilder", + return_value=mock_builder, + ): + build_calibration_inputs( + dataset_class=MagicMock, + time_period=2024, + db_path="/fake/db.sqlite", + sim=mock_sim, + geo_level="national", + ) + + call_kwargs = mock_builder.build_matrix.call_args[1] + assert call_kwargs["geo_level"] == "national" + + def test_matrix_and_targets_consistent_shape(self): + """Matrix columns must equal targets length.""" + from policyengine_us_data.calibration.fit_national_weights import ( + build_calibration_inputs, + ) + + n_hh, n_targets = 50, 20 + mock_matrix = np.random.rand(n_hh, n_targets) + mock_targets = np.random.rand(n_targets) * 1e9 + mock_names = [f"t_{i}" for i in range(n_targets)] + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + mock_matrix, + mock_targets, + mock_names, + ) + mock_sim = MagicMock() + + with patch( + "policyengine_us_data.calibration." + "national_matrix_builder.NationalMatrixBuilder", + return_value=mock_builder, + ): + matrix, targets, names = build_calibration_inputs( + dataset_class=MagicMock, + time_period=2024, + db_path="/fake/db.sqlite", + sim=mock_sim, + ) + + assert matrix.shape[1] == len(targets) + assert matrix.shape[1] == len(names) + + +# ------------------------------------------------------------------- +# compute_diagnostics tests +# ------------------------------------------------------------------- + + +class TestComputeDiagnostics: + """Test diagnostic output formatting.""" + + def test_basic_diagnostics(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([1000.0, 2000.0, 3000.0, 4000.0]) + estimates = np.array([1050.0, 1500.0, 3100.0, 4000.0]) + names = ["pop", "income", "snap", "ssi"] + + diag = compute_diagnostics(targets, estimates, names) + + assert "pct_within_10" in diag + assert "worst_targets" in diag + assert isinstance(diag["pct_within_10"], float) + assert 0 <= diag["pct_within_10"] <= 100 + + def test_worst_targets_sorted(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([1000.0, 2000.0, 3000.0]) + # income has 25% error (worst), pop has 5% + estimates = np.array([1050.0, 1500.0, 3050.0]) + names = ["pop", "income", "snap"] + + diag = compute_diagnostics(targets, estimates, names) + worst = diag["worst_targets"] + + # First worst target should be "income" (25% error) + assert worst[0][0] == "income" + + def test_perfect_match(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([1000.0, 2000.0, 3000.0]) + estimates = targets.copy() + names = ["a", "b", "c"] + + diag = compute_diagnostics(targets, estimates, names) + assert diag["pct_within_10"] == pytest.approx(100.0) + + def test_all_outside_threshold(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([100.0, 200.0]) + estimates = np.array([200.0, 400.0]) # 100% error + names = ["a", "b"] + + diag = compute_diagnostics(targets, estimates, names) + assert diag["pct_within_10"] == pytest.approx(0.0) + + def test_custom_threshold(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([100.0, 200.0]) + estimates = np.array([110.0, 230.0]) + names = ["a", "b"] + + # With 20% threshold, both should be within + diag = compute_diagnostics(targets, estimates, names, threshold=0.20) + assert diag["pct_within_10"] == pytest.approx(100.0) + + def test_custom_n_worst(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.ones(50) * 1000 + estimates = np.arange(50, dtype=float) * 100 + names = [f"t_{i}" for i in range(50)] + + diag = compute_diagnostics(targets, estimates, names, n_worst=5) + assert len(diag["worst_targets"]) == 5 + + def test_near_zero_targets_handled(self): + """Targets near zero should not produce inf/nan errors.""" + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([0.0, 1e-10, 1000.0]) + estimates = np.array([5.0, 10.0, 1050.0]) + names = ["zero", "tiny", "normal"] + + diag = compute_diagnostics(targets, estimates, names) + assert not np.isnan(diag["pct_within_10"]) + assert not any(np.isnan(err) for _, err in diag["worst_targets"]) + + def test_single_target(self): + from policyengine_us_data.calibration.fit_national_weights import ( + compute_diagnostics, + ) + + targets = np.array([1000.0]) + estimates = np.array([950.0]) + names = ["only"] + + diag = compute_diagnostics(targets, estimates, names) + assert diag["pct_within_10"] == pytest.approx(100.0) + assert len(diag["worst_targets"]) == 1 + + +# ------------------------------------------------------------------- +# fit_national_weights tests +# ------------------------------------------------------------------- + + +class TestFitNationalWeights: + """Test the main fitting function with mocked L0.""" + + def _mock_l0(self, n_households, return_weights=None): + """Create mock l0.calibration.SparseCalibrationWeights.""" + if return_weights is None: + return_weights = np.ones(n_households) * 95.0 + + mock_model = MagicMock() + mock_model.get_weights.return_value = MagicMock( + cpu=MagicMock( + return_value=MagicMock( + numpy=MagicMock(return_value=return_weights) + ) + ) + ) + + mock_l0_module = MagicMock() + mock_l0_module.SparseCalibrationWeights = MagicMock( + return_value=mock_model + ) + + return mock_l0_module, mock_model + + def test_returns_weights_array(self): + """fit_national_weights returns array with correct shape.""" + from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + ) + + n_households = 50 + n_targets = 10 + matrix = np.random.rand(n_households, n_targets).astype(np.float32) + targets = np.random.rand(n_targets).astype(np.float32) * 1e6 + initial_weights = np.ones(n_households) * 100.0 + + mock_l0_module, _ = self._mock_l0(n_households) + + with patch.dict( + sys.modules, + { + "l0": MagicMock(), + "l0.calibration": mock_l0_module, + }, + ): + weights = fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=initial_weights, + epochs=5, + ) + + assert weights.shape == (n_households,) + assert np.all(weights > 0) + + def test_calls_sparse_calibration_with_correct_args(self): + """SparseCalibrationWeights is called with the right + hyperparameters.""" + from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + BETA, + GAMMA, + ZETA, + INIT_KEEP_PROB, + ) + + n = 20 + matrix = np.random.rand(n, 5).astype(np.float32) + targets = np.random.rand(5) * 1e6 + initial_weights = np.ones(n) * 100.0 + + mock_l0_module, mock_model = self._mock_l0(n) + + with patch.dict( + sys.modules, + { + "l0": MagicMock(), + "l0.calibration": mock_l0_module, + }, + ): + fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=initial_weights, + epochs=10, + ) + + # Verify SparseCalibrationWeights constructor was called + constructor = mock_l0_module.SparseCalibrationWeights + assert constructor.called + call_kwargs = constructor.call_args[1] + assert call_kwargs["n_features"] == n + assert call_kwargs["beta"] == BETA + assert call_kwargs["gamma"] == GAMMA + assert call_kwargs["zeta"] == ZETA + assert call_kwargs["init_keep_prob"] == INIT_KEEP_PROB + + def test_calls_fit_with_correct_args(self): + """model.fit() is called with the correct parameters.""" + from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + ) + + n = 20 + matrix = np.random.rand(n, 5).astype(np.float32) + targets = np.random.rand(5) * 1e6 + initial_weights = np.ones(n) * 100.0 + epochs = 42 + lambda_l0 = 1e-5 + lambda_l2 = 1e-10 + lr = 0.1 + + mock_l0_module, mock_model = self._mock_l0(n) + + with patch.dict( + sys.modules, + { + "l0": MagicMock(), + "l0.calibration": mock_l0_module, + }, + ): + fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=initial_weights, + epochs=epochs, + lambda_l0=lambda_l0, + lambda_l2=lambda_l2, + learning_rate=lr, + ) + + mock_model.fit.assert_called_once() + call_kwargs = mock_model.fit.call_args[1] + assert call_kwargs["epochs"] == epochs + assert call_kwargs["lambda_l0"] == lambda_l0 + assert call_kwargs["lambda_l2"] == lambda_l2 + assert call_kwargs["lr"] == lr + assert call_kwargs["loss_type"] == "relative" + + def test_raises_on_missing_l0(self): + """ImportError raised if l0 is not installed.""" + from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + ) + + n = 10 + matrix = np.random.rand(n, 3).astype(np.float32) + targets = np.random.rand(3) * 1e6 + initial_weights = np.ones(n) * 100.0 + + # Remove l0 from modules so import fails + with patch.dict( + sys.modules, + {"l0": None, "l0.calibration": None}, + ): + with pytest.raises(ImportError, match="l0-python"): + fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=initial_weights, + epochs=1, + ) + + def test_deterministic_weights(self): + """get_weights is called with deterministic=True.""" + from policyengine_us_data.calibration.fit_national_weights import ( + fit_national_weights, + ) + + n = 10 + matrix = np.random.rand(n, 3).astype(np.float32) + targets = np.random.rand(3) * 1e6 + initial_weights = np.ones(n) * 100.0 + + mock_l0_module, mock_model = self._mock_l0(n) + + with patch.dict( + sys.modules, + { + "l0": MagicMock(), + "l0.calibration": mock_l0_module, + }, + ): + fit_national_weights( + matrix=matrix, + targets=targets, + initial_weights=initial_weights, + epochs=1, + ) + + mock_model.get_weights.assert_called_once_with(deterministic=True) + + +# ------------------------------------------------------------------- +# save_weights_to_h5 tests +# ------------------------------------------------------------------- + + +class TestSaveWeights: + """Test saving weights to h5 file.""" + + def test_save_to_h5(self): + from policyengine_us_data.calibration.fit_national_weights import ( + save_weights_to_h5, + ) + + n = 100 + weights = np.random.rand(n) * 200 + + with tempfile.NamedTemporaryFile(suffix=".h5") as tmp: + with h5py.File(tmp.name, "w") as f: + f.create_dataset("household_weight/2024", data=np.ones(n)) + f.create_dataset("person_id/2024", data=np.arange(n)) + + save_weights_to_h5(tmp.name, weights, year=2024) + + with h5py.File(tmp.name, "r") as f: + saved = f["household_weight/2024"][:] + np.testing.assert_array_almost_equal(saved, weights) + + def test_save_preserves_other_data(self): + from policyengine_us_data.calibration.fit_national_weights import ( + save_weights_to_h5, + ) + + n = 50 + weights = np.random.rand(n) * 200 + other_data = np.arange(n) + + with tempfile.NamedTemporaryFile(suffix=".h5") as tmp: + with h5py.File(tmp.name, "w") as f: + f.create_dataset("household_weight/2024", data=np.ones(n)) + f.create_dataset("person_id/2024", data=other_data) + + save_weights_to_h5(tmp.name, weights, year=2024) + + with h5py.File(tmp.name, "r") as f: + np.testing.assert_array_equal( + f["person_id/2024"][:], other_data + ) + + def test_save_creates_key_if_absent(self): + """Saving to an h5 file that has no existing weights + key should create it.""" + from policyengine_us_data.calibration.fit_national_weights import ( + save_weights_to_h5, + ) + + n = 30 + weights = np.random.rand(n) * 100 + + with tempfile.NamedTemporaryFile(suffix=".h5") as tmp: + # Create an h5 file with only other data + with h5py.File(tmp.name, "w") as f: + f.create_dataset("person_id/2024", data=np.arange(n)) + + save_weights_to_h5(tmp.name, weights, year=2024) + + with h5py.File(tmp.name, "r") as f: + assert "household_weight/2024" in f + np.testing.assert_array_almost_equal( + f["household_weight/2024"][:], weights + ) + + def test_save_different_year(self): + """Saving for a different year does not overwrite + other years.""" + from policyengine_us_data.calibration.fit_national_weights import ( + save_weights_to_h5, + ) + + n = 20 + weights_2024 = np.ones(n) * 100 + weights_2025 = np.ones(n) * 200 + + with tempfile.NamedTemporaryFile(suffix=".h5") as tmp: + with h5py.File(tmp.name, "w") as f: + f.create_dataset( + "household_weight/2024", + data=weights_2024, + ) + + save_weights_to_h5(tmp.name, weights_2025, year=2025) + + with h5py.File(tmp.name, "r") as f: + np.testing.assert_array_almost_equal( + f["household_weight/2024"][:], + weights_2024, + ) + np.testing.assert_array_almost_equal( + f["household_weight/2025"][:], + weights_2025, + ) + + def test_overwrite_existing_weights(self): + """Saving weights overwrites existing data at the same + year key.""" + from policyengine_us_data.calibration.fit_national_weights import ( + save_weights_to_h5, + ) + + n = 15 + old_weights = np.ones(n) * 50 + new_weights = np.ones(n) * 150 + + with tempfile.NamedTemporaryFile(suffix=".h5") as tmp: + with h5py.File(tmp.name, "w") as f: + f.create_dataset( + "household_weight/2024", + data=old_weights, + ) + + save_weights_to_h5(tmp.name, new_weights, year=2024) + + with h5py.File(tmp.name, "r") as f: + np.testing.assert_array_almost_equal( + f["household_weight/2024"][:], + new_weights, + ) + + +# ------------------------------------------------------------------- +# run_validation tests +# ------------------------------------------------------------------- + + +class TestRunValidation: + """Test the run_validation convenience function.""" + + def test_does_not_raise(self): + """run_validation should not raise on valid input.""" + from policyengine_us_data.calibration.fit_national_weights import ( + run_validation, + ) + + n = 10 + weights = np.ones(n) * 100 + matrix = np.random.rand(n, 5) + targets = weights @ matrix + names = [f"t_{i}" for i in range(5)] + + # Should not raise + run_validation(weights, matrix, targets, names) + + def test_handles_mismatched_estimates(self): + """run_validation should handle large errors gracefully.""" + from policyengine_us_data.calibration.fit_national_weights import ( + run_validation, + ) + + n = 5 + weights = np.ones(n) + matrix = np.eye(n, 3) + targets = np.array([1e12, 2e12, 3e12]) + names = ["a", "b", "c"] + + # Should not raise + run_validation(weights, matrix, targets, names) + + +# ------------------------------------------------------------------- +# CLI tests +# ------------------------------------------------------------------- + + +class TestCLI: + """Test CLI argument parsing.""" + + def test_parse_args_defaults(self): + from policyengine_us_data.calibration.fit_national_weights import ( + parse_args, + ) + + args = parse_args([]) + assert args.epochs == 1000 + assert args.lambda_l0 == 1e-6 + assert args.device == "cpu" + assert args.dataset is None + assert args.db_path is None + assert args.output is None + assert args.geo_level == "all" + + def test_parse_args_custom(self): + from policyengine_us_data.calibration.fit_national_weights import ( + parse_args, + ) + + args = parse_args( + [ + "--epochs", + "500", + "--lambda-l0", + "1e-5", + "--device", + "cuda", + "--dataset", + "/tmp/data.h5", + "--db-path", + "/tmp/db.sqlite", + "--output", + "/tmp/out.h5", + ] + ) + assert args.epochs == 500 + assert args.lambda_l0 == 1e-5 + assert args.device == "cuda" + assert args.dataset == "/tmp/data.h5" + assert args.db_path == "/tmp/db.sqlite" + assert args.output == "/tmp/out.h5" + + def test_parse_args_invalid_device(self): + from policyengine_us_data.calibration.fit_national_weights import ( + parse_args, + ) + + with pytest.raises(SystemExit): + parse_args(["--device", "tpu"]) + + def test_parse_args_negative_epochs(self): + """Negative epochs is accepted by argparse (validation + is elsewhere).""" + from policyengine_us_data.calibration.fit_national_weights import ( + parse_args, + ) + + args = parse_args(["--epochs", "-1"]) + assert args.epochs == -1 + + +# ------------------------------------------------------------------- +# Integration test: EnhancedCPS.generate interface +# ------------------------------------------------------------------- + + +class TestEnhancedCPSIntegration: + """Test that EnhancedCPS.generate calls the calibration pipeline.""" + + def test_generate_calls_pipeline(self): + """EnhancedCPS.generate invokes + build_calibration_inputs, initialize_weights, + and fit_national_weights in sequence.""" + from policyengine_us_data.datasets.cps.enhanced_cps import ( + EnhancedCPS, + ) + + # Create a minimal subclass for testing + class TestEnhancedCPS(EnhancedCPS): + input_dataset = MagicMock() + start_year = 2024 + end_year = 2024 + name = "test_enhanced_cps" + label = "Test Enhanced CPS" + file_path = "/tmp/test.h5" + url = None + + instance = TestEnhancedCPS() + + n_hh = 50 + mock_weights = np.ones(n_hh) * 100 + mock_matrix = np.random.rand(n_hh, 10) + mock_targets = np.random.rand(10) * 1e9 + mock_names = [f"t_{i}" for i in range(10)] + calibrated = np.ones(n_hh) * 95 + + mock_sim = MagicMock() + mock_calc_result = MagicMock(values=mock_weights) + mock_calc_result.__len__ = lambda self: len(mock_weights) + mock_sim.calculate.return_value = mock_calc_result + mock_sim.dataset.load_dataset.return_value = {"household_weight": {}} + + with ( + patch( + "policyengine_us.Microsimulation", + return_value=mock_sim, + ), + patch( + "policyengine_us_data.calibration." + "fit_national_weights." + "build_calibration_inputs", + return_value=( + mock_matrix, + mock_targets, + mock_names, + ), + ) as mock_build, + patch( + "policyengine_us_data.calibration." + "fit_national_weights." + "initialize_weights", + return_value=mock_weights.copy(), + ) as mock_init, + patch( + "policyengine_us_data.calibration." + "fit_national_weights." + "fit_national_weights", + return_value=calibrated, + ) as mock_fit, + patch.object(instance, "save_dataset"), + ): + instance.generate() + + mock_build.assert_called_once() + mock_init.assert_called_once() + mock_fit.assert_called_once() + + # Check fit was called with 500 epochs (default) + fit_kwargs = mock_fit.call_args[1] + assert fit_kwargs["epochs"] == 500 diff --git a/policyengine_us_data/tests/test_calibration/test_national_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_national_matrix_builder.py new file mode 100644 index 000000000..76bb763d9 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_national_matrix_builder.py @@ -0,0 +1,1232 @@ +""" +Tests for NationalMatrixBuilder. + +Uses a mock in-memory SQLite database with representative targets +and a mock Microsimulation to avoid heavy dependencies. +""" + +import numpy as np +import pytest +from unittest.mock import MagicMock, patch +from sqlalchemy import create_engine +from sqlmodel import Session, SQLModel + +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, + Source, + SourceType, +) +from policyengine_us_data.calibration.national_matrix_builder import ( + NationalMatrixBuilder, + COUNT_VARIABLES, + PERSON_LEVEL_VARIABLES, + SPM_UNIT_VARIABLES, +) + +# ------------------------------------------------------------------- +# Helper: build a mock Microsimulation with controllable data +# ------------------------------------------------------------------- + + +def _make_mock_sim(n_households=5, n_persons=10): + """ + Create a mock Microsimulation with controllable data. + + Household layout: + hh0: persons 0,1 (tax_unit 0, spm_unit 0) + hh1: persons 2,3 (tax_unit 1, spm_unit 1) + hh2: persons 4,5 (tax_unit 2, spm_unit 2) + hh3: persons 6,7 (tax_unit 3, spm_unit 3) + hh4: persons 8,9 (tax_unit 4, spm_unit 4) + """ + sim = MagicMock() + + person_hh_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + person_tu_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + person_spm_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + hh_ids = np.arange(n_households) + + # Person-level data + person_medicaid = np.array( + [100, 0, 200, 0, 0, 0, 300, 0, 0, 0], dtype=float + ) + person_state_fips = np.array([6, 6, 6, 6, 36, 36, 36, 36, 6, 6], dtype=int) + # Tax-unit-level data (5 tax units) + tu_is_filer = np.array([1, 1, 0, 1, 1], dtype=float) + tu_agi = np.array([75000, 120000, 30000, 60000, 200000], dtype=float) + tu_income_tax_positive = np.array( + [8000, 15000, 0, 5000, 40000], dtype=float + ) + # Household-level data + hh_medicaid = np.array([100, 200, 0, 300, 0], dtype=float) + hh_net_worth = np.array( + [500000, 1000000, 200000, 50000, 3000000], dtype=float + ) + hh_snap = np.array([0, 5000, 0, 3000, 0], dtype=float) + hh_person_count = np.array([2, 2, 2, 2, 2], dtype=float) + + # Map tax_unit vars to person level + person_is_filer = tu_is_filer[person_tu_ids] + person_agi = tu_agi[person_tu_ids] + person_income_tax_positive = tu_income_tax_positive[person_tu_ids] + + def calculate_side_effect(var, period=None, map_to=None): + """Mock sim.calculate() returning appropriate arrays.""" + result = MagicMock() + + if map_to == "person": + mapping = { + "household_id": person_hh_ids, + "tax_unit_id": person_tu_ids, + "spm_unit_id": person_spm_ids, + "person_id": np.arange(n_persons), + "medicaid": person_medicaid, + "state_fips": person_state_fips, + "tax_unit_is_filer": person_is_filer, + "adjusted_gross_income": person_agi, + "income_tax_positive": person_income_tax_positive, + "person_count": np.ones(n_persons, dtype=float), + "snap": hh_snap[person_hh_ids].astype(float), + "net_worth": hh_net_worth[person_hh_ids].astype(float), + } + result.values = mapping.get(var, np.zeros(n_persons, dtype=float)) + elif map_to == "household": + mapping = { + "household_id": hh_ids, + "medicaid": hh_medicaid, + "net_worth": hh_net_worth, + "snap": hh_snap, + "income_tax_positive": np.array( + [8000, 15000, 0, 5000, 40000], dtype=float + ), + "person_count": hh_person_count, + "adjusted_gross_income": np.array( + [75000, 120000, 30000, 60000, 200000], + dtype=float, + ), + "tax_unit_is_filer": tu_is_filer, + "state_fips": np.array([6, 6, 36, 36, 6], dtype=int), + } + result.values = mapping.get( + var, np.zeros(n_households, dtype=float) + ) + else: + # Default: tax_unit level + mapping = { + "tax_unit_is_filer": tu_is_filer, + "adjusted_gross_income": tu_agi, + "income_tax_positive": tu_income_tax_positive, + } + result.values = mapping.get(var, np.zeros(5, dtype=float)) + + return result + + sim.calculate = calculate_side_effect + + def map_result_side_effect(values, from_entity, to_entity, how=None): + """Mock sim.map_result() for person->household sum.""" + if from_entity == "person" and to_entity == "household": + result = np.zeros(n_households, dtype=float) + for i in range(n_persons): + result[person_hh_ids[i]] += float(values[i]) + return result + elif from_entity == "tax_unit" and to_entity == "household": + return np.array(values, dtype=float)[:n_households] + return values + + sim.map_result = map_result_side_effect + + return sim + + +# ------------------------------------------------------------------- +# Fixtures +# ------------------------------------------------------------------- + + +def _seed_db(engine, include_inactive=False, include_geo=True): + """Populate an in-memory SQLite DB with test targets. + + Returns the engine for convenience. + """ + with Session(engine) as session: + source = Source( + name="Test", + type=SourceType.HARDCODED, + ) + session.add(source) + session.flush() + + # --- National stratum (no constraints) --- + us_stratum = Stratum( + stratum_group_id=0, + notes="United States", + ) + us_stratum.constraints_rel = [] + session.add(us_stratum) + session.flush() + + # Target 1: national medicaid sum + session.add( + Target( + stratum_id=us_stratum.stratum_id, + variable="medicaid", + period=2024, + value=871.7e9, + source_id=source.source_id, + active=True, + ) + ) + + # Target 2: national net_worth sum + session.add( + Target( + stratum_id=us_stratum.stratum_id, + variable="net_worth", + period=2024, + value=160e12, + source_id=source.source_id, + active=True, + ) + ) + + # --- National filer stratum --- + filer_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=2, + notes="United States - Tax Filers", + ) + filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(filer_stratum) + session.flush() + + # Target 3: income_tax_positive on filer stratum + session.add( + Target( + stratum_id=filer_stratum.stratum_id, + variable="income_tax_positive", + period=2024, + value=2.5e12, + source_id=source.source_id, + active=True, + ) + ) + + # --- AGI band stratum --- + agi_stratum = Stratum( + parent_stratum_id=filer_stratum.stratum_id, + stratum_group_id=3, + notes="National filers, AGI 50k-100k", + ) + agi_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation=">=", + value="50000", + ), + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation="<", + value="100000", + ), + ] + session.add(agi_stratum) + session.flush() + + # Target 4: person_count in AGI band + session.add( + Target( + stratum_id=agi_stratum.stratum_id, + variable="person_count", + period=2024, + value=32_801_908, + source_id=source.source_id, + active=True, + ) + ) + + # --- Medicaid enrollment stratum --- + medicaid_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=5, + notes="National Medicaid Enrollment", + ) + medicaid_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="medicaid", + operation=">", + value="0", + ) + ] + session.add(medicaid_stratum) + session.flush() + + # Target 5: person_count with medicaid > 0 + session.add( + Target( + stratum_id=medicaid_stratum.stratum_id, + variable="person_count", + period=2024, + value=72_429_055, + source_id=source.source_id, + active=True, + ) + ) + + if include_geo: + # --- State geographic stratum (California) --- + ca_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=1, + notes="State FIPS 6 - California", + ) + ca_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ) + ] + session.add(ca_stratum) + session.flush() + + # Target 6: snap in CA + session.add( + Target( + stratum_id=ca_stratum.stratum_id, + variable="snap", + period=2024, + value=10e9, + source_id=source.source_id, + active=True, + ) + ) + + if include_inactive: + # Inactive target that should be excluded + session.add( + Target( + stratum_id=us_stratum.stratum_id, + variable="ssi", + period=2024, + value=60e9, + source_id=source.source_id, + active=False, + ) + ) + + session.commit() + + return engine + + +@pytest.fixture +def mock_db(): + """Create an in-memory SQLite DB with representative targets.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + return _seed_db(engine) + + +@pytest.fixture +def mock_db_with_inactive(): + """DB that includes both active and inactive targets.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + return _seed_db(engine, include_inactive=True) + + +@pytest.fixture +def empty_db(): + """An empty DB with tables but no rows.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + return engine + + +@pytest.fixture +def mock_sim(): + """Standard mock Microsimulation with 5 households / 10 persons.""" + return _make_mock_sim() + + +def _builder_with_engine(engine, time_period=2024): + """Create a NationalMatrixBuilder and inject the test engine.""" + builder = NationalMatrixBuilder( + db_uri="sqlite://", + time_period=time_period, + ) + builder.engine = engine + return builder + + +# ------------------------------------------------------------------- +# Module-level constants tests +# ------------------------------------------------------------------- + + +class TestModuleConstants: + """Verify that module-level constant sets are defined.""" + + def test_count_variables_not_empty(self): + assert len(COUNT_VARIABLES) > 0 + + def test_person_count_is_count_variable(self): + assert "person_count" in COUNT_VARIABLES + + def test_person_count_is_person_level(self): + assert "person_count" in PERSON_LEVEL_VARIABLES + + def test_spm_unit_variables_defined(self): + assert isinstance(SPM_UNIT_VARIABLES, set) + + +# ------------------------------------------------------------------- +# Database query tests +# ------------------------------------------------------------------- + + +class TestQueryAllTargets: + """Test _query_active_targets reads the right rows.""" + + def test_returns_all_active_targets(self, mock_db): + builder = _builder_with_engine(mock_db) + df = builder._query_active_targets() + + # 6 active targets were inserted (no inactive) + assert len(df) == 6 + + def test_excludes_inactive_targets(self, mock_db_with_inactive): + builder = _builder_with_engine(mock_db_with_inactive) + df = builder._query_active_targets() + + # Still 6 active targets; the inactive "ssi" should be excluded + assert len(df) == 6 + assert "ssi" not in df["variable"].values + + def test_required_columns_present(self, mock_db): + builder = _builder_with_engine(mock_db) + df = builder._query_active_targets() + + for col in [ + "target_id", + "stratum_id", + "variable", + "value", + "period", + ]: + assert col in df.columns + + def test_empty_db_returns_empty(self, empty_db): + builder = _builder_with_engine(empty_db) + df = builder._query_active_targets() + assert len(df) == 0 + + +class TestGetConstraints: + """Test _get_all_constraints retrieves constraint rows.""" + + def test_no_constraints_for_national(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + + # National medicaid has stratum with no constraints + national_row = targets_df[targets_df["variable"] == "medicaid"].iloc[0] + constraints = builder._get_all_constraints(national_row["stratum_id"]) + assert constraints == [] + + def test_single_constraint_filer(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + + filer_row = targets_df[ + targets_df["variable"] == "income_tax_positive" + ].iloc[0] + constraints = builder._get_all_constraints(filer_row["stratum_id"]) + assert len(constraints) == 1 + assert constraints[0]["variable"] == "tax_unit_is_filer" + assert constraints[0]["operation"] == "==" + assert constraints[0]["value"] == "1" + + def test_multiple_constraints_agi_band(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + + # person_count target with AGI band constraints + agi_row = targets_df[ + (targets_df["variable"] == "person_count") + & (targets_df["value"] == 32_801_908) + ].iloc[0] + constraints = builder._get_all_constraints(agi_row["stratum_id"]) + assert len(constraints) == 3 + + var_names = {c["variable"] for c in constraints} + assert "tax_unit_is_filer" in var_names + assert "adjusted_gross_income" in var_names + + def test_nonexistent_stratum_returns_empty(self, mock_db): + builder = _builder_with_engine(mock_db) + constraints = builder._get_all_constraints(99999) + assert constraints == [] + + +# ------------------------------------------------------------------- +# Constraint evaluation tests +# ------------------------------------------------------------------- + + +class TestEvaluateConstraints: + """Test _evaluate_constraints mask computation.""" + + def test_empty_constraints_returns_all_true(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + mask = builder._evaluate_constraints_entity_aware( + mock_sim, [], n_households=5 + ) + assert mask.shape == (5,) + assert np.all(mask) + + def test_filer_constraint_mask(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + } + ] + mask = builder._evaluate_constraints_entity_aware( + mock_sim, constraints, n_households=5 + ) + # tu_is_filer = [1, 1, 0, 1, 1] + # hh2 should be False (non-filer) + assert mask.dtype == bool + assert mask[0] is np.True_ + assert mask[1] is np.True_ + assert mask[2] is np.False_ + assert mask[3] is np.True_ + assert mask[4] is np.True_ + + def test_compound_constraints_agi_band(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + }, + { + "variable": "adjusted_gross_income", + "operation": ">=", + "value": "50000", + }, + { + "variable": "adjusted_gross_income", + "operation": "<", + "value": "100000", + }, + ] + mask = builder._evaluate_constraints_entity_aware( + mock_sim, constraints, n_households=5 + ) + # AGI: [75k, 120k, 30k, 60k, 200k] + # Filer: [1, 1, 0, 1, 1] + # In band AND filer: hh0(75k), hh3(60k) + assert mask[0] is np.True_ + assert mask[1] is np.False_ # AGI too high + assert mask[2] is np.False_ # not filer + assert mask[3] is np.True_ + assert mask[4] is np.False_ # AGI too high + + def test_geographic_constraint(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + } + ] + mask = builder._evaluate_constraints_entity_aware( + mock_sim, constraints, n_households=5 + ) + # person_state_fips = [6,6, 6,6, 36,36, 36,36, 6,6] + # hh0(CA), hh1(CA), hh2(NY), hh3(NY), hh4(CA) + assert mask[0] is np.True_ + assert mask[1] is np.True_ + assert mask[2] is np.False_ + assert mask[3] is np.False_ + assert mask[4] is np.True_ + + +# ------------------------------------------------------------------- +# Entity relationship cache tests +# ------------------------------------------------------------------- + + +class TestEntityRelationshipCache: + """Test that entity relationship is built once and cached.""" + + def test_cache_is_populated_on_first_call(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + assert builder._entity_rel_cache is None + + df = builder._build_entity_relationship(mock_sim) + assert builder._entity_rel_cache is not None + assert len(df) == 10 # n_persons + + def test_cache_is_reused(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + + df1 = builder._build_entity_relationship(mock_sim) + df2 = builder._build_entity_relationship(mock_sim) + + # Should be the exact same object (cached) + assert df1 is df2 + + def test_entity_rel_has_required_columns(self, mock_sim): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + df = builder._build_entity_relationship(mock_sim) + for col in [ + "person_id", + "household_id", + "tax_unit_id", + "spm_unit_id", + ]: + assert col in df.columns + + +# ------------------------------------------------------------------- +# Target column computation tests +# ------------------------------------------------------------------- + + +class TestComputeTargetColumn: + """Test _compute_target_column for different variable types.""" + + def test_sum_variable_no_constraints(self, mock_sim): + """Unconstrained sum variable returns raw household values.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + col = builder._compute_target_column( + mock_sim, "medicaid", [], n_households=5 + ) + expected = np.array([100, 200, 0, 300, 0], dtype=float) + np.testing.assert_array_almost_equal(col, expected) + + def test_sum_variable_with_constraint(self, mock_sim): + """Constrained sum variable zeros out non-matching.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + } + ] + col = builder._compute_target_column( + mock_sim, "snap", constraints, n_households=5 + ) + # hh_snap = [0, 5000, 0, 3000, 0] + # CA mask = [T, T, F, F, T] + expected = np.array([0, 5000, 0, 0, 0], dtype=float) + np.testing.assert_array_almost_equal(col, expected) + + def test_person_count_no_constraints(self, mock_sim): + """person_count with no constraints counts all persons.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + col = builder._compute_target_column( + mock_sim, "person_count", [], n_households=5 + ) + # Each household has 2 persons + expected = np.array([2, 2, 2, 2, 2], dtype=float) + np.testing.assert_array_almost_equal(col, expected) + + def test_person_count_with_constraints(self, mock_sim): + """person_count with medicaid > 0 counts qualifying persons + per household.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "medicaid", + "operation": ">", + "value": "0", + } + ] + col = builder._compute_target_column( + mock_sim, "person_count", constraints, n_households=5 + ) + # person_medicaid = [100,0, 200,0, 0,0, 300,0, 0,0] + # hh0: 1 person, hh1: 1 person, hh2: 0, hh3: 1, hh4: 0 + expected = np.array([1, 1, 0, 1, 0], dtype=float) + np.testing.assert_array_almost_equal(col, expected) + + def test_tax_unit_count_returns_mask(self, mock_sim): + """tax_unit_count returns the household-level mask as float.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + } + ] + col = builder._compute_target_column( + mock_sim, + "tax_unit_count", + constraints, + n_households=5, + ) + # Filer mask: [T, T, F, T, T] + expected = np.array([1, 1, 0, 1, 1], dtype=float) + np.testing.assert_array_almost_equal(col, expected) + + def test_column_dtype_is_float64(self, mock_sim): + """All columns should be float64.""" + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + col = builder._compute_target_column( + mock_sim, "medicaid", [], n_households=5 + ) + assert col.dtype == np.float64 + + +# ------------------------------------------------------------------- +# Target name generation tests +# ------------------------------------------------------------------- + + +class TestMakeTargetName: + """Test _make_target_name label generation.""" + + def test_national_unconstrained(self): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + name = builder._make_target_name("medicaid", [], "United States") + assert "national" in name + assert "medicaid" in name + + def test_geographic_constraint_in_name(self): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + } + ] + name = builder._make_target_name("snap", constraints, "California") + assert "state_6" in name + assert "snap" in name + # Should NOT have "national" prefix + assert "national" not in name + + def test_non_geo_constraints_in_brackets(self): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + } + ] + name = builder._make_target_name( + "income_tax_positive", + constraints, + "Filers", + ) + assert "national" in name + assert "income_tax_positive" in name + assert "[" in name + assert "tax_unit_is_filer" in name + + def test_mixed_geo_and_non_geo(self): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + }, + { + "variable": "tax_unit_is_filer", + "operation": "==", + "value": "1", + }, + ] + name = builder._make_target_name("eitc", constraints, "CA filers") + assert "state_6" in name + assert "eitc" in name + assert "tax_unit_is_filer" in name + + def test_congressional_district_geo(self): + builder = NationalMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "congressional_district_geoid", + "operation": "==", + "value": "0601", + }, + ] + name = builder._make_target_name("snap", constraints, "CD 0601") + assert "cd_0601" in name + + +# ------------------------------------------------------------------- +# Full build_matrix integration tests +# ------------------------------------------------------------------- + + +class TestBuildMatrix: + """Test the main build_matrix method end-to-end.""" + + def test_matrix_shape(self, mock_db, mock_sim): + """Matrix shape is (n_households, n_active_targets).""" + builder = _builder_with_engine(mock_db) + matrix, targets, names = builder.build_matrix(mock_sim) + + assert matrix.shape == (5, 6) + assert targets.shape == (6,) + assert len(names) == 6 + + def test_target_values_match_db(self, mock_db, mock_sim): + """Target values array matches what was inserted.""" + builder = _builder_with_engine(mock_db) + _, targets, names = builder.build_matrix(mock_sim) + + expected_values = { + 871.7e9, + 160e12, + 2.5e12, + 32_801_908, + 72_429_055, + 10e9, + } + actual_values = set(targets) + assert actual_values == expected_values + + def test_all_names_non_empty(self, mock_db, mock_sim): + """Every target name is a non-empty string.""" + builder = _builder_with_engine(mock_db) + _, _, names = builder.build_matrix(mock_sim) + + for name in names: + assert isinstance(name, str) + assert len(name) > 0 + + def test_no_nan_in_matrix(self, mock_db, mock_sim): + """Matrix and targets contain no NaN values.""" + builder = _builder_with_engine(mock_db) + matrix, targets, _ = builder.build_matrix(mock_sim) + + assert not np.any(np.isnan(matrix)) + assert not np.any(np.isnan(targets)) + + def test_matrix_dtype_is_float64(self, mock_db, mock_sim): + """Matrix dtype should be float64.""" + builder = _builder_with_engine(mock_db) + matrix, targets, _ = builder.build_matrix(mock_sim) + + assert matrix.dtype == np.float64 + assert targets.dtype == np.float64 + + def test_unconstrained_medicaid_column(self, mock_db, mock_sim): + """National medicaid column = raw household medicaid.""" + builder = _builder_with_engine(mock_db) + matrix, targets, names = builder.build_matrix(mock_sim) + + # Find medicaid column (not the person_count one) + idx = next( + i + for i, n in enumerate(names) + if "medicaid" in n and "person_count" not in n + ) + expected = np.array([100, 200, 0, 300, 0], dtype=float) + np.testing.assert_array_almost_equal(matrix[:, idx], expected) + assert targets[idx] == pytest.approx(871.7e9) + + def test_filer_masked_income_tax(self, mock_db, mock_sim): + """income_tax_positive column zeros out non-filer hh.""" + builder = _builder_with_engine(mock_db) + matrix, _, names = builder.build_matrix(mock_sim) + + idx = next( + i for i, n in enumerate(names) if "income_tax_positive" in n + ) + col = matrix[:, idx] + # hh2 is non-filer -> 0 + assert col[2] == 0 + # filer households have positive values + assert col[0] > 0 + assert col[1] > 0 + assert col[3] > 0 + assert col[4] > 0 + + def test_agi_band_person_count(self, mock_db, mock_sim): + """person_count in AGI 50k-100k band.""" + builder = _builder_with_engine(mock_db) + matrix, targets, names = builder.build_matrix(mock_sim) + + # Find the person_count target with value 32_801_908 + idx = next( + i + for i, n in enumerate(names) + if "person_count" in n and targets[i] == pytest.approx(32_801_908) + ) + col = matrix[:, idx] + # hh0 (AGI 75k, filer) -> in band, 2 persons qualifying + # hh3 (AGI 60k, filer) -> in band, 2 persons qualifying + assert col[0] > 0 + assert col[3] > 0 + # hh1 (AGI 120k), hh2 (not filer), hh4 (AGI 200k) -> 0 + assert col[1] == 0 + assert col[2] == 0 + assert col[4] == 0 + + def test_medicaid_enrollment_count(self, mock_db, mock_sim): + """person_count with medicaid > 0 counts correctly.""" + builder = _builder_with_engine(mock_db) + matrix, targets, names = builder.build_matrix(mock_sim) + + idx = next( + i + for i, n in enumerate(names) + if "person_count" in n and targets[i] == pytest.approx(72_429_055) + ) + col = matrix[:, idx] + # person_medicaid = [100,0, 200,0, 0,0, 300,0, 0,0] + # hh0:1, hh1:1, hh2:0, hh3:1, hh4:0 + assert col[0] == pytest.approx(1.0) + assert col[1] == pytest.approx(1.0) + assert col[2] == pytest.approx(0.0) + assert col[3] == pytest.approx(1.0) + assert col[4] == pytest.approx(0.0) + + def test_geographic_snap_target(self, mock_db, mock_sim): + """snap in CA only includes CA households.""" + builder = _builder_with_engine(mock_db) + matrix, _, names = builder.build_matrix(mock_sim) + + idx = next(i for i, n in enumerate(names) if "snap" in n) + col = matrix[:, idx] + # CA mask = [T, T, F, F, T] + # hh_snap = [0, 5000, 0, 3000, 0] + assert col[0] == 0 # CA, no snap + assert col[1] == 5000 # CA, has snap + assert col[2] == 0 # not CA + assert col[3] == 0 # not CA (snap zeroed) + assert col[4] == 0 # CA, no snap + + def test_constraint_cache_populated(self, mock_db, mock_sim): + """After build_matrix, constraint cache should be non-empty.""" + builder = _builder_with_engine(mock_db) + # Call build_matrix to populate internal state + builder.build_matrix(mock_sim) + + # The entity relationship cache should be set + assert builder._entity_rel_cache is not None + + def test_empty_db_raises_value_error(self, empty_db, mock_sim): + """Empty database raises ValueError.""" + builder = _builder_with_engine(empty_db) + with pytest.raises(ValueError, match="No active targets"): + builder.build_matrix(mock_sim) + + +# ------------------------------------------------------------------- +# Inactive target filtering test +# ------------------------------------------------------------------- + + +class TestInactiveTargetFiltering: + """Verify that active=False targets are excluded.""" + + def test_inactive_excluded_from_matrix( + self, mock_db_with_inactive, mock_sim + ): + builder = _builder_with_engine(mock_db_with_inactive) + matrix, targets, names = builder.build_matrix(mock_sim) + + # Should still have 6 active targets, not 7 + assert matrix.shape[1] == 6 + assert len(targets) == 6 + + # "ssi" should not appear in any name + for name in names: + assert "ssi" not in name.lower() + + +# ------------------------------------------------------------------- +# Edge case tests +# ------------------------------------------------------------------- + + +class TestEdgeCases: + """Edge cases and boundary conditions.""" + + def test_single_target_db(self, mock_sim): + """DB with only one active target produces 1-column matrix.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="US") + stratum.constraints_rel = [] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="medicaid", + period=2024, + value=500e9, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + matrix, targets, names = builder.build_matrix(mock_sim) + + assert matrix.shape == (5, 1) + assert len(targets) == 1 + assert targets[0] == pytest.approx(500e9) + + def test_all_households_masked_out(self, mock_sim): + """A constraint that no household satisfies produces + a zero column.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="Impossible") + stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="99", # no household has this + ) + ] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="snap", + period=2024, + value=1e9, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + matrix, _, _ = builder.build_matrix(mock_sim) + + # All values in the column should be 0 + np.testing.assert_array_equal(matrix[:, 0], np.zeros(5)) + + def test_large_target_value_preserved(self, mock_sim): + """Very large target values are not corrupted.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + big_value = 1.5e15 + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="US") + stratum.constraints_rel = [] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="net_worth", + period=2024, + value=big_value, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + _, targets, _ = builder.build_matrix(mock_sim) + assert targets[0] == pytest.approx(big_value) + + def test_zero_target_value_filtered_out(self, mock_sim): + """A target with value=0 is filtered out (not useful for + calibration), raising ValueError when it's the only target.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="US") + stratum.constraints_rel = [] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="medicaid", + period=2024, + value=0.0, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + with pytest.raises(ValueError, match="zero or null"): + builder.build_matrix(mock_sim) + + def test_multiple_targets_same_stratum(self, mock_sim): + """Multiple targets sharing one stratum produce separate + columns with the same mask.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="US") + stratum.constraints_rel = [] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="medicaid", + period=2024, + value=100e9, + source_id=source.source_id, + active=True, + ) + ) + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="snap", + period=2024, + value=50e9, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + matrix, targets, names = builder.build_matrix(mock_sim) + + assert matrix.shape == (5, 2) + assert len(targets) == 2 + + +# ------------------------------------------------------------------- +# apply_op utility tests (used by constraint evaluation) +# ------------------------------------------------------------------- + + +class TestApplyOp: + """Test the apply_op utility used for constraint evaluation.""" + + def test_equality(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, "==", "3") + expected = np.array([False, False, True, False, False]) + np.testing.assert_array_equal(result, expected) + + def test_greater_than(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, ">", "3") + expected = np.array([False, False, False, True, True]) + np.testing.assert_array_equal(result, expected) + + def test_less_than(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, "<", "3") + expected = np.array([True, True, False, False, False]) + np.testing.assert_array_equal(result, expected) + + def test_gte(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, ">=", "3") + expected = np.array([False, False, True, True, True]) + np.testing.assert_array_equal(result, expected) + + def test_lte(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, "<=", "3") + expected = np.array([True, True, True, False, False]) + np.testing.assert_array_equal(result, expected) + + def test_not_equal(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1, 2, 3, 4, 5]) + result = apply_op(vals, "!=", "3") + expected = np.array([True, True, False, True, True]) + np.testing.assert_array_equal(result, expected) + + def test_float_value_parsing(self): + from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + apply_op, + ) + + vals = np.array([1.5, 2.5, 3.5, 4.5]) + result = apply_op(vals, ">", "2.5") + expected = np.array([False, False, True, True]) + np.testing.assert_array_equal(result, expected) diff --git a/policyengine_us_data/tests/test_calibration/test_puf_impute.py b/policyengine_us_data/tests/test_calibration/test_puf_impute.py new file mode 100644 index 000000000..e5ee313d7 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_puf_impute.py @@ -0,0 +1,175 @@ +"""Tests for puf_impute module. + +Verifies PUF clone + QRF imputation logic using mock data +so tests don't require real CPS/PUF datasets. +""" + +import numpy as np +import pandas as pd +import pytest +from unittest.mock import patch, MagicMock + +from policyengine_us_data.calibration.puf_impute import ( + puf_clone_dataset, + DEMOGRAPHIC_PREDICTORS, + IMPUTED_VARIABLES, +) + +# ------------------------------------------------------------------ +# Mock helpers +# ------------------------------------------------------------------ + + +def _make_mock_data(n_persons=20, n_households=5, time_period=2024): + """Build a minimal mock CPS data dict. + + Returns a dict of {variable: {time_period: array}} matching + the Dataset.TIME_PERIOD_ARRAYS format. + """ + # Person-level IDs and demographics + person_ids = np.arange(1, n_persons + 1) + # 4 persons per household + household_ids_person = np.repeat( + np.arange(1, n_households + 1), n_persons // n_households + ) + tax_unit_ids_person = household_ids_person.copy() + spm_unit_ids_person = household_ids_person.copy() + + ages = np.random.default_rng(42).integers(18, 80, size=n_persons) + is_male = np.random.default_rng(42).integers(0, 2, size=n_persons) + + data = { + "person_id": {time_period: person_ids}, + "household_id": {time_period: np.arange(1, n_households + 1)}, + "tax_unit_id": {time_period: np.arange(1, n_households + 1)}, + "spm_unit_id": {time_period: np.arange(1, n_households + 1)}, + "person_household_id": {time_period: household_ids_person}, + "person_tax_unit_id": {time_period: tax_unit_ids_person}, + "person_spm_unit_id": {time_period: spm_unit_ids_person}, + "age": {time_period: ages.astype(np.float32)}, + "is_male": {time_period: is_male.astype(np.float32)}, + "household_weight": {time_period: np.ones(n_households) * 1000}, + "employment_income": { + time_period: np.random.default_rng(42).uniform( + 0, 100000, n_persons + ) + }, + } + return data + + +# ------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------ + + +class TestPufCloneDataset: + """Tests for puf_clone_dataset.""" + + def test_doubles_records(self): + """Output has 2x the records of input.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) # per household + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + # Households should double + assert len(result["household_id"][2024]) == 10 + # Persons should double + assert len(result["person_id"][2024]) == 40 + + def test_ids_are_unique(self): + """Person and household IDs are unique across both halves.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + person_ids = result["person_id"][2024] + household_ids = result["household_id"][2024] + assert len(np.unique(person_ids)) == len(person_ids) + assert len(np.unique(household_ids)) == len(household_ids) + + def test_puf_half_weight_zero(self): + """PUF half has zero household weights.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + weights = result["household_weight"][2024] + # First half: original weights + assert np.all(weights[:5] > 0) + # Second half: zero weights (PUF copy) + assert np.all(weights[5:] == 0) + + def test_state_fips_preserved(self): + """State FIPS doubles along with records.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + result_states = result["state_fips"][2024] + # Both halves should have the same state assignments + np.testing.assert_array_equal(result_states[:5], state_fips) + np.testing.assert_array_equal(result_states[5:], state_fips) + + def test_demographics_shared(self): + """Both halves share the same demographic values.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + ages = result["age"][2024] + n = len(ages) // 2 + np.testing.assert_array_equal(ages[:n], ages[n:]) + + def test_n_records_output(self): + """Returns correct n_households_out count.""" + data = _make_mock_data(n_persons=20, n_households=5) + state_fips = np.array([1, 2, 36, 6, 48]) + + result = puf_clone_dataset( + data=data, + state_fips=state_fips, + time_period=2024, + skip_qrf=True, + ) + + # Should have 10 households total + assert len(result["household_id"][2024]) == 10 + + def test_demographic_predictors_list(self): + """DEMOGRAPHIC_PREDICTORS includes state_fips.""" + assert "state_fips" in DEMOGRAPHIC_PREDICTORS + + def test_imputed_variables_not_empty(self): + """IMPUTED_VARIABLES list is populated.""" + assert len(IMPUTED_VARIABLES) > 0 diff --git a/policyengine_us_data/tests/test_calibration/test_source_impute.py b/policyengine_us_data/tests/test_calibration/test_source_impute.py new file mode 100644 index 000000000..ada09760f --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_source_impute.py @@ -0,0 +1,240 @@ +""" +Tests for source_impute module — ACS/SIPP/SCF imputations +with state_fips as QRF predictor. + +Uses mocks to avoid loading real donor data (ACS, SIPP, SCF). +""" + +import sys +from unittest.mock import MagicMock, patch + +import numpy as np +import pandas as pd +import pytest + +# ------------------------------------------------------------------- +# Constants tests +# ------------------------------------------------------------------- + + +class TestConstants: + """Test that module constants are defined correctly.""" + + def test_acs_variables_defined(self): + from policyengine_us_data.calibration.source_impute import ( + ACS_IMPUTED_VARIABLES, + ) + + assert "rent" in ACS_IMPUTED_VARIABLES + assert "real_estate_taxes" in ACS_IMPUTED_VARIABLES + + def test_sipp_variables_defined(self): + from policyengine_us_data.calibration.source_impute import ( + SIPP_IMPUTED_VARIABLES, + ) + + assert "tip_income" in SIPP_IMPUTED_VARIABLES + assert "bank_account_assets" in SIPP_IMPUTED_VARIABLES + assert "stock_assets" in SIPP_IMPUTED_VARIABLES + assert "bond_assets" in SIPP_IMPUTED_VARIABLES + + def test_scf_variables_defined(self): + from policyengine_us_data.calibration.source_impute import ( + SCF_IMPUTED_VARIABLES, + ) + + assert "net_worth" in SCF_IMPUTED_VARIABLES + assert "auto_loan_balance" in SCF_IMPUTED_VARIABLES + assert "auto_loan_interest" in SCF_IMPUTED_VARIABLES + + def test_all_source_variables_defined(self): + from policyengine_us_data.calibration.source_impute import ( + ALL_SOURCE_VARIABLES, + ACS_IMPUTED_VARIABLES, + SIPP_IMPUTED_VARIABLES, + SCF_IMPUTED_VARIABLES, + ) + + expected = ( + ACS_IMPUTED_VARIABLES + + SIPP_IMPUTED_VARIABLES + + SCF_IMPUTED_VARIABLES + ) + assert ALL_SOURCE_VARIABLES == expected + + +# ------------------------------------------------------------------- +# impute_source_variables tests +# ------------------------------------------------------------------- + + +class TestImputeSourceVariables: + """Test main entry point.""" + + def _make_data_dict(self, n_persons=100, time_period=2024): + """Create a minimal data dict for testing.""" + n_hh = n_persons // 2 + rng = np.random.default_rng(42) + return { + "person_id": { + time_period: np.arange(n_persons), + }, + "household_id": { + time_period: np.arange(n_hh), + }, + "person_household_id": { + time_period: np.repeat(np.arange(n_hh), 2), + }, + "age": { + time_period: rng.integers(18, 80, n_persons).astype( + np.float32 + ), + }, + "employment_income": { + time_period: rng.uniform(0, 100000, n_persons).astype( + np.float32 + ), + }, + # Variables that will be overwritten + "rent": { + time_period: np.zeros(n_persons), + }, + "real_estate_taxes": { + time_period: np.zeros(n_persons), + }, + "tip_income": { + time_period: np.zeros(n_persons), + }, + "bank_account_assets": { + time_period: np.zeros(n_persons), + }, + "stock_assets": { + time_period: np.zeros(n_persons), + }, + "bond_assets": { + time_period: np.zeros(n_persons), + }, + "net_worth": { + time_period: np.zeros(n_persons), + }, + "auto_loan_balance": { + time_period: np.zeros(n_persons), + }, + "auto_loan_interest": { + time_period: np.zeros(n_persons), + }, + } + + def test_function_exists(self): + """impute_source_variables is importable.""" + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + assert callable(impute_source_variables) + + def test_returns_dict(self): + """Returns a dict with same keys as input.""" + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + data = self._make_data_dict(n_persons=20) + state_fips = np.ones(10, dtype=np.int32) * 6 + + result = impute_source_variables( + data=data, + state_fips=state_fips, + time_period=2024, + skip_acs=True, + skip_sipp=True, + skip_scf=True, + ) + assert isinstance(result, dict) + # All original keys preserved + for key in data: + assert key in result + + def test_skip_flags_work(self): + """When all skip flags True, data unchanged.""" + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + data = self._make_data_dict(n_persons=20) + state_fips = np.ones(10, dtype=np.int32) * 6 + + result = impute_source_variables( + data=data, + state_fips=state_fips, + time_period=2024, + skip_acs=True, + skip_sipp=True, + skip_scf=True, + ) + + for var in [ + "rent", + "real_estate_taxes", + "tip_income", + "net_worth", + ]: + np.testing.assert_array_equal(result[var][2024], data[var][2024]) + + def test_state_fips_added_to_data(self): + """state_fips is added to the returned data dict.""" + from policyengine_us_data.calibration.source_impute import ( + impute_source_variables, + ) + + data = self._make_data_dict(n_persons=20) + state_fips = np.ones(10, dtype=np.int32) * 6 + + result = impute_source_variables( + data=data, + state_fips=state_fips, + time_period=2024, + skip_acs=True, + skip_sipp=True, + skip_scf=True, + ) + + assert "state_fips" in result + + +# ------------------------------------------------------------------- +# Individual source imputation tests +# ------------------------------------------------------------------- + + +class TestACSImputation: + """Test _impute_acs function.""" + + def test_function_exists(self): + from policyengine_us_data.calibration.source_impute import ( + _impute_acs, + ) + + assert callable(_impute_acs) + + +class TestSIPPImputation: + """Test _impute_sipp function.""" + + def test_function_exists(self): + from policyengine_us_data.calibration.source_impute import ( + _impute_sipp, + ) + + assert callable(_impute_sipp) + + +class TestSCFImputation: + """Test _impute_scf function.""" + + def test_function_exists(self): + from policyengine_us_data.calibration.source_impute import ( + _impute_scf, + ) + + assert callable(_impute_scf) diff --git a/policyengine_us_data/tests/test_calibration/test_unified_calibration.py b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py new file mode 100644 index 000000000..9bf1daf1c --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_unified_calibration.py @@ -0,0 +1,608 @@ +""" +Tests for unified L0 calibration pipeline +(unified_calibration.py). + +Uses TDD: tests written first, then implementation. +Mocks heavy dependencies (torch, l0, Microsimulation, +clone_and_assign, UnifiedMatrixBuilder) to keep tests fast. +""" + +import sys +from unittest.mock import MagicMock, patch, call + +import numpy as np +import pandas as pd +import pytest +import scipy.sparse + +# ------------------------------------------------------------------- +# CLI argument parsing tests +# ------------------------------------------------------------------- + + +class TestParseArgs: + """Test CLI argument parsing for unified calibration.""" + + def test_parse_args_defaults(self): + """Default values when no arguments are passed.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + DEFAULT_EPOCHS, + DEFAULT_N_CLONES, + ) + + args = parse_args([]) + assert args.dataset is None + assert args.db_path is None + assert args.output is None + assert args.n_clones == DEFAULT_N_CLONES + assert args.preset is None + assert args.lambda_l0 is None + assert args.epochs == DEFAULT_EPOCHS + assert args.device == "cpu" + assert args.seed == 42 + + def test_parse_args_preset_local(self): + """--preset local is accepted.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + PRESETS, + ) + + args = parse_args(["--preset", "local"]) + assert args.preset == "local" + # Verify the preset maps to expected lambda + assert PRESETS["local"] == 1e-8 + + def test_parse_args_preset_national(self): + """--preset national is accepted.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + PRESETS, + ) + + args = parse_args(["--preset", "national"]) + assert args.preset == "national" + assert PRESETS["national"] == 1e-4 + + def test_parse_args_lambda_overrides_preset(self): + """--lambda-l0 should be available alongside --preset. + + The actual override logic is in main(), not parse_args. + Here we verify both values are stored independently. + """ + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args(["--preset", "local", "--lambda-l0", "5e-6"]) + assert args.preset == "local" + assert args.lambda_l0 == 5e-6 + + def test_parse_args_all_options(self): + """All CLI options are parsed correctly.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + args = parse_args( + [ + "--dataset", + "/tmp/data.h5", + "--db-path", + "/tmp/db.sqlite", + "--output", + "/tmp/weights.npy", + "--n-clones", + "50", + "--preset", + "national", + "--lambda-l0", + "1e-3", + "--epochs", + "200", + "--device", + "cuda", + "--seed", + "99", + ] + ) + assert args.dataset == "/tmp/data.h5" + assert args.db_path == "/tmp/db.sqlite" + assert args.output == "/tmp/weights.npy" + assert args.n_clones == 50 + assert args.preset == "national" + assert args.lambda_l0 == 1e-3 + assert args.epochs == 200 + assert args.device == "cuda" + assert args.seed == 99 + + def test_parse_args_invalid_device(self): + """Invalid device choice raises SystemExit.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + with pytest.raises(SystemExit): + parse_args(["--device", "tpu"]) + + def test_parse_args_invalid_preset(self): + """Invalid preset choice raises SystemExit.""" + from policyengine_us_data.calibration.unified_calibration import ( + parse_args, + ) + + with pytest.raises(SystemExit): + parse_args(["--preset", "invalid"]) + + +# ------------------------------------------------------------------- +# Helpers for mocking run_calibration dependencies +# ------------------------------------------------------------------- + + +def _make_mock_l0(n_total, return_weights=None): + """Create a mock SparseCalibrationWeights model. + + Args: + n_total: Number of features (records * clones). + return_weights: Optional weight array to return. + + Returns: + Tuple of (mock_l0_module, mock_model). + """ + if return_weights is None: + return_weights = np.ones(n_total) * 95.0 + + mock_model = MagicMock() + mock_model.get_weights.return_value = MagicMock( + cpu=MagicMock( + return_value=MagicMock( + numpy=MagicMock(return_value=return_weights) + ) + ) + ) + + mock_l0_module = MagicMock() + mock_l0_module.SparseCalibrationWeights = MagicMock( + return_value=mock_model + ) + + return mock_l0_module, mock_model + + +def _make_mock_sim(n_records): + """Create a mock Microsimulation that returns n_records + household IDs. + + Args: + n_records: Number of households. + + Returns: + Mock Microsimulation instance. + """ + mock_sim = MagicMock() + mock_calc_result = MagicMock() + mock_calc_result.values = np.arange(n_records) + mock_sim.calculate.return_value = mock_calc_result + return mock_sim + + +def _make_mock_geography(n_records, n_clones): + """Create a mock GeographyAssignment. + + Args: + n_records: Number of base records. + n_clones: Number of clones. + + Returns: + Mock GeographyAssignment. + """ + n_total = n_records * n_clones + geo = MagicMock() + geo.n_records = n_records + geo.n_clones = n_clones + geo.state_fips = np.ones(n_total, dtype=int) * 6 + geo.cd_geoid = np.array(["0601"] * n_total) + geo.block_geoid = np.array(["060014001001000"] * n_total) + return geo + + +def _make_mock_builder_result(n_targets, n_total): + """Create mock build_matrix return values. + + Args: + n_targets: Number of calibration targets. + n_total: Number of total records (records * clones). + + Returns: + Tuple of (targets_df, X_sparse, target_names). + """ + targets_df = pd.DataFrame( + { + "variable": [f"var_{i}" for i in range(n_targets)], + "value": np.random.rand(n_targets) * 1e6 + 1e3, + } + ) + X_sparse = scipy.sparse.random( + n_targets, n_total, density=0.3, format="csr" + ) + # Ensure no all-zero rows by default + for i in range(n_targets): + if X_sparse[i, :].nnz == 0: + X_sparse[i, 0] = 1.0 + target_names = [f"target_{i}" for i in range(n_targets)] + return targets_df, X_sparse, target_names + + +def _setup_module_mocks( + mock_sim, + mock_geo, + mock_builder, + mock_l0_module, +): + """Build sys.modules dict for patching imports inside + run_calibration. + + Since run_calibration uses local imports, we mock the + source modules in sys.modules so that ``from X import Y`` + resolves to our mocks. + + Args: + mock_sim: Mock Microsimulation instance. + mock_geo: Mock GeographyAssignment to return. + mock_builder: Mock UnifiedMatrixBuilder instance. + mock_l0_module: Mock l0.calibration module. + + Returns: + Dict suitable for patch.dict(sys.modules, ...). + """ + # Mock policyengine_us.Microsimulation + mock_pe_us = MagicMock() + mock_pe_us.Microsimulation = MagicMock(return_value=mock_sim) + + # Mock clone_and_assign module + mock_clone_mod = MagicMock() + mock_clone_mod.assign_random_geography = MagicMock(return_value=mock_geo) + mock_clone_mod.double_geography_for_puf = MagicMock(return_value=mock_geo) + + # Mock unified_matrix_builder module + mock_builder_mod = MagicMock() + mock_builder_mod.UnifiedMatrixBuilder = MagicMock( + return_value=mock_builder + ) + + # Mock torch + mock_torch = MagicMock() + mock_torch.no_grad.return_value.__enter__ = MagicMock(return_value=None) + mock_torch.no_grad.return_value.__exit__ = MagicMock(return_value=False) + + return { + "policyengine_us": mock_pe_us, + "policyengine_us.Microsimulation": mock_pe_us, + "policyengine_us_data.calibration.clone_and_assign": (mock_clone_mod), + "policyengine_us_data.calibration." + "unified_matrix_builder": mock_builder_mod, + "l0": MagicMock(), + "l0.calibration": mock_l0_module, + "torch": mock_torch, + } + + +# ------------------------------------------------------------------- +# run_calibration tests +# ------------------------------------------------------------------- + + +class TestRunCalibration: + """Test run_calibration with fully mocked dependencies.""" + + def test_returns_weights_array(self): + """run_calibration returns a numpy array of correct + shape.""" + # We need to reimport after patching sys.modules + n_records = 10 + n_clones = 5 + n_total = n_records * n_clones + n_targets = 8 + + expected_weights = np.ones(n_total) * 95.0 + mock_l0_module, mock_model = _make_mock_l0(n_total, expected_weights) + mock_sim = _make_mock_sim(n_records) + mock_geo = _make_mock_geography(n_records, n_clones) + targets_df, X_sparse, target_names = _make_mock_builder_result( + n_targets, n_total + ) + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + targets_df, + X_sparse, + target_names, + ) + + modules = _setup_module_mocks( + mock_sim, mock_geo, mock_builder, mock_l0_module + ) + + with patch.dict(sys.modules, modules): + # Reimport to pick up patched modules + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration." "unified_calibration" + ) + importlib.reload(mod) + + weights = mod.run_calibration( + dataset_path="/fake/data.h5", + db_path="/fake/db.sqlite", + n_clones=n_clones, + lambda_l0=1e-8, + epochs=5, + skip_puf=True, + skip_source_impute=True, + ) + + assert isinstance(weights, np.ndarray) + assert weights.shape == (n_total,) + np.testing.assert_array_equal(weights, expected_weights) + + def test_achievable_target_filtering(self): + """Targets with all-zero rows in X_sparse are removed + before L0 fitting.""" + n_records = 5 + n_clones = 2 + n_total = n_records * n_clones + + expected_weights = np.ones(n_total) * 90.0 + mock_l0_module, mock_model = _make_mock_l0(n_total, expected_weights) + mock_sim = _make_mock_sim(n_records) + mock_geo = _make_mock_geography(n_records, n_clones) + + # Create sparse matrix with 2 achievable and 2 zero + # rows + targets_df = pd.DataFrame( + { + "variable": [ + "achievable_0", + "zero_0", + "achievable_1", + "zero_1", + ], + "value": [1e6, 2e6, 3e6, 4e6], + } + ) + X_sparse = scipy.sparse.lil_matrix((4, n_total)) + X_sparse[0, 0] = 1.0 # achievable + X_sparse[1, :] = 0.0 # all-zero row + X_sparse[2, 1] = 2.0 # achievable + X_sparse[3, :] = 0.0 # all-zero row + X_sparse = X_sparse.tocsr() + + target_names = [ + "achievable_0", + "zero_0", + "achievable_1", + "zero_1", + ] + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + targets_df, + X_sparse, + target_names, + ) + + # Track what gets passed to model.fit + fit_call_args = {} + + def capture_fit(**kwargs): + fit_call_args.update(kwargs) + + mock_model.fit.side_effect = capture_fit + + modules = _setup_module_mocks( + mock_sim, mock_geo, mock_builder, mock_l0_module + ) + + with patch.dict(sys.modules, modules): + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration." "unified_calibration" + ) + importlib.reload(mod) + + mod.run_calibration( + dataset_path="/fake/data.h5", + db_path="/fake/db.sqlite", + n_clones=n_clones, + lambda_l0=1e-8, + epochs=5, + skip_puf=True, + skip_source_impute=True, + ) + + # model.fit should have been called with only 2 + # achievable targets, not 4 + y_passed = fit_call_args["y"] + M_passed = fit_call_args["M"] + assert len(y_passed) == 2 + assert M_passed.shape[0] == 2 + # Values should be the achievable ones + np.testing.assert_array_almost_equal(y_passed, [1e6, 3e6]) + + def test_calls_assign_random_geography(self): + """assign_random_geography is called with correct + arguments.""" + n_records = 8 + n_clones = 3 + n_total = n_records * n_clones + n_targets = 5 + + mock_l0_module, _ = _make_mock_l0(n_total) + mock_sim = _make_mock_sim(n_records) + mock_geo = _make_mock_geography(n_records, n_clones) + targets_df, X_sparse, target_names = _make_mock_builder_result( + n_targets, n_total + ) + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + targets_df, + X_sparse, + target_names, + ) + + modules = _setup_module_mocks( + mock_sim, mock_geo, mock_builder, mock_l0_module + ) + + # Get a handle to the mock assign function so we can + # assert on it after the context manager exits + mock_assign = modules[ + "policyengine_us_data.calibration.clone_and_assign" + ].assign_random_geography + + with patch.dict(sys.modules, modules): + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration." "unified_calibration" + ) + importlib.reload(mod) + + mod.run_calibration( + dataset_path="/fake/data.h5", + db_path="/fake/db.sqlite", + n_clones=n_clones, + lambda_l0=1e-8, + epochs=5, + seed=123, + skip_puf=True, + skip_source_impute=True, + ) + + mock_assign.assert_called_once_with( + n_records=n_records, + n_clones=n_clones, + seed=123, + ) + + def test_l0_import_error(self): + """ImportError raised when l0-python is missing.""" + n_records = 5 + n_clones = 2 + n_total = n_records * n_clones + n_targets = 3 + + mock_sim = _make_mock_sim(n_records) + mock_geo = _make_mock_geography(n_records, n_clones) + targets_df, X_sparse, target_names = _make_mock_builder_result( + n_targets, n_total + ) + + mock_builder = MagicMock() + mock_builder.build_matrix.return_value = ( + targets_df, + X_sparse, + target_names, + ) + + # Mock policyengine_us + mock_pe_us = MagicMock() + mock_pe_us.Microsimulation = MagicMock(return_value=mock_sim) + + # Mock clone_and_assign module + mock_clone_mod = MagicMock() + mock_clone_mod.assign_random_geography = MagicMock( + return_value=mock_geo + ) + mock_clone_mod.double_geography_for_puf = MagicMock( + return_value=mock_geo + ) + + # Mock unified_matrix_builder module + mock_builder_mod = MagicMock() + mock_builder_mod.UnifiedMatrixBuilder = MagicMock( + return_value=mock_builder + ) + + modules = { + "policyengine_us": mock_pe_us, + "policyengine_us_data.calibration." + "clone_and_assign": mock_clone_mod, + "policyengine_us_data.calibration." + "unified_matrix_builder": mock_builder_mod, + "l0": None, + "l0.calibration": None, + } + + with patch.dict(sys.modules, modules): + import importlib + + mod = importlib.import_module( + "policyengine_us_data.calibration." "unified_calibration" + ) + importlib.reload(mod) + + with pytest.raises(ImportError, match="l0-python"): + mod.run_calibration( + dataset_path="/fake/data.h5", + db_path="/fake/db.sqlite", + n_clones=n_clones, + lambda_l0=1e-8, + epochs=5, + skip_puf=True, + skip_source_impute=True, + ) + + +# ------------------------------------------------------------------- +# Constants tests +# ------------------------------------------------------------------- + + +class TestConstants: + """Test that module constants are defined correctly.""" + + def test_presets_defined(self): + from policyengine_us_data.calibration.unified_calibration import ( + PRESETS, + ) + + assert "local" in PRESETS + assert "national" in PRESETS + assert PRESETS["local"] == 1e-8 + assert PRESETS["national"] == 1e-4 + + def test_hyperparameters_defined(self): + from policyengine_us_data.calibration.unified_calibration import ( + BETA, + GAMMA, + ZETA, + INIT_KEEP_PROB, + LOG_WEIGHT_JITTER_SD, + LOG_ALPHA_JITTER_SD, + LAMBDA_L2, + LEARNING_RATE, + ) + + assert BETA == 0.35 + assert GAMMA == -0.1 + assert ZETA == 1.1 + assert INIT_KEEP_PROB == 0.999 + assert LOG_WEIGHT_JITTER_SD == 0.05 + assert LOG_ALPHA_JITTER_SD == 0.01 + assert LAMBDA_L2 == 1e-12 + assert LEARNING_RATE == 0.15 + + def test_default_constants(self): + from policyengine_us_data.calibration.unified_calibration import ( + DEFAULT_EPOCHS, + DEFAULT_N_CLONES, + ) + + assert DEFAULT_EPOCHS == 100 + assert DEFAULT_N_CLONES == 10 diff --git a/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py new file mode 100644 index 000000000..32a0ef002 --- /dev/null +++ b/policyengine_us_data/tests/test_calibration/test_unified_matrix_builder.py @@ -0,0 +1,1014 @@ +""" +Tests for UnifiedMatrixBuilder. + +Uses a mock in-memory SQLite database with representative targets +and mocked Microsimulation to avoid heavy dependencies. +""" + +import numpy as np +import pytest +from collections import namedtuple +from unittest.mock import MagicMock, patch +from sqlalchemy import create_engine +from sqlmodel import Session, SQLModel + +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, + Source, + SourceType, +) +from policyengine_us_data.calibration.unified_matrix_builder import ( + UnifiedMatrixBuilder, + COUNT_VARIABLES, + _GEO_STATE_VARS, + _GEO_CD_VARS, + _GEO_VARS, +) + +# ------------------------------------------------------------------- +# Helper: lightweight GeographyAssignment stand-in +# ------------------------------------------------------------------- + + +class MockGeography: + """Minimal geography assignment for testing. + + Attributes: + n_records: Number of base records (households). + n_clones: Number of clones. + block_geoid: Array of length n_records * n_clones with + 15-char census block GEOID strings. + state_fips: Array of length n_records * n_clones with + state FIPS for each column. + cd_geoid: Array of length n_records * n_clones with CD + GEOID strings for each column. + """ + + def __init__( + self, n_records, n_clones, state_fips, cd_geoid, block_geoid=None + ): + self.n_records = n_records + self.n_clones = n_clones + self.state_fips = np.asarray(state_fips) + self.cd_geoid = np.asarray(cd_geoid, dtype=object) + if block_geoid is not None: + self.block_geoid = np.asarray(block_geoid, dtype=object) + else: + # Generate synthetic block GEOIDs from state_fips + self.block_geoid = np.array( + [f"{int(s):02d}0010001001001" for s in self.state_fips], + dtype=object, + ) + + +# ------------------------------------------------------------------- +# Helper: build a mock Microsimulation +# ------------------------------------------------------------------- + + +def _make_mock_sim(n_households=4, n_persons=8): + """Create a mock Microsimulation with controllable data. + + Household layout: + hh0: persons 0,1 (income=1000, snap=100) + hh1: persons 2,3 (income=2000, snap=200) + hh2: persons 4,5 (income=3000, snap=0) + hh3: persons 6,7 (income=4000, snap=0) + """ + sim = MagicMock() + + person_hh_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3]) + person_tu_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3]) + person_spm_ids = np.array([0, 0, 1, 1, 2, 2, 3, 3]) + hh_ids = np.arange(n_households) + + hh_income = np.array([1000, 2000, 3000, 4000], dtype=float) + hh_snap = np.array([100, 200, 0, 0], dtype=float) + # Person-level ages: hh0 has adult+child, rest all adults + person_age = np.array([35, 10, 40, 45, 50, 55, 60, 65], dtype=float) + + def calculate_side_effect(var, period=None, map_to=None): + result = MagicMock() + if map_to == "person": + mapping = { + "household_id": person_hh_ids, + "tax_unit_id": person_tu_ids, + "spm_unit_id": person_spm_ids, + "person_id": np.arange(n_persons), + "age": person_age, + "income": hh_income[person_hh_ids], + "snap": hh_snap[person_hh_ids], + "person_count": np.ones(n_persons, dtype=float), + } + result.values = mapping.get(var, np.zeros(n_persons, dtype=float)) + elif map_to == "household": + mapping = { + "household_id": hh_ids, + "income": hh_income, + "snap": hh_snap, + "person_count": np.array([2, 2, 2, 2], dtype=float), + "household_count": np.ones(n_households, dtype=float), + } + result.values = mapping.get( + var, np.zeros(n_households, dtype=float) + ) + else: + result.values = np.zeros(n_households, dtype=float) + return result + + sim.calculate = calculate_side_effect + + def map_result_side_effect(values, from_entity, to_entity, how=None): + if from_entity == "person" and to_entity == "household": + result = np.zeros(n_households, dtype=float) + for i in range(n_persons): + result[person_hh_ids[i]] += float(values[i]) + return result + return values + + sim.map_result = map_result_side_effect + + # Mock set_input and delete_arrays as no-ops + sim.set_input = MagicMock() + sim.delete_arrays = MagicMock() + + return sim + + +# ------------------------------------------------------------------- +# Fixtures: in-memory SQLite databases +# ------------------------------------------------------------------- + + +def _seed_db(engine): + """Populate test DB with targets at national, state, and CD levels. + + Creates: + - National stratum -> income target (value=1e9) + - CA stratum (state_fips=6) -> snap target (value=5e8) + - NY stratum (state_fips=36) -> snap target (value=3e8) + - CD 0601 stratum (CA CD) -> income target (value=1e8) + - National -> household_count target (value=1e6) + """ + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + # --- National stratum (no constraints) --- + us_stratum = Stratum( + stratum_group_id=0, + notes="United States", + ) + us_stratum.constraints_rel = [] + session.add(us_stratum) + session.flush() + + # Target 1: national income sum + session.add( + Target( + stratum_id=us_stratum.stratum_id, + variable="income", + period=2024, + value=1e9, + source_id=source.source_id, + active=True, + ) + ) + + # Target 2: national household_count + session.add( + Target( + stratum_id=us_stratum.stratum_id, + variable="household_count", + period=2024, + value=1e6, + source_id=source.source_id, + active=True, + ) + ) + + # --- California stratum --- + ca_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=1, + notes="California", + ) + ca_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ) + ] + session.add(ca_stratum) + session.flush() + + # Target 3: CA snap + session.add( + Target( + stratum_id=ca_stratum.stratum_id, + variable="snap", + period=2024, + value=5e8, + source_id=source.source_id, + active=True, + ) + ) + + # --- New York stratum --- + ny_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=1, + notes="New York", + ) + ny_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="36", + ) + ] + session.add(ny_stratum) + session.flush() + + # Target 4: NY snap + session.add( + Target( + stratum_id=ny_stratum.stratum_id, + variable="snap", + period=2024, + value=3e8, + source_id=source.source_id, + active=True, + ) + ) + + # --- CD 0601 stratum (under CA) --- + cd_stratum = Stratum( + parent_stratum_id=ca_stratum.stratum_id, + stratum_group_id=2, + notes="CA-01", + ) + cd_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value="0601", + ) + ] + session.add(cd_stratum) + session.flush() + + # Target 5: CD 0601 income + session.add( + Target( + stratum_id=cd_stratum.stratum_id, + variable="income", + period=2024, + value=1e8, + source_id=source.source_id, + active=True, + ) + ) + + session.commit() + + return engine + + +def _seed_constrained_db(engine): + """Populate test DB with a constrained target. + + Creates a state-level target with an age constraint: + - CA stratum with age >= 18 -> income target + """ + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + # National parent (no constraints) + us_stratum = Stratum( + stratum_group_id=0, + notes="United States", + ) + us_stratum.constraints_rel = [] + session.add(us_stratum) + session.flush() + + # CA stratum + ca_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=1, + notes="California", + ) + ca_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value="6", + ) + ] + session.add(ca_stratum) + session.flush() + + # CA adults stratum (child of CA) + ca_adults = Stratum( + parent_stratum_id=ca_stratum.stratum_id, + stratum_group_id=3, + notes="CA adults", + ) + ca_adults.constraints_rel = [ + StratumConstraint( + constraint_variable="age", + operation=">=", + value="18", + ) + ] + session.add(ca_adults) + session.flush() + + # Target: CA adult income + session.add( + Target( + stratum_id=ca_adults.stratum_id, + variable="income", + period=2024, + value=7e8, + source_id=source.source_id, + active=True, + ) + ) + + session.commit() + + return engine + + +@pytest.fixture +def mock_db(): + """In-memory SQLite DB with representative targets.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + return _seed_db(engine) + + +@pytest.fixture +def constrained_db(): + """In-memory SQLite DB with a constrained (age) target.""" + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + return _seed_constrained_db(engine) + + +@pytest.fixture +def mock_sim(): + """Standard mock Microsimulation with 4 households / 8 persons.""" + return _make_mock_sim() + + +def _builder_with_engine(engine, time_period=2024): + """Create a UnifiedMatrixBuilder and inject the test engine.""" + builder = UnifiedMatrixBuilder( + db_uri="sqlite://", + time_period=time_period, + ) + builder.engine = engine + return builder + + +# ------------------------------------------------------------------- +# Geography setup helper +# ------------------------------------------------------------------- + + +def _make_geography(n_records, n_clones, assignments): + """Build a MockGeography from a list of (state_fips, cd_geoid) + tuples, one per clone. + + Column ordering: clone_idx * n_records + record_idx + + Args: + n_records: Number of base records. + n_clones: Number of clones. + assignments: List of (state_fips, cd_geoid) of length + n_clones. Each assignment applies to ALL records + cloned into that clone slot. + """ + assert len(assignments) == n_clones + n_total = n_records * n_clones + state_fips = np.zeros(n_total, dtype=int) + cd_geoid = np.empty(n_total, dtype=object) + + for clone_idx, (sfips, cd) in enumerate(assignments): + start = clone_idx * n_records + end = start + n_records + state_fips[start:end] = sfips + cd_geoid[start:end] = cd + + return MockGeography(n_records, n_clones, state_fips, cd_geoid) + + +# ------------------------------------------------------------------- +# Tests: module constants +# ------------------------------------------------------------------- + + +class TestModuleConstants: + def test_count_variables_not_empty(self): + assert len(COUNT_VARIABLES) > 0 + + def test_household_count_in_count_variables(self): + assert "household_count" in COUNT_VARIABLES + + def test_geo_vars_include_state_and_cd(self): + assert "state_fips" in _GEO_VARS + assert "congressional_district_geoid" in _GEO_VARS + + +# ------------------------------------------------------------------- +# Tests: _classify_constraint_geo +# ------------------------------------------------------------------- + + +class TestClassifyConstraintGeo: + def test_national_when_no_geo_constraints(self): + builder = UnifiedMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "age", + "operation": ">=", + "value": "18", + } + ] + level, geo_id = builder._classify_constraint_geo(constraints) + assert level == "national" + assert geo_id is None + + def test_state_level(self): + builder = UnifiedMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + } + ] + level, geo_id = builder._classify_constraint_geo(constraints) + assert level == "state" + assert geo_id == "6" + + def test_cd_level(self): + builder = UnifiedMatrixBuilder(db_uri="sqlite://", time_period=2024) + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + }, + { + "variable": "congressional_district_geoid", + "operation": "==", + "value": "0601", + }, + ] + level, geo_id = builder._classify_constraint_geo(constraints) + assert level == "cd" + assert geo_id == "0601" + + def test_empty_constraints_returns_national(self): + builder = UnifiedMatrixBuilder(db_uri="sqlite://", time_period=2024) + level, geo_id = builder._classify_constraint_geo([]) + assert level == "national" + assert geo_id is None + + +# ------------------------------------------------------------------- +# Tests: _make_target_name +# ------------------------------------------------------------------- + + +class TestMakeTargetName: + def test_national_unconstrained(self): + name = UnifiedMatrixBuilder._make_target_name( + "income", [], reform_id=0 + ) + assert "national" in name + assert "income" in name + + def test_state_constraint_in_name(self): + constraints = [ + { + "variable": "state_fips", + "operation": "==", + "value": "6", + } + ] + name = UnifiedMatrixBuilder._make_target_name( + "snap", constraints, reform_id=0 + ) + assert "state_6" in name + assert "snap" in name + + def test_cd_constraint_in_name(self): + constraints = [ + { + "variable": "congressional_district_geoid", + "operation": "==", + "value": "0601", + } + ] + name = UnifiedMatrixBuilder._make_target_name( + "income", constraints, reform_id=0 + ) + assert "cd_0601" in name + + def test_reform_appends_expenditure(self): + name = UnifiedMatrixBuilder._make_target_name( + "salt_deduction", [], reform_id=1 + ) + assert "expenditure" in name + + +# ------------------------------------------------------------------- +# Tests: build_matrix +# ------------------------------------------------------------------- + + +class TestMatrixShape: + """Verify the output matrix has the correct shape.""" + + @patch("policyengine_us.Microsimulation") + def test_matrix_shape(self, MockMicrosim, mock_db, mock_sim): + """Matrix shape is (n_targets, n_records * n_clones).""" + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 3 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), # clone 0 -> CA + (36, "3601"), # clone 1 -> NY + (6, "0602"), # clone 2 -> CA + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + + # 5 targets in mock_db + assert X.shape == (5, n_records * n_clones) + assert len(names) == 5 + assert len(targets_df) == 5 + + +class TestStateTargetFillsOnlyStateColumns: + """State-level target should only have nonzero values in columns + assigned to that state.""" + + @patch("policyengine_us.Microsimulation") + def test_state_target_fills_only_state_columns( + self, MockMicrosim, mock_db, mock_sim + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 3 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), # clone 0: CA cols 0-3 + (36, "3601"), # clone 1: NY cols 4-7 + (6, "0602"), # clone 2: CA cols 8-11 + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # Find the CA snap target (state_fips=6, variable=snap) + ca_snap_idx = None + for i, name in enumerate(names): + if "state_6" in name and "snap" in name: + ca_snap_idx = i + break + assert ( + ca_snap_idx is not None + ), f"Could not find CA snap target in names: {names}" + + ca_row = X_dense[ca_snap_idx] + # Cols 0-3 (clone 0, CA) and 8-11 (clone 2, CA) can be nonzero + # Cols 4-7 (clone 1, NY) must be zero + ny_cols = np.arange(4, 8) + np.testing.assert_array_equal( + ca_row[ny_cols], + np.zeros(len(ny_cols)), + err_msg="CA snap target has nonzero values in NY columns", + ) + + # At least some CA columns should be nonzero (snap=[100,200,0,0]) + ca_cols = np.concatenate([np.arange(0, 4), np.arange(8, 12)]) + assert np.any( + ca_row[ca_cols] != 0 + ), "CA snap target has all zeros in CA columns" + + +class TestCdTargetFillsOnlyCdColumns: + """CD-level target should only fill columns assigned to that CD.""" + + @patch("policyengine_us.Microsimulation") + def test_cd_target_fills_only_cd_columns( + self, MockMicrosim, mock_db, mock_sim + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 3 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), # clone 0: CA CD-01 cols 0-3 + (36, "3601"), # clone 1: NY CD-01 cols 4-7 + (6, "0602"), # clone 2: CA CD-02 cols 8-11 + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # Find the CD 0601 income target + cd_income_idx = None + for i, name in enumerate(names): + if "cd_0601" in name and "income" in name: + cd_income_idx = i + break + assert ( + cd_income_idx is not None + ), f"Could not find CD 0601 income target in names: {names}" + + cd_row = X_dense[cd_income_idx] + # Only cols 0-3 (clone 0, CD 0601) should be nonzero + # Cols 4-7 (NY) and 8-11 (CA CD 0602) must be zero + non_cd_cols = np.arange(4, 12) + np.testing.assert_array_equal( + cd_row[non_cd_cols], + np.zeros(len(non_cd_cols)), + err_msg=("CD 0601 target has nonzero in non-CD-0601 columns"), + ) + + cd_cols = np.arange(0, 4) + assert np.any( + cd_row[cd_cols] != 0 + ), "CD 0601 income target is all zeros in its own columns" + + +class TestNationalTargetFillsAllColumns: + """National target fills columns across all states.""" + + @patch("policyengine_us.Microsimulation") + def test_national_target_fills_all_columns( + self, MockMicrosim, mock_db, mock_sim + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 3 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), + (36, "3601"), + (6, "0602"), + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # Find national income target (no state_ or cd_ prefix) + national_income_idx = None + for i, name in enumerate(names): + if "national" in name and "income" in name and "cd_" not in name: + national_income_idx = i + break + assert ( + national_income_idx is not None + ), f"Could not find national income target in: {names}" + + nat_row = X_dense[national_income_idx] + # hh_income = [1000, 2000, 3000, 4000] -- all nonzero + # Should have nonzero values in every clone's columns + for clone_idx in range(n_clones): + start = clone_idx * n_records + end = start + n_records + clone_slice = nat_row[start:end] + assert np.any(clone_slice != 0), ( + f"National income target is all zeros in clone " + f"{clone_idx} (cols {start}-{end-1})" + ) + + +class TestColumnValuesUseCorrectRecord: + """Column i should use values from record i % n_records.""" + + @patch("policyengine_us.Microsimulation") + def test_column_values_use_correct_record( + self, MockMicrosim, mock_db, mock_sim + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 3 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), + (36, "3601"), + (6, "0602"), + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # Find national income target + national_income_idx = None + for i, name in enumerate(names): + if "national" in name and "income" in name and "cd_" not in name: + national_income_idx = i + break + assert national_income_idx is not None + + nat_row = X_dense[national_income_idx] + # hh_income = [1000, 2000, 3000, 4000] + expected_values = np.array([1000, 2000, 3000, 4000], dtype=np.float32) + + # Each clone should replicate the same base record values + for clone_idx in range(n_clones): + start = clone_idx * n_records + end = start + n_records + np.testing.assert_array_almost_equal( + nat_row[start:end], + expected_values, + err_msg=(f"Clone {clone_idx} has wrong record mapping"), + ) + + +class TestConstraintMaskApplied: + """Non-geographic constraints filter which records contribute.""" + + @patch("policyengine_us.Microsimulation") + def test_constraint_mask_applied( + self, + MockMicrosim, + constrained_db, + mock_sim, + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 2 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), # clone 0: CA + (36, "3601"), # clone 1: NY + ], + ) + + builder = _builder_with_engine(constrained_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # The only target is CA adults income (age >= 18) + assert X_dense.shape[0] == 1 + row = X_dense[0] + + # NY columns (clone 1, cols 4-7) must be zero because + # this is a state=6 target + ny_cols = np.arange(4, 8) + np.testing.assert_array_equal(row[ny_cols], np.zeros(len(ny_cols))) + + # CA columns (clone 0, cols 0-3): + # The age constraint (age >= 18) is evaluated per person: + # person_age = [35,10, 40,45, 50,55, 60,65] + # hh0 has one adult (person 0, age 35) -> mask True + # (any person satisfies -> household passes) + # hh1: both adults -> True + # hh2: both adults -> True + # hh3: both adults -> True + # So all CA households pass the constraint, and income + # values are [1000, 2000, 3000, 4000]. + ca_cols = np.arange(0, 4) + expected = np.array([1000, 2000, 3000, 4000], dtype=np.float32) + np.testing.assert_array_almost_equal(row[ca_cols], expected) + + +class TestCountVariableHandling: + """Count variables should produce 1.0 per qualifying household.""" + + @patch("policyengine_us.Microsimulation") + def test_household_count_is_one_per_household( + self, MockMicrosim, mock_db, mock_sim + ): + MockMicrosim.return_value = mock_sim + + n_records = 4 + n_clones = 2 + geography = _make_geography( + n_records, + n_clones, + [ + (6, "0601"), + (36, "3601"), + ], + ) + + builder = _builder_with_engine(mock_db) + targets_df, X, names = builder.build_matrix("dummy_path.h5", geography) + X_dense = X.toarray() + + # Find household_count target + hh_count_idx = None + for i, name in enumerate(names): + if "household_count" in name: + hh_count_idx = i + break + assert hh_count_idx is not None + + row = X_dense[hh_count_idx] + # household_count is a count variable with no constraints + # -> mask is all True -> values = mask.astype(float32) = 1.0 + # for all columns + expected = np.ones(n_records * n_clones, dtype=np.float32) + np.testing.assert_array_almost_equal(row, expected) + + +class TestQueryActiveTargets: + """Test that _query_active_targets returns correct data.""" + + def test_returns_all_active(self, mock_db): + builder = _builder_with_engine(mock_db) + df = builder._query_active_targets() + assert len(df) == 5 + + def test_filters_zero_values(self): + engine = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + source = Source(name="Test", type=SourceType.HARDCODED) + session.add(source) + session.flush() + + stratum = Stratum(stratum_group_id=0, notes="US") + stratum.constraints_rel = [] + session.add(stratum) + session.flush() + + session.add( + Target( + stratum_id=stratum.stratum_id, + variable="income", + period=2024, + value=0.0, + source_id=source.source_id, + active=True, + ) + ) + session.commit() + + builder = _builder_with_engine(engine) + df = builder._query_active_targets() + # Zero-value target filtered out + assert len(df) == 0 + + +class TestGetAllConstraints: + """Test constraint chain walking.""" + + def test_no_constraints_for_national(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + national_row = targets_df[ + (targets_df["variable"] == "income") & (targets_df["value"] > 5e8) + ].iloc[0] + constraints = builder._get_all_constraints(national_row["stratum_id"]) + assert constraints == [] + + def test_state_constraint_present(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + ca_row = targets_df[ + (targets_df["variable"] == "snap") & (targets_df["value"] == 5e8) + ].iloc[0] + constraints = builder._get_all_constraints(ca_row["stratum_id"]) + var_names = {c["variable"] for c in constraints} + assert "state_fips" in var_names + + def test_cd_walks_to_parent_state(self, mock_db): + builder = _builder_with_engine(mock_db) + targets_df = builder._query_active_targets() + cd_row = targets_df[targets_df["value"] == 1e8].iloc[0] + constraints = builder._get_all_constraints(cd_row["stratum_id"]) + var_names = {c["variable"] for c in constraints} + assert "congressional_district_geoid" in var_names + assert "state_fips" in var_names + + +class TestExtendedCPSHasNoCalculatedVars: + """The extended CPS h5 should contain only input variables. + + The unified calibration pipeline assigns new geography and + then invokes PE to compute all derived variables from + scratch. If the h5 includes variables with PE formulas, + those stored values could conflict with what PE would + compute fresh. This test ensures the h5 only stores + true survey inputs. + """ + + # Variables that have PE formulas but are stored in the + # h5 as survey-reported or imputed input values. These + # are acceptable because PE's set_input mechanism means + # the stored value takes precedence over the formula. + # Each entry should have a comment explaining why it's + # allowed. + _ALLOWED_FORMULA_VARS = { + # CPS/PUF-reported values with PE fallback formulas + "employment_income", + "self_employment_income", + "weekly_hours_worked", + # PUF-imputed tax credits (PE has formulas but we + # trust the imputed values from the tax model) + "american_opportunity_credit", + "foreign_tax_credit", + "savers_credit", + "energy_efficient_home_improvement_credit", + "cdcc_relevant_expenses", + "taxable_unemployment_compensation", + # Derived from other h5 inputs, not geography + "rent", + "person_id", + "employment_income_last_year", + "immigration_status", + } + + @pytest.mark.xfail( + reason="in_nyc should be removed from extended CPS h5", + strict=True, + ) + def test_no_formula_vars_in_h5(self): + """H5 should not contain PE formula variables. + + Any variable with a PE formula that's stored in + the h5 risks providing stale values (especially + after geography reassignment). Only explicitly + allowed exceptions are permitted. + + Currently xfail because in_nyc is in the h5 and + needs to be removed from the dataset build. + """ + import h5py + from pathlib import Path + + h5_path = Path("policyengine_us_data/storage/extended_cps_2024.h5") + if not h5_path.exists(): + pytest.skip("extended_cps_2024.h5 not available") + + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=str(h5_path)) + + with h5py.File(h5_path, "r") as f: + h5_vars = set(f.keys()) + + unexpected = set() + for var_name in h5_vars: + if var_name not in sim.tax_benefit_system.variables: + continue + var = sim.tax_benefit_system.variables[var_name] + has_formula = hasattr(var, "formulas") and len(var.formulas) > 0 + if has_formula and var_name not in self._ALLOWED_FORMULA_VARS: + unexpected.add(var_name) + + assert unexpected == set(), ( + f"Extended CPS h5 contains {len(unexpected)} " + f"variable(s) with PE formulas that are not in " + f"the allowlist. Either remove them from the " + f"h5 or add to _ALLOWED_FORMULA_VARS with a " + f"justification: {sorted(unexpected)}" + ) diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py index 96e4a996c..13d533337 100644 --- a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -6,13 +6,11 @@ import numpy as np from policyengine_core.data import Dataset -from policyengine_core.reforms import Reform from policyengine_us import Microsimulation -from policyengine_us_data.utils import ( - build_loss_matrix, - print_reweighting_diagnostics, +from policyengine_us_data.storage import ( + STORAGE_FOLDER, + CALIBRATION_FOLDER, ) -from policyengine_us_data.storage import STORAGE_FOLDER, CALIBRATION_FOLDER @pytest.fixture(scope="session") @@ -25,134 +23,21 @@ def sim(data): return Microsimulation(dataset=data) -@pytest.mark.filterwarnings("ignore:DataFrame is highly fragmented") -@pytest.mark.filterwarnings("ignore:The distutils package is deprecated") -@pytest.mark.filterwarnings( - "ignore:Series.__getitem__ treating keys as positions is deprecated" -) -@pytest.mark.filterwarnings( - "ignore:Setting an item of incompatible dtype is deprecated" -) -@pytest.mark.filterwarnings( - "ignore:Boolean Series key will be reindexed to match DataFrame index." -) -def test_sparse_ecps(sim): - data = sim.dataset.load_dataset() - optimised_weights = data["household_weight"]["2024"] - - bad_targets = [ - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All", - ] - - loss_matrix, targets_array = build_loss_matrix(sim.dataset, 2024) - zero_mask = np.isclose(targets_array, 0.0, atol=0.1) - bad_mask = loss_matrix.columns.isin(bad_targets) - keep_mask_bool = ~(zero_mask | bad_mask) - keep_idx = np.where(keep_mask_bool)[0] - loss_matrix_clean = loss_matrix.iloc[:, keep_idx] - targets_array_clean = targets_array[keep_idx] - assert loss_matrix_clean.shape[1] == targets_array_clean.size - - percent_within_10 = print_reweighting_diagnostics( - optimised_weights, - loss_matrix_clean, - targets_array_clean, - "Sparse Solutions", - ) - assert percent_within_10 > 60.0 - - def test_sparse_ecps_has_mortgage_interest(sim): assert sim.calculate("deductible_mortgage_interest").sum() > 1 def test_sparse_ecps_has_tips(sim): # Ensure we impute at least $40 billion in tip income. - # We currently target $38 billion * 1.4 = $53.2 billion. TIP_INCOME_MINIMUM = 40e9 assert sim.calculate("tip_income").sum() > TIP_INCOME_MINIMUM -def test_sparse_ecps_replicates_jct_tax_expenditures(): - calibration_log = pd.read_csv( - "calibration_log.csv", - ) - - jct_rows = calibration_log[ - (calibration_log["target_name"].str.contains("jct/")) - & (calibration_log["epoch"] == calibration_log["epoch"].max()) - ] - - assert ( - jct_rows.rel_abs_error.max() < 0.5 - ), "JCT tax expenditure targets not met (see the calibration log for details). Max relative error: {:.2%}".format( - jct_rows.rel_abs_error.max() - ) - - -def deprecated_test_sparse_ecps_replicates_jct_tax_expenditures_full(sim): - - # JCT tax expenditure targets - EXPENDITURE_TARGETS = { - "salt_deduction": 21.247e9, - "medical_expense_deduction": 11.4e9, - "charitable_deduction": 65.301e9, - "interest_deduction": 24.8e9, - } - - baseline = sim - income_tax_b = baseline.calculate( - "income_tax", period=2024, map_to="household" - ) - - for deduction, target in EXPENDITURE_TARGETS.items(): - # Create reform that neutralizes the deduction - class RepealDeduction(Reform): - def apply(self): - self.neutralize_variable(deduction) - - # Run reform simulation - reformed = Microsimulation(reform=RepealDeduction, dataset=sim.dataset) - income_tax_r = reformed.calculate( - "income_tax", period=2024, map_to="household" - ) - - # Calculate tax expenditure - tax_expenditure = (income_tax_r - income_tax_b).sum() - pct_error = abs((tax_expenditure - target) / target) - TOLERANCE = 0.4 - - logging.info( - f"{deduction} tax expenditure {tax_expenditure/1e9:.1f}bn " - f"differs from target {target/1e9:.1f}bn by {pct_error:.2%}" - ) - assert pct_error < TOLERANCE, deduction - - def test_sparse_ssn_card_type_none_target(sim): TARGET_COUNT = 13e6 TOLERANCE = 0.2 # Allow 20% error - # Calculate the number of individuals with ssn_card_type == "NONE" ssn_type_none_mask = sim.calculate("ssn_card_type") == "NONE" count = ssn_type_none_mask.sum() @@ -168,7 +53,8 @@ def test_sparse_ssn_card_type_none_target(sim): def test_sparse_aca_calibration(sim): TARGETS_PATH = Path( - "policyengine_us_data/storage/calibration_targets/aca_spending_and_enrollment_2024.csv" + "policyengine_us_data/storage/calibration_targets/" + "aca_spending_and_enrollment_2024.csv" ) targets = pd.read_csv(TARGETS_PATH) # Monthly to yearly @@ -206,7 +92,8 @@ def test_sparse_aca_calibration(sim): def test_sparse_medicaid_calibration(sim): TARGETS_PATH = Path( - "policyengine_us_data/storage/calibration_targets/medicaid_enrollment_2024.csv" + "policyengine_us_data/storage/calibration_targets/" + "medicaid_enrollment_2024.csv" ) targets = pd.read_csv(TARGETS_PATH) diff --git a/policyengine_us_data/tests/test_db/__init__.py b/policyengine_us_data/tests/test_db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/policyengine_us_data/tests/test_db/test_etl_all_targets.py b/policyengine_us_data/tests/test_db/test_etl_all_targets.py new file mode 100644 index 000000000..31e6a2c17 --- /dev/null +++ b/policyengine_us_data/tests/test_db/test_etl_all_targets.py @@ -0,0 +1,474 @@ +"""Tests for the comprehensive ETL that migrates all legacy loss.py +targets into the database.""" + +import numpy as np +import pandas as pd +import pytest +from sqlmodel import Session, SQLModel, create_engine, select + +from policyengine_us_data.db.create_database_tables import ( + Source, + SourceType, + Stratum, + StratumConstraint, + Target, +) +from policyengine_us_data.db.etl_all_targets import ( + extract_census_age_populations, + extract_eitc_by_child_count, + extract_healthcare_by_age, + extract_soi_filer_counts, + extract_spm_threshold_agi, + extract_negative_market_income, + extract_infant_count, + extract_net_worth, + extract_state_population, + extract_tax_expenditure_targets, + extract_state_real_estate_taxes, + extract_state_aca, + extract_state_medicaid_enrollment, + extract_state_10yr_age, + extract_state_agi, + extract_soi_filing_status_targets, + load_all_targets, +) +from policyengine_us_data.storage import CALIBRATION_FOLDER + +# ------------------------------------------------------------------ # +# Fixtures # +# ------------------------------------------------------------------ # + + +@pytest.fixture +def engine(): + """Create an in-memory SQLite engine with all tables.""" + eng = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(eng) + return eng + + +@pytest.fixture +def session(engine): + """Provide a session that auto-rolls back.""" + with Session(engine) as sess: + yield sess + + +@pytest.fixture +def root_stratum(session): + """Create the root (national) stratum required by the ETL.""" + root = Stratum( + definition_hash="root_national", + parent_stratum_id=None, + stratum_group_id=1, + notes="United States", + ) + session.add(root) + session.commit() + session.refresh(root) + return root + + +# ------------------------------------------------------------------ # +# Extract-layer tests # +# ------------------------------------------------------------------ # + + +class TestExtractCensusAgePopulations: + def test_returns_86_bins(self): + records = extract_census_age_populations(time_period=2024) + assert len(records) == 86 + + def test_values_are_positive(self): + records = extract_census_age_populations(time_period=2024) + for r in records: + assert r["value"] > 0 + assert r["age"] == r["age"] # not NaN + + def test_first_bin_is_age_0(self): + records = extract_census_age_populations(time_period=2024) + assert records[0]["age"] == 0 + + +class TestExtractEitcByChildCount: + def test_returns_4_child_buckets(self): + records = extract_eitc_by_child_count() + child_counts = {r["count_children"] for r in records} + assert child_counts == {0, 1, 2, 3} + + def test_has_returns_and_total(self): + records = extract_eitc_by_child_count() + for r in records: + assert "eitc_returns" in r + assert "eitc_total" in r + assert r["eitc_returns"] > 0 + assert r["eitc_total"] > 0 + + +class TestExtractSoiFilerCounts: + def test_returns_7_bands(self): + records = extract_soi_filer_counts() + assert len(records) == 7 + + def test_bands_cover_full_range(self): + records = extract_soi_filer_counts() + lowers = [r["agi_lower"] for r in records] + uppers = [r["agi_upper"] for r in records] + assert -np.inf in lowers + assert np.inf in uppers + + +class TestExtractHealthcareByAge: + def test_returns_9_age_bands(self): + records = extract_healthcare_by_age() + assert len(records) == 9 + + def test_has_4_expense_types(self): + records = extract_healthcare_by_age() + for r in records: + assert ( + "health_insurance_premiums_without_medicare_part_b" + in r["expenses"] + ) + assert "over_the_counter_health_expenses" in r["expenses"] + assert "other_medical_expenses" in r["expenses"] + assert "medicare_part_b_premiums" in r["expenses"] + + +class TestExtractSpmThresholdAgi: + def test_returns_10_deciles(self): + records = extract_spm_threshold_agi() + assert len(records) == 10 + + def test_has_agi_and_count(self): + records = extract_spm_threshold_agi() + for r in records: + assert "adjusted_gross_income" in r + assert "count" in r + assert r["decile"] >= 1 + + +class TestExtractNegativeMarketIncome: + def test_has_total_and_count(self): + result = extract_negative_market_income() + assert result["total"] == -138e9 + assert result["count"] == 3e6 + + +class TestExtractInfantCount: + def test_returns_positive(self): + result = extract_infant_count() + assert result > 3e6 + + +class TestExtractNetWorth: + def test_returns_160_trillion(self): + result = extract_net_worth() + assert result == 160e12 + + +class TestExtractStatePopulation: + def test_returns_51_or_52_rows(self): + records = extract_state_population() + assert len(records) in (51, 52) # 50 states + DC (+ PR) + + def test_has_under_5(self): + records = extract_state_population() + for r in records: + assert "population_under_5" in r + + +class TestExtractTaxExpenditureTargets: + def test_returns_5_deductions(self): + records = extract_tax_expenditure_targets() + assert len(records) == 5 + names = {r["variable"] for r in records} + assert "salt_deduction" in names + assert "qualified_business_income_deduction" in names + + +class TestExtractStateRealEstateTaxes: + def test_returns_51_states(self): + records = extract_state_real_estate_taxes() + assert len(records) == 51 + + def test_sums_to_national_target(self): + records = extract_state_real_estate_taxes() + total = sum(r["value"] for r in records) + assert abs(total - 500e9) < 1e6 # within $1M + + +class TestExtractStateAca: + def test_returns_spending_and_enrollment(self): + records = extract_state_aca() + assert len(records) > 0 + first = records[0] + assert "spending" in first + assert "enrollment" in first + + +class TestExtractStateMedicaidEnrollment: + def test_returns_51_rows(self): + records = extract_state_medicaid_enrollment() + assert len(records) == 51 + + +class TestExtractState10yrAge: + def test_returns_50_states(self): + records = extract_state_10yr_age() + states = {r["state"] for r in records} + assert len(states) == 50 + + def test_has_18_age_ranges(self): + records = extract_state_10yr_age() + first_state = records[0]["state"] + state_records = [r for r in records if r["state"] == first_state] + assert len(state_records) == 18 + + +class TestExtractStateAgi: + def test_returns_918_rows(self): + records = extract_state_agi() + assert len(records) == 918 + + +class TestExtractSoiFilingStatusTargets: + def test_returns_filtered_rows(self): + records = extract_soi_filing_status_targets() + # Only PE-valid variables are kept after filtering + assert len(records) == 532 + for r in records: + assert r["taxable_only"] is True + assert r["agi_upper"] > 10_000 + + +# ------------------------------------------------------------------ # +# Load-layer integration test # +# ------------------------------------------------------------------ # + + +class TestLoadAllTargets: + def test_load_creates_targets(self, engine, root_stratum): + """Run the full load and verify target counts.""" + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + + with Session(engine) as sess: + total_targets = sess.exec(select(Target)).all() + total_strata = sess.exec(select(Stratum)).unique().all() + total_sources = sess.exec(select(Source)).all() + + # Must have created at least 1 source + assert len(total_sources) >= 1 + + # Must have created many strata beyond the root + assert len(total_strata) > 10 + + # Must have created many targets + assert len(total_targets) > 100 + + def test_census_age_targets_present(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + age_targets = sess.exec( + select(Target).where( + Target.variable == "person_count", + Target.notes.contains("Census age"), + ) + ).all() + assert len(age_targets) == 86 + + def test_eitc_targets_present(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + # 4 child counts x 2 (returns + spending) = 8 + eitc_targets = sess.exec( + select(Target).where( + Target.notes.contains("EITC"), + ) + ).all() + assert len(eitc_targets) == 8 + + def test_soi_filer_count_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + filer_targets = sess.exec( + select(Target).where( + Target.notes.contains("SOI filer count"), + ) + ).all() + assert len(filer_targets) == 7 + + def test_healthcare_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + hc_targets = sess.exec( + select(Target).where( + Target.notes.contains("Healthcare"), + ) + ).all() + # 9 age bands x 4 expense types = 36 + assert len(hc_targets) == 36 + + def test_spm_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + spm_targets = sess.exec( + select(Target).where( + Target.notes.contains("SPM threshold"), + ) + ).all() + # 10 deciles x 2 (agi + count) = 20 + assert len(spm_targets) == 20 + + def test_negative_market_income_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + nmi_targets = sess.exec( + select(Target).where( + Target.notes.contains("Negative household market"), + ) + ).all() + assert len(nmi_targets) == 2 + + def test_state_real_estate_tax_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + ret_targets = sess.exec( + select(Target).where( + Target.notes.contains("State real estate"), + ) + ).all() + assert len(ret_targets) == 51 + + def test_idempotent(self, engine, root_stratum): + """Running twice should not duplicate targets.""" + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + count_1 = len(sess.exec(select(Target)).all()) + + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + count_2 = len(sess.exec(select(Target)).all()) + + assert count_1 == count_2 + + def test_soi_filing_status_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + soi_targets = sess.exec( + select(Target).where( + Target.notes.contains("SOI filing-status"), + ) + ).all() + # Only PE-valid variables are loaded + # PE-valid vars, deduplicated by (variable, stratum) + assert len(soi_targets) == 289 + + def test_state_aca_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + # 51 states x 2 (spending + enrollment) = 102 + # but actually stored as separate targets + spending = sess.exec( + select(Target).where( + Target.notes.contains("ACA spending"), + ) + ).all() + enrollment = sess.exec( + select(Target).where( + Target.notes.contains("ACA enrollment"), + ) + ).all() + assert len(spending) >= 50 + assert len(enrollment) >= 50 + + def test_state_medicaid_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + med = sess.exec( + select(Target).where( + Target.notes.contains("State Medicaid"), + ) + ).all() + assert len(med) == 51 + + def test_state_age_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + age = sess.exec( + select(Target).where( + Target.notes.contains("State 10yr age"), + ) + ).all() + # 50 states x 18 age ranges = 900 + assert len(age) == 900 + + def test_state_agi_targets(self, engine, root_stratum): + load_all_targets( + engine=engine, + time_period=2024, + root_stratum_id=root_stratum.stratum_id, + ) + with Session(engine) as sess: + agi = sess.exec( + select(Target).where( + Target.notes.contains("State AGI"), + ) + ).all() + assert len(agi) == 918 diff --git a/policyengine_us_data/tests/test_db/test_legacy_targets_in_production_db.py b/policyengine_us_data/tests/test_db/test_legacy_targets_in_production_db.py new file mode 100644 index 000000000..c1697780f --- /dev/null +++ b/policyengine_us_data/tests/test_db/test_legacy_targets_in_production_db.py @@ -0,0 +1,180 @@ +"""Tests that the production policy_data.db contains all legacy +calibration targets from loss.py. + +These targets were loaded by running ``load_all_targets()`` from +``etl_all_targets.py`` against the production database. The tests +verify minimum target counts by category to guard against regressions. +""" + +import sqlite3 + +import pytest + +from policyengine_us_data.storage import STORAGE_FOLDER + +DB_PATH = STORAGE_FOLDER / "calibration" / "policy_data.db" + + +@pytest.fixture +def cursor(): + """Provide a read-only cursor to the production DB.""" + conn = sqlite3.connect(str(DB_PATH)) + cur = conn.cursor() + yield cur + conn.close() + + +def _count_targets_by_notes(cursor, pattern, period=2024): + """Count targets whose notes contain the given pattern.""" + cursor.execute( + "SELECT COUNT(*) FROM targets " "WHERE notes LIKE ? AND period = ?", + (f"%{pattern}%", period), + ) + return cursor.fetchone()[0] + + +def _count_targets_by_stratum_group(cursor, group_id, period=2024): + """Count targets whose stratum has the given group_id.""" + cursor.execute( + "SELECT COUNT(*) FROM targets t " + "JOIN strata s ON t.stratum_id = s.stratum_id " + "WHERE s.stratum_group_id = ? AND t.period = ?", + (group_id, period), + ) + return cursor.fetchone()[0] + + +class TestLegacyTargetsPresent: + """Verify all legacy loss.py target categories are present.""" + + def test_census_single_year_age(self, cursor): + """86 single-year age bins (ages 0-85).""" + count = _count_targets_by_notes(cursor, "Census age bin") + assert count == 86, f"Expected 86 age targets, got {count}" + + def test_eitc_by_child_count(self, cursor): + """4 child buckets x 2 metrics (returns + spending) + 1 + national EITC. + """ + count = _count_targets_by_notes(cursor, "EITC") + # 4 child counts x 2 = 8, plus 1 national = 9 + assert count >= 8, f"Expected >= 8 EITC targets, got {count}" + + def test_soi_filer_counts(self, cursor): + """7 AGI bands for SOI filer counts.""" + count = _count_targets_by_notes(cursor, "SOI filer count") + assert count == 7, f"Expected 7 filer count targets, got {count}" + + def test_healthcare_spending_by_age(self, cursor): + """9 age bands x 4 expense types = 36.""" + count = _count_targets_by_notes(cursor, "Healthcare") + assert count == 36, f"Expected 36 healthcare targets, got {count}" + + def test_spm_threshold_deciles(self, cursor): + """10 deciles x 2 (AGI + count) = 20.""" + count = _count_targets_by_notes(cursor, "SPM threshold") + assert count == 20, f"Expected 20 SPM threshold targets, got {count}" + + def test_negative_market_income(self, cursor): + """2 targets: total + count.""" + count = _count_targets_by_notes(cursor, "Negative household market") + assert ( + count == 2 + ), f"Expected 2 negative market income targets, got {count}" + + def test_tax_expenditures(self, cursor): + """5 deductions with reform_id=1.""" + cursor.execute( + "SELECT COUNT(*) FROM targets " + "WHERE reform_id = 1 AND period = 2024", + ) + count = cursor.fetchone()[0] + assert ( + count >= 5 + ), f"Expected >= 5 tax expenditure targets, got {count}" + + def test_state_population(self, cursor): + """51 state totals + 51 under-5 = 102 minimum.""" + count = _count_targets_by_notes(cursor, "State") + assert count >= 100, f"Expected >= 100 state targets, got {count}" + + def test_state_real_estate_taxes(self, cursor): + """51 states.""" + count = _count_targets_by_notes(cursor, "State real estate") + assert count == 51, f"Expected 51 real estate tax targets, got {count}" + + def test_state_aca(self, cursor): + """51 states x 2 (spending + enrollment) = 102.""" + spending = _count_targets_by_notes(cursor, "ACA spending") + enrollment = _count_targets_by_notes(cursor, "ACA enrollment") + assert ( + spending >= 50 + ), f"Expected >= 50 ACA spending targets, got {spending}" + assert ( + enrollment >= 50 + ), f"Expected >= 50 ACA enrollment targets, got {enrollment}" + + def test_state_medicaid_enrollment(self, cursor): + """51 states.""" + count = _count_targets_by_notes(cursor, "State Medicaid") + assert count == 51, f"Expected 51 Medicaid targets, got {count}" + + def test_state_10yr_age(self, cursor): + """50 states x 18 ranges = 900.""" + count = _count_targets_by_notes(cursor, "State 10yr age") + assert count == 900, f"Expected 900 state age targets, got {count}" + + def test_state_agi(self, cursor): + """918 state AGI targets.""" + count = _count_targets_by_notes(cursor, "State AGI") + assert count == 918, f"Expected 918 state AGI targets, got {count}" + + def test_soi_filing_status(self, cursor): + """SOI filing-status x AGI bin targets.""" + count = _count_targets_by_notes(cursor, "SOI filing-status") + assert ( + count >= 280 + ), f"Expected >= 280 SOI filing-status targets, got {count}" + + def test_net_worth(self, cursor): + """National net worth target.""" + cursor.execute( + "SELECT value FROM targets " + "WHERE variable = 'net_worth' AND period = 2024 " + "AND stratum_id = 1", + ) + row = cursor.fetchone() + assert row is not None, "Net worth target not found" + assert row[0] == 160e12 + + def test_total_2024_targets_minimum(self, cursor): + """Overall count: the old loss.py had ~350+ national targets. + With state/CD targets, we expect well over 9000 total. + """ + cursor.execute("SELECT COUNT(*) FROM targets WHERE period = 2024") + count = cursor.fetchone()[0] + assert ( + count >= 12000 + ), f"Expected >= 12000 total 2024 targets, got {count}" + + def test_income_tax_constraint_uses_valid_variable(self, cursor): + """Verify that SOI filing-status strata use 'income_tax' + (not the invalid 'total_income_tax') as a constraint + variable. + """ + cursor.execute( + "SELECT COUNT(*) FROM stratum_constraints " + "WHERE constraint_variable = 'total_income_tax'" + ) + count = cursor.fetchone()[0] + assert count == 0, ( + f"Found {count} constraints using invalid " + f"'total_income_tax' -- should be 'income_tax'" + ) + + cursor.execute( + "SELECT COUNT(*) FROM stratum_constraints " + "WHERE constraint_variable = 'income_tax'" + ) + count = cursor.fetchone()[0] + assert count > 0, "No constraints using 'income_tax' found" diff --git a/policyengine_us_data/tests/test_db/test_reconcile_targets.py b/policyengine_us_data/tests/test_db/test_reconcile_targets.py new file mode 100644 index 000000000..2bb49e9cf --- /dev/null +++ b/policyengine_us_data/tests/test_db/test_reconcile_targets.py @@ -0,0 +1,439 @@ +"""Tests for the two-pass geographic target reconciliation.""" + +import logging + +import pytest +from sqlmodel import Session, SQLModel, create_engine, select + +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, +) +from policyengine_us_data.db.reconcile_targets import reconcile_targets + +# ------------------------------------------------------------------ # +# Fixtures # +# ------------------------------------------------------------------ # + + +@pytest.fixture +def engine(): + """Create an in-memory SQLite engine with all tables.""" + eng = create_engine("sqlite:///:memory:") + SQLModel.metadata.create_all(eng) + return eng + + +@pytest.fixture +def session(engine): + """Provide a session.""" + with Session(engine) as sess: + yield sess + + +def _make_geo_hierarchy(session): + """Build a small national -> 3 states -> CDs hierarchy. + + Returns dict with stratum IDs. + """ + # National + national = Stratum( + parent_stratum_id=None, + stratum_group_id=1, + notes="United States", + ) + national.constraints_rel = [] + session.add(national) + session.flush() + + # States + state_ids = {} + for fips in [1, 2, 3]: + s = Stratum( + parent_stratum_id=national.stratum_id, + stratum_group_id=1, + notes=f"State {fips}", + ) + s.constraints_rel = [ + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(fips), + ) + ] + session.add(s) + session.flush() + state_ids[fips] = s.stratum_id + + # CDs under state 1 + cd_ids = {} + for cd_geoid in [101, 102]: + cd = Stratum( + parent_stratum_id=state_ids[1], + stratum_group_id=1, + notes=f"CD {cd_geoid}", + ) + cd.constraints_rel = [ + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), + ) + ] + session.add(cd) + session.flush() + cd_ids[cd_geoid] = cd.stratum_id + + session.commit() + return { + "national": national.stratum_id, + "states": state_ids, + "cds": cd_ids, + } + + +def _add_target(session, stratum_id, variable, period, value): + """Helper to add a target.""" + t = Target( + stratum_id=stratum_id, + variable=variable, + period=period, + value=value, + active=True, + ) + session.add(t) + session.flush() + return t + + +# ------------------------------------------------------------------ # +# Tests # +# ------------------------------------------------------------------ # + + +class TestPassOneStateToNational: + def test_states_scaled_to_national(self, session): + """National=100, states=[30,40,50] -> scaled to sum=100.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "real_estate_taxes", 2022, 100) + _add_target(session, ids["states"][1], "real_estate_taxes", 2022, 30) + _add_target(session, ids["states"][2], "real_estate_taxes", 2022, 40) + _add_target(session, ids["states"][3], "real_estate_taxes", 2022, 50) + session.commit() + + stats = reconcile_targets(session) + + # Verify state targets scaled + assert stats["scaled_state"] == 3 + + state_targets = session.exec( + select(Target).where( + Target.variable == "real_estate_taxes", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + total = sum(t.value for t in state_targets) + assert abs(total - 100) < 1e-6 + + def test_raw_value_preserved(self, session): + """raw_value should hold the original source value.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "real_estate_taxes", 2022, 100) + _add_target(session, ids["states"][1], "real_estate_taxes", 2022, 30) + _add_target(session, ids["states"][2], "real_estate_taxes", 2022, 40) + _add_target(session, ids["states"][3], "real_estate_taxes", 2022, 50) + session.commit() + + reconcile_targets(session) + + state_targets = session.exec( + select(Target).where( + Target.variable == "real_estate_taxes", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + raw_values = sorted(t.raw_value for t in state_targets) + assert raw_values == [30, 40, 50] + + def test_proportions_preserved(self, session): + """Relative proportions across states should be unchanged.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "real_estate_taxes", 2022, 100) + _add_target(session, ids["states"][1], "real_estate_taxes", 2022, 30) + _add_target(session, ids["states"][2], "real_estate_taxes", 2022, 40) + _add_target(session, ids["states"][3], "real_estate_taxes", 2022, 50) + session.commit() + + reconcile_targets(session) + + state_targets = session.exec( + select(Target).where( + Target.variable == "real_estate_taxes", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + values = {t.stratum_id: t.value for t in state_targets} + # 30/40 == 3/4 should be preserved + ratio = values[ids["states"][1]] / values[ids["states"][2]] + assert abs(ratio - 30 / 40) < 1e-9 + + +class TestPassTwoCdToState: + def test_two_pass_reconciliation(self, session): + """Full two-pass: national -> states -> CDs.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "income_tax", 2022, 1000) + # States sum to 500 (not 1000) + _add_target(session, ids["states"][1], "income_tax", 2022, 200) + _add_target(session, ids["states"][2], "income_tax", 2022, 150) + _add_target(session, ids["states"][3], "income_tax", 2022, 150) + # CDs under state 1 sum to 100 (not 200) + _add_target(session, ids["cds"][101], "income_tax", 2022, 60) + _add_target(session, ids["cds"][102], "income_tax", 2022, 40) + session.commit() + + stats = reconcile_targets(session) + + # States should sum to national (1000) + state_targets = session.exec( + select(Target).where( + Target.variable == "income_tax", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + state_sum = sum(t.value for t in state_targets) + assert abs(state_sum - 1000) < 1e-6 + + # State 1 should be 200 * (1000/500) = 400 + state1_val = next( + t.value for t in state_targets if t.stratum_id == ids["states"][1] + ) + assert abs(state1_val - 400) < 1e-6 + + # CDs should sum to corrected state 1 (400) + cd_targets = session.exec( + select(Target).where( + Target.variable == "income_tax", + Target.stratum_id.in_(list(ids["cds"].values())), + ) + ).all() + cd_sum = sum(t.value for t in cd_targets) + assert abs(cd_sum - 400) < 1e-6 + + assert stats["scaled_state"] == 3 + assert stats["scaled_cd"] == 2 + + +class TestNoNationalTarget: + def test_states_unchanged_without_national(self, session): + """Without a national target, state values stay the same.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["states"][1], "real_estate_taxes", 2022, 30) + _add_target(session, ids["states"][2], "real_estate_taxes", 2022, 40) + session.commit() + + stats = reconcile_targets(session) + + assert stats["scaled_state"] == 0 + state_targets = session.exec( + select(Target).where( + Target.variable == "real_estate_taxes", + ) + ).all() + values = sorted(t.value for t in state_targets) + assert values == [30, 40] + + +class TestZeroChildSum: + def test_zero_state_sum_skipped(self, session, caplog): + """Zero state sum should log warning, not divide by zero.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "income_tax", 2022, 100) + _add_target(session, ids["states"][1], "income_tax", 2022, 0) + _add_target(session, ids["states"][2], "income_tax", 2022, 0) + _add_target(session, ids["states"][3], "income_tax", 2022, 0) + session.commit() + + with caplog.at_level(logging.WARNING): + stats = reconcile_targets(session) + + assert stats["skipped_zero_sum"] >= 1 + assert "zero" in caplog.text.lower() + + +class TestIdempotency: + def test_running_twice_same_result(self, session): + """Running reconciliation twice should produce same result.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "real_estate_taxes", 2022, 100) + _add_target(session, ids["states"][1], "real_estate_taxes", 2022, 30) + _add_target(session, ids["states"][2], "real_estate_taxes", 2022, 40) + _add_target(session, ids["states"][3], "real_estate_taxes", 2022, 50) + session.commit() + + reconcile_targets(session) + + # Capture values after first run + first_run = { + t.target_id: (t.value, t.raw_value) + for t in session.exec(select(Target)).all() + } + + reconcile_targets(session) + + # Values should be unchanged + second_run = { + t.target_id: (t.value, t.raw_value) + for t in session.exec(select(Target)).all() + } + + for tid in first_run: + v1, r1 = first_run[tid] + v2, r2 = second_run[tid] + if v1 is not None: + assert abs(v1 - v2) < 1e-9 + if r1 is not None: + assert abs(r1 - r2) < 1e-9 + + +class TestRawValuePreservation: + def test_re_run_uses_raw_value_as_base(self, session): + """Re-running should use raw_value, not already-scaled value.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "salt", 2022, 200) + _add_target(session, ids["states"][1], "salt", 2022, 50) + _add_target(session, ids["states"][2], "salt", 2022, 50) + session.commit() + + # First run: scale factor = 200/100 = 2 + reconcile_targets(session) + + state_targets = session.exec( + select(Target).where( + Target.variable == "salt", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + for t in state_targets: + assert t.raw_value == 50 + assert abs(t.value - 100) < 1e-6 + + # Second run: same raw_value * same scale = same result + reconcile_targets(session) + + state_targets = session.exec( + select(Target).where( + Target.variable == "salt", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + for t in state_targets: + assert t.raw_value == 50 # raw_value unchanged + assert abs(t.value - 100) < 1e-6 # not compounded + + +class TestNonGeographicSubStrata: + def test_filer_substrata_grouped_by_geo_ancestor(self, session): + """Targets on non-geo strata resolve to their geo ancestor.""" + ids = _make_geo_hierarchy(session) + + # Create filer sub-strata under geo strata (group_id=2) + nat_filer = Stratum( + parent_stratum_id=ids["national"], + stratum_group_id=2, + notes="National filers", + ) + nat_filer.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(nat_filer) + session.flush() + + state_filers = {} + for fips in [1, 2, 3]: + sf = Stratum( + parent_stratum_id=ids["states"][fips], + stratum_group_id=2, + notes=f"State {fips} filers", + ) + sf.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(fips), + ), + ] + session.add(sf) + session.flush() + state_filers[fips] = sf.stratum_id + + session.commit() + + # Add targets on the filer strata + _add_target(session, nat_filer.stratum_id, "income_tax", 2022, 500) + _add_target(session, state_filers[1], "income_tax", 2022, 100) + _add_target(session, state_filers[2], "income_tax", 2022, 200) + _add_target(session, state_filers[3], "income_tax", 2022, 300) + session.commit() + + stats = reconcile_targets(session) + + # States sum was 600, national is 500 -> scale = 5/6 + assert stats["scaled_state"] == 3 + + state_targets = session.exec( + select(Target).where( + Target.variable == "income_tax", + Target.stratum_id.in_(list(state_filers.values())), + ) + ).all() + total = sum(t.value for t in state_targets) + assert abs(total - 500) < 1e-6 + + +class TestAlreadyReconciled: + def test_no_scaling_when_already_matching(self, session): + """When states already sum to national, no scaling occurs.""" + ids = _make_geo_hierarchy(session) + + _add_target(session, ids["national"], "income_tax", 2022, 120) + _add_target(session, ids["states"][1], "income_tax", 2022, 30) + _add_target(session, ids["states"][2], "income_tax", 2022, 40) + _add_target(session, ids["states"][3], "income_tax", 2022, 50) + session.commit() + + stats = reconcile_targets(session) + + # Scale is 1.0, so no targets are actually modified + assert stats["scaled_state"] == 0 + + state_targets = session.exec( + select(Target).where( + Target.variable == "income_tax", + Target.stratum_id.in_(list(ids["states"].values())), + ) + ).all() + # Values should be unchanged + values = sorted(t.value for t in state_targets) + assert values == [30, 40, 50] + # raw_value should still be None (no scaling happened) + for t in state_targets: + assert t.raw_value is None diff --git a/policyengine_us_data/utils/__init__.py b/policyengine_us_data/utils/__init__.py index 2b93ecbfb..6bad628ef 100644 --- a/policyengine_us_data/utils/__init__.py +++ b/policyengine_us_data/utils/__init__.py @@ -1,5 +1 @@ -from .soi import * from .uprating import * -from .loss import * -from .l0 import * -from .seed import * diff --git a/policyengine_us_data/utils/l0.py b/policyengine_us_data/utils/l0.py deleted file mode 100644 index 3dd9e0145..000000000 --- a/policyengine_us_data/utils/l0.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math -import logging - - -class HardConcrete(nn.Module): - """HardConcrete distribution for L0 regularization.""" - - def __init__( - self, - input_dim, - output_dim=None, - temperature=0.5, - stretch=0.1, - init_mean=0.5, - ): - super().__init__() - if output_dim is None: - self.gate_size = (input_dim,) - else: - self.gate_size = (input_dim, output_dim) - self.qz_logits = nn.Parameter(torch.zeros(self.gate_size)) - self.temperature = temperature - self.stretch = stretch - self.gamma = -0.1 - self.zeta = 1.1 - self.init_mean = init_mean - self.reset_parameters() - - def reset_parameters(self): - if self.init_mean is not None: - init_val = math.log(self.init_mean / (1 - self.init_mean)) - self.qz_logits.data.fill_(init_val) - - def forward(self, input_shape=None): - if self.training: - gates = self._sample_gates() - else: - gates = self._deterministic_gates() - if input_shape is not None and len(input_shape) > len(gates.shape): - gates = gates.unsqueeze(-1).unsqueeze(-1) - return gates - - def _sample_gates(self): - u = torch.zeros_like(self.qz_logits).uniform_(1e-8, 1.0 - 1e-8) - s = torch.log(u) - torch.log(1 - u) + self.qz_logits - s = torch.sigmoid(s / self.temperature) - s = s * (self.zeta - self.gamma) + self.gamma - gates = torch.clamp(s, 0, 1) - return gates - - def _deterministic_gates(self): - probs = torch.sigmoid(self.qz_logits) - gates = probs * (self.zeta - self.gamma) + self.gamma - return torch.clamp(gates, 0, 1) - - def get_penalty(self): - logits_shifted = self.qz_logits - self.temperature * math.log( - -self.gamma / self.zeta - ) - prob_active = torch.sigmoid(logits_shifted) - return prob_active.sum() - - def get_active_prob(self): - logits_shifted = self.qz_logits - self.temperature * math.log( - -self.gamma / self.zeta - ) - return torch.sigmoid(logits_shifted) - - -class L0Linear(nn.Module): - """Linear layer with L0 regularization using HardConcrete gates.""" - - def __init__( - self, - in_features, - out_features, - bias=True, - temperature=0.5, - init_sparsity=0.5, - ): - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_features)) - else: - self.register_parameter("bias", None) - self.weight_gates = HardConcrete( - out_features, - in_features, - temperature=temperature, - init_mean=init_sparsity, - ) - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_normal_(self.weight, mode="fan_out") - if self.bias is not None: - nn.init.zeros_(self.bias) - - def forward(self, input): - gates = self.weight_gates() - masked_weight = self.weight * gates - return F.linear(input, masked_weight, self.bias) - - def get_l0_penalty(self): - return self.weight_gates.get_penalty() - - def get_sparsity(self): - with torch.no_grad(): - prob_active = self.weight_gates.get_active_prob() - return 1.0 - prob_active.mean().item() - - -class SparseMLP(nn.Module): - """Example MLP with L0 regularization on all layers""" - - def __init__( - self, - input_dim=784, - hidden_dim=256, - output_dim=10, - init_sparsity=0.5, - temperature=0.5, - ): - super().__init__() - self.fc1 = L0Linear( - input_dim, - hidden_dim, - init_sparsity=init_sparsity, - temperature=temperature, - ) - self.fc2 = L0Linear( - hidden_dim, - hidden_dim, - init_sparsity=init_sparsity, - temperature=temperature, - ) - self.fc3 = L0Linear( - hidden_dim, - output_dim, - init_sparsity=init_sparsity, - temperature=temperature, - ) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - - def get_l0_loss(self): - l0_loss = 0 - for module in self.modules(): - if isinstance(module, L0Linear): - l0_loss += module.get_l0_penalty() - return l0_loss - - def get_sparsity_stats(self): - stats = {} - for name, module in self.named_modules(): - if isinstance(module, L0Linear): - stats[name] = { - "sparsity": module.get_sparsity(), - "active_params": module.get_l0_penalty().item(), - } - return stats - - -def train_with_l0(model, train_loader, epochs=10, l0_lambda=1e-3): - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - criterion = nn.CrossEntropyLoss() - for epoch in range(epochs): - total_loss = 0 - total_l0 = 0 - for batch_idx, (data, target) in enumerate(train_loader): - optimizer.zero_grad() - output = model(data) - ce_loss = criterion(output, target) - l0_loss = model.get_l0_loss() - loss = ce_loss + l0_lambda * l0_loss - loss.backward() - optimizer.step() - total_loss += ce_loss.item() - total_l0 += l0_loss.item() - if epoch % 1 == 0: - sparsity_stats = model.get_sparsity_stats() - logging.info( - f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, L0={total_l0/len(train_loader):.4f}" - ) - for layer, stats in sparsity_stats.items(): - logging.info( - f" {layer}: {stats['sparsity']*100:.1f}% sparse, {stats['active_params']:.1f} active params" - ) - - -def prune_model(model, threshold=0.05): - for module in model.modules(): - if isinstance(module, L0Linear): - with torch.no_grad(): - prob_active = module.weight_gates.get_active_prob() - mask = (prob_active > threshold).float() - module.weight.data *= mask - return model diff --git a/policyengine_us_data/utils/loss.py b/policyengine_us_data/utils/loss.py deleted file mode 100644 index bb3f8cb88..000000000 --- a/policyengine_us_data/utils/loss.py +++ /dev/null @@ -1,961 +0,0 @@ -import pandas as pd -import numpy as np -import logging - -from policyengine_us_data.storage import STORAGE_FOLDER, CALIBRATION_FOLDER -from policyengine_us_data.storage.calibration_targets.pull_soi_targets import ( - STATE_ABBR_TO_FIPS, -) -from policyengine_core.reforms import Reform -from policyengine_us_data.utils.soi import pe_to_soi, get_soi - -# National calibration targets consumed by build_loss_matrix(). -# These are duplicated in db/etl_national_targets.py which loads them -# into policy_data.db. A future PR should wire build_loss_matrix() -# to read from the database so this dict can be deleted. See PR #488. - -HARD_CODED_TOTALS = { - "health_insurance_premiums_without_medicare_part_b": 385e9, - "other_medical_expenses": 278e9, - "medicare_part_b_premiums": 112e9, - "over_the_counter_health_expenses": 72e9, - "spm_unit_spm_threshold": 3_945e9, - "child_support_expense": 33e9, - "child_support_received": 33e9, - "spm_unit_capped_work_childcare_expenses": 348e9, - "spm_unit_capped_housing_subsidy": 35e9, - "tanf": 9e9, - # Alimony could be targeted via SOI - "alimony_income": 13e9, - "alimony_expense": 13e9, - # Rough estimate, not CPS derived - "real_estate_taxes": 500e9, # Rough estimate between 350bn and 600bn total property tax collections - "rent": 735e9, # ACS total uprated by CPI - # Table 5A from https://www.irs.gov/statistics/soi-tax-stats-individual-information-return-form-w2-statistics - # shows $38,316,190,000 in Box 7: Social security tips (2018) - # Wages and salaries grew 32% from 2018 to 2023: https://fred.stlouisfed.org/graph/?g=1J0CC - # Assume 40% through 2024 - "tip_income": 38e9 * 1.4, - # SSA benefit-type totals for 2024, derived from: - # - Total OASDI: $1,452B (CBO projection) - # - OASI trust fund: $1,227.4B in 2023 - # https://www.ssa.gov/OACT/STATS/table4a3.html - # - DI trust fund: $151.9B in 2023 - # https://www.ssa.gov/OACT/STATS/table4a3.html - # - SSA 2024 fact sheet type shares: retired+deps=78.5%, - # survivors=11.0%, disabled+deps=10.5% - # https://www.ssa.gov/OACT/FACTS/ - # - SSA Annual Statistical Supplement Table 5.A1 - # https://www.ssa.gov/policy/docs/statcomps/supplement/2024/5a.html - "social_security_retirement": 1_060e9, # ~73% of total - "social_security_disability": 148e9, # ~10.2% (disabled workers) - "social_security_survivors": 160e9, # ~11.0% (widows, children of deceased) - "social_security_dependents": 84e9, # ~5.8% (spouses/children of retired+disabled) - # IRA contribution totals from IRS SOI IRA accumulation tables. - # Tax year 2022: ~5M taxpayers x $4,510 avg = ~$22.5B traditional; - # ~10M taxpayers x $3,482 avg = ~$34.8B Roth. - # Uprated ~12% to 2024 for limit increases ($6k->$7k) and - # wage growth. - # https://www.irs.gov/statistics/soi-tax-stats-accumulation-and-distribution-of-individual-retirement-arrangements - "traditional_ira_contributions": 25e9, - "roth_ira_contributions": 39e9, -} - - -def fmt(x): - if x == -np.inf: - return "-inf" - if x == np.inf: - return "inf" - if x < 1e3: - return f"{x:.0f}" - if x < 1e6: - return f"{x/1e3:.0f}k" - if x < 1e9: - return f"{x/1e6:.0f}m" - return f"{x/1e9:.1f}bn" - - -def build_loss_matrix(dataset: type, time_period): - loss_matrix = pd.DataFrame() - df = pe_to_soi(dataset, time_period) - agi = df["adjusted_gross_income"].values - filer = df["is_tax_filer"].values - taxable = df["total_income_tax"].values > 0 - soi_subset = get_soi(time_period) - targets_array = [] - agi_level_targeted_variables = [ - "adjusted_gross_income", - "count", - "employment_income", - "business_net_profits", - "capital_gains_gross", - "ordinary_dividends", - "partnership_and_s_corp_income", - "qualified_dividends", - "taxable_interest_income", - "total_pension_income", - "total_social_security", - ] - aggregate_level_targeted_variables = [ - "business_net_losses", - "capital_gains_distributions", - "capital_gains_losses", - "estate_income", - "estate_losses", - "exempt_interest", - "ira_distributions", - "partnership_and_s_corp_losses", - "rent_and_royalty_net_income", - "rent_and_royalty_net_losses", - "taxable_pension_income", - "taxable_social_security", - "unemployment_compensation", - ] - aggregate_level_targeted_variables = [ - variable - for variable in aggregate_level_targeted_variables - if variable in df.columns - ] - soi_subset = soi_subset[ - soi_subset.Variable.isin(agi_level_targeted_variables) - | ( - soi_subset.Variable.isin(aggregate_level_targeted_variables) - & (soi_subset["AGI lower bound"] == -np.inf) - & (soi_subset["AGI upper bound"] == np.inf) - ) - ] - for _, row in soi_subset.iterrows(): - if not row["Taxable only"]: - continue # exclude non "taxable returns" statistics - - if row["AGI upper bound"] <= 10_000: - continue - - mask = ( - (agi >= row["AGI lower bound"]) - * (agi < row["AGI upper bound"]) - * filer - ) > 0 - - if row["Filing status"] == "Single": - mask *= df["filing_status"].values == "SINGLE" - elif row["Filing status"] == "Married Filing Jointly/Surviving Spouse": - mask *= df["filing_status"].values == "JOINT" - elif row["Filing status"] == "Head of Household": - mask *= df["filing_status"].values == "HEAD_OF_HOUSEHOLD" - elif row["Filing status"] == "Married Filing Separately": - mask *= df["filing_status"].values == "SEPARATE" - - values = df[row["Variable"]].values - - if row["Taxable only"]: - mask *= taxable - - if row["Count"]: - values = (values > 0).astype(float) - - agi_range_label = ( - f"{fmt(row['AGI lower bound'])}-{fmt(row['AGI upper bound'])}" - ) - taxable_label = ( - "taxable" if row["Taxable only"] else "all" + " returns" - ) - filing_status_label = row["Filing status"] - - variable_label = row["Variable"].replace("_", " ") - - if row["Count"] and not row["Variable"] == "count": - label = ( - f"nation/irs/{variable_label}/count/AGI in " - f"{agi_range_label}/{taxable_label}/{filing_status_label}" - ) - elif row["Variable"] == "count": - label = ( - f"nation/irs/{variable_label}/count/AGI in " - f"{agi_range_label}/{taxable_label}/{filing_status_label}" - ) - else: - label = ( - f"nation/irs/{variable_label}/total/AGI in " - f"{agi_range_label}/{taxable_label}/{filing_status_label}" - ) - - if label not in loss_matrix.columns: - loss_matrix[label] = mask * values - targets_array.append(row["Value"]) - - # Convert tax-unit level df to household-level df - - from policyengine_us import Microsimulation - - sim = Microsimulation(dataset=dataset) - sim.default_calculation_period = time_period - hh_id = sim.calculate("household_id", map_to="person") - tax_unit_hh_id = sim.map_result( - hh_id, "person", "tax_unit", how="value_from_first_person" - ) - - loss_matrix = loss_matrix.groupby(tax_unit_hh_id).sum() - - hh_id = sim.calculate("household_id").values - loss_matrix = loss_matrix.loc[hh_id] - - # Census single-year age population projections - - populations = pd.read_csv(CALIBRATION_FOLDER / "np2023_d5_mid.csv") - populations = populations[populations.SEX == 0][populations.RACE_HISP == 0] - populations = ( - populations.groupby("YEAR") - .sum()[[f"POP_{i}" for i in range(0, 86)]] - .T[time_period] - .values - ) # Array of [age_0_pop, age_1_pop, ...] for the given year - age = sim.calculate("age").values - for year in range(len(populations)): - label = f"nation/census/population_by_age/{year}" - loss_matrix[label] = sim.map_result( - (age >= year) * (age < year + 1), "person", "household" - ) - targets_array.append(populations[year]) - - # CBO projections - # Note: income_tax_positive matches CBO's receipts definition where - # refundable credit payments in excess of liability are classified as - # outlays, not negative receipts. See: https://www.cbo.gov/publication/43767 - - CBO_PROGRAMS = [ - "income_tax_positive", - "snap", - "social_security", - "ssi", - "unemployment_compensation", - ] - - # Mapping from variable name to CBO parameter name (when different) - CBO_PARAM_NAME_MAP = { - "income_tax_positive": "income_tax", - } - - for variable_name in CBO_PROGRAMS: - label = f"nation/cbo/{variable_name}" - loss_matrix[label] = sim.calculate( - variable_name, map_to="household" - ).values - if any(loss_matrix[label].isna()): - raise ValueError(f"Missing values for {label}") - param_name = CBO_PARAM_NAME_MAP.get(variable_name, variable_name) - targets_array.append( - sim.tax_benefit_system.parameters( - time_period - ).calibration.gov.cbo._children[param_name] - ) - - # 1. Medicaid Spending - label = "nation/hhs/medicaid_spending" - loss_matrix[label] = sim.calculate("medicaid", map_to="household").values - MEDICAID_SPENDING_2024 = 9e11 - targets_array.append(MEDICAID_SPENDING_2024) - - # 2. Medicaid Enrollment - label = "nation/hhs/medicaid_enrollment" - on_medicaid = ( - sim.calculate( - "medicaid", # or your enrollee flag - map_to="person", - period=time_period, - ).values - > 0 - ).astype(int) - loss_matrix[label] = sim.map_result(on_medicaid, "person", "household") - MEDICAID_ENROLLMENT_2024 = 72_429_055 # target lives (not thousands) - targets_array.append(MEDICAID_ENROLLMENT_2024) - - # National ACA Spending - label = "nation/gov/aca_spending" - loss_matrix[label] = sim.calculate( - "aca_ptc", map_to="household", period=2025 - ).values - ACA_SPENDING_2024 = 9.8e10 # 2024 outlays on PTC - targets_array.append(ACA_SPENDING_2024) - - # National ACA Enrollment (people receiving a PTC) - label = "nation/gov/aca_enrollment" - on_ptc = ( - sim.calculate("aca_ptc", map_to="person", period=2025).values > 0 - ).astype(int) - loss_matrix[label] = sim.map_result(on_ptc, "person", "household") - - ACA_PTC_ENROLLMENT_2024 = 19_743_689 # people enrolled - targets_array.append(ACA_PTC_ENROLLMENT_2024) - - # Treasury EITC - - loss_matrix["nation/treasury/eitc"] = sim.calculate( - "eitc", map_to="household" - ).values - eitc_spending = ( - sim.tax_benefit_system.parameters.calibration.gov.treasury.tax_expenditures.eitc - ) - targets_array.append(eitc_spending(time_period)) - - # IRS EITC filers and totals by child counts - eitc_stats = pd.read_csv(CALIBRATION_FOLDER / "eitc.csv") - - eitc_spending_uprating = eitc_spending(time_period) / eitc_spending(2021) - population = ( - sim.tax_benefit_system.parameters.calibration.gov.census.populations.total - ) - population_uprating = population(time_period) / population(2021) - - for _, row in eitc_stats.iterrows(): - returns_label = ( - f"nation/irs/eitc/returns/count_children_{row['count_children']}" - ) - eitc_eligible_children = sim.calculate("eitc_child_count").values - eitc = sim.calculate("eitc").values - if row["count_children"] < 2: - meets_child_criteria = ( - eitc_eligible_children == row["count_children"] - ) - else: - meets_child_criteria = ( - eitc_eligible_children >= row["count_children"] - ) - loss_matrix[returns_label] = sim.map_result( - (eitc > 0) * meets_child_criteria, - "tax_unit", - "household", - ) - targets_array.append(row["eitc_returns"] * population_uprating) - - spending_label = ( - f"nation/irs/eitc/spending/count_children_{row['count_children']}" - ) - loss_matrix[spending_label] = sim.map_result( - eitc * meets_child_criteria, - "tax_unit", - "household", - ) - targets_array.append(row["eitc_total"] * eitc_spending_uprating) - - # Tax filer counts by AGI band (SOI Table 1.1) - # This calibrates total filers (not just taxable returns) including - # low-AGI filers who are important for income distribution accuracy - SOI_FILER_COUNTS_2015 = { - # (agi_lower, agi_upper): total_returns - (-np.inf, 0): 2_072_066, - (0, 5_000): 10_134_703, - (5_000, 10_000): 11_398_595, - (10_000, 25_000): 23_447_927, - (25_000, 50_000): 23_727_745, - (50_000, 100_000): 32_801_908, - (100_000, np.inf): 25_120_985, - } - - # Get AGI and filer status at tax unit level, mapped to household - agi_tu = sim.calculate("adjusted_gross_income").values - is_filer_tu = sim.calculate("tax_unit_is_filer").values > 0 - - for ( - agi_lower, - agi_upper, - ), filer_count_2015 in SOI_FILER_COUNTS_2015.items(): - in_band = (agi_tu >= agi_lower) & (agi_tu < agi_upper) - label = f"nation/soi/filer_count/agi_{fmt(agi_lower)}_{fmt(agi_upper)}" - loss_matrix[label] = sim.map_result( - (is_filer_tu & in_band).astype(float), - "tax_unit", - "household", - ) - # Uprate from 2015 to current year using population growth - uprated_target = filer_count_2015 * population_uprating - targets_array.append(uprated_target) - - # Hard-coded totals - for variable_name, target in HARD_CODED_TOTALS.items(): - label = f"nation/census/{variable_name}" - loss_matrix[label] = sim.calculate( - variable_name, map_to="household" - ).values - if any(loss_matrix[label].isna()): - raise ValueError(f"Missing values for {label}") - targets_array.append(target) - - # Negative household market income total rough estimate from the IRS SOI PUF - - market_income = sim.calculate("household_market_income").values - loss_matrix["nation/irs/negative_household_market_income_total"] = ( - market_income * (market_income < 0) - ) - targets_array.append(-138e9) - - loss_matrix["nation/irs/negative_household_market_income_count"] = ( - market_income < 0 - ).astype(float) - targets_array.append(3e6) - - # Healthcare spending by age - - healthcare = pd.read_csv(CALIBRATION_FOLDER / "healthcare_spending.csv") - - for _, row in healthcare.iterrows(): - age_lower_bound = int(row["age_10_year_lower_bound"]) - in_age_range = (age >= age_lower_bound) * (age < age_lower_bound + 10) - for expense_type in [ - "health_insurance_premiums_without_medicare_part_b", - "over_the_counter_health_expenses", - "other_medical_expenses", - "medicare_part_b_premiums", - ]: - label = f"nation/census/{expense_type}/age_{age_lower_bound}_to_{age_lower_bound+9}" - value = sim.calculate(expense_type).values - loss_matrix[label] = sim.map_result( - in_age_range * value, "person", "household" - ) - targets_array.append(row[expense_type]) - - # AGI by SPM threshold totals - - spm_threshold_agi = pd.read_csv( - CALIBRATION_FOLDER / "spm_threshold_agi.csv" - ) - - for _, row in spm_threshold_agi.iterrows(): - spm_unit_agi = sim.calculate( - "adjusted_gross_income", map_to="spm_unit" - ).values - spm_threshold = sim.calculate("spm_unit_spm_threshold").values - in_threshold_range = (spm_threshold >= row["lower_spm_threshold"]) * ( - spm_threshold < row["upper_spm_threshold"] - ) - label = ( - f"nation/census/agi_in_spm_threshold_decile_{int(row['decile'])}" - ) - loss_matrix[label] = sim.map_result( - in_threshold_range * spm_unit_agi, "spm_unit", "household" - ) - targets_array.append(row["adjusted_gross_income"]) - - label = ( - f"nation/census/count_in_spm_threshold_decile_{int(row['decile'])}" - ) - loss_matrix[label] = sim.map_result( - in_threshold_range, "spm_unit", "household" - ) - targets_array.append(row["count"]) - - # Population by state and population under 5 by state - - state_population = pd.read_csv( - CALIBRATION_FOLDER / "population_by_state.csv" - ) - - for _, row in state_population.iterrows(): - in_state = sim.calculate("state_code", map_to="person") == row["state"] - label = f"state/census/population_by_state/{row['state']}" - loss_matrix[label] = sim.map_result(in_state, "person", "household") - targets_array.append(row["population"]) - - under_5 = sim.calculate("age").values < 5 - in_state_under_5 = in_state * under_5 - label = f"state/census/population_under_5_by_state/{row['state']}" - loss_matrix[label] = sim.map_result( - in_state_under_5, "person", "household" - ) - targets_array.append(row["population_under_5"]) - - age = sim.calculate("age").values - infants = (age >= 0) & (age < 1) - label = "nation/census/infants" - loss_matrix[label] = sim.map_result(infants, "person", "household") - # Total number of infants in the 1 Year ACS - INFANTS_2023 = 3_491_679 - INFANTS_2022 = 3_437_933 - # Assume infant population grows at the same rate from 2023. - infants_2024 = INFANTS_2023 * (INFANTS_2023 / INFANTS_2022) - targets_array.append(infants_2024) - - networth = sim.calculate("net_worth").values - label = "nation/net_worth/total" - loss_matrix[label] = networth - # Federal Reserve estimate of $160 trillion in 2024Q4 - # https://fred.stlouisfed.org/series/BOGZ1FL192090005Q - NET_WORTH_2024 = 160e12 - targets_array.append(NET_WORTH_2024) - - # SALT tax expenditure targeting - - _add_tax_expenditure_targets( - dataset, time_period, sim, loss_matrix, targets_array - ) - - if any(loss_matrix.isna().sum() > 0): - raise ValueError("Some targets are missing from the loss matrix") - - if any(pd.isna(targets_array)): - raise ValueError("Some targets are missing from the targets array") - - # SSN Card Type calibration - for card_type_str in ["NONE"]: # SSN card types as strings - ssn_type_mask = sim.calculate("ssn_card_type").values == card_type_str - - # Overall count by SSN card type - label = f"nation/ssa/ssn_card_type_{card_type_str.lower()}_count" - loss_matrix[label] = sim.map_result( - ssn_type_mask, "person", "household" - ) - - # Target undocumented population by year based on various sources - if card_type_str == "NONE": - undocumented_targets = { - 2022: 11.0e6, # Official DHS Office of Homeland Security Statistics estimate for 1 Jan 2022 - # https://ohss.dhs.gov/sites/default/files/2024-06/2024_0418_ohss_estimates-of-the-unauthorized-immigrant-population-residing-in-the-united-states-january-2018%25E2%2580%2593january-2022.pdf - 2023: 12.2e6, # Center for Migration Studies ACS-based residual estimate (published May 2025) - # https://cmsny.org/publications/the-undocumented-population-in-the-united-states-increased-to-12-million-in-2023/ - 2024: 13.0e6, # Reuters synthesis of experts ahead of 2025 change ("~13-14 million") - central value - # https://www.reuters.com/data/who-are-immigrants-who-could-be-targeted-trumps-mass-deportation-plans-2024-12-18/ - 2025: 13.0e6, # Same midpoint carried forward - CBP data show 95% drop in border apprehensions - } - if time_period <= 2022: - target_count = 11.0e6 # Use 2022 value for earlier years - elif time_period >= 2025: - target_count = 13.0e6 # Use 2025 value for later years - else: - target_count = undocumented_targets[time_period] - - targets_array.append(target_count) - - # ACA spending by state - spending_by_state = pd.read_csv( - CALIBRATION_FOLDER / "aca_spending_and_enrollment_2024.csv" - ) - # Monthly to yearly - spending_by_state["spending"] = spending_by_state["spending"] * 12 - # Adjust to match national target - spending_by_state["spending"] = spending_by_state["spending"] * ( - ACA_SPENDING_2024 / spending_by_state["spending"].sum() - ) - - for _, row in spending_by_state.iterrows(): - # Households located in this state - in_state = ( - sim.calculate("state_code", map_to="household").values - == row["state"] - ) - - # ACA PTC amounts for every household (2025) - aca_value = sim.calculate( - "aca_ptc", map_to="household", period=2025 - ).values - - # Add a loss-matrix entry and matching target - label = f"nation/irs/aca_spending/{row['state'].lower()}" - loss_matrix[label] = aca_value * in_state - annual_target = row["spending"] - if any(loss_matrix[label].isna()): - raise ValueError(f"Missing values for {label}") - targets_array.append(annual_target) - - # Marketplace enrollment by state (targets in thousands) - enrollment_by_state = pd.read_csv( - CALIBRATION_FOLDER / "aca_spending_and_enrollment_2024.csv" - ) - - # One-time pulls so we don’t re-compute inside the loop - state_person = sim.calculate("state_code", map_to="person").values - - # Flag people in households that actually receive any PTC (> 0) - in_tax_unit_with_aca = ( - sim.calculate("aca_ptc", map_to="person", period=2025).values > 0 - ) - is_aca_eligible = sim.calculate( - "is_aca_ptc_eligible", map_to="person", period=2025 - ).values - is_enrolled = in_tax_unit_with_aca & is_aca_eligible - - for _, row in enrollment_by_state.iterrows(): - # People who both live in the state and have marketplace coverage - in_state = state_person == row["state"] - in_state_enrolled = in_state & is_enrolled - - label = f"state/irs/aca_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result( - in_state_enrolled, "person", "household" - ) - if any(loss_matrix[label].isna()): - raise ValueError(f"Missing values for {label}") - - # Convert to thousands for the target - targets_array.append(row["enrollment"]) - - # Medicaid enrollment by state - - enrollment_by_state = pd.read_csv( - CALIBRATION_FOLDER / "medicaid_enrollment_2024.csv" - ) - - # One-time pulls so we don’t re-compute inside the loop - state_person = sim.calculate("state_code", map_to="person").values - - # Flag people in households that actually receive medicaid - has_medicaid = sim.calculate( - "medicaid_enrolled", map_to="person", period=2025 - ) - is_medicaid_eligible = sim.calculate( - "is_medicaid_eligible", map_to="person", period=2025 - ).values - is_enrolled = has_medicaid & is_medicaid_eligible - - for _, row in enrollment_by_state.iterrows(): - # People who both live in the state and have marketplace coverage - in_state = state_person == row["state"] - in_state_enrolled = in_state & is_enrolled - - label = f"irs/medicaid_enrollment/{row['state'].lower()}" - loss_matrix[label] = sim.map_result( - in_state_enrolled, "person", "household" - ) - if any(loss_matrix[label].isna()): - raise ValueError(f"Missing values for {label}") - - # Convert to thousands for the target - targets_array.append(row["enrollment"]) - - logging.info( - f"Targeting Medicaid enrollment for {row['state']} " - f"with target {row['enrollment']:.0f}k" - ) - - # State 10-year age targets - - age_targets = pd.read_csv(CALIBRATION_FOLDER / "age_state.csv") - - for state in age_targets.GEO_NAME.unique(): - state_mask = state_person == state - for age_range in age_targets.columns[2:]: - if "+" in age_range: - # Handle the "85+" case - age_lower_bound = int(age_range.replace("+", "")) - age_upper_bound = np.inf - else: - age_lower_bound, age_upper_bound = map( - int, age_range.split("-") - ) - - age_mask = (age >= age_lower_bound) & (age <= age_upper_bound) - label = f"state/census/age/{state}/{age_range}" - loss_matrix[label] = sim.map_result( - state_mask * age_mask, "person", "household" - ) - target_value = age_targets.loc[ - age_targets.GEO_NAME == state, age_range - ].values[0] - targets_array.append(target_value) - - agi_state_target_names, agi_state_targets = _add_agi_state_targets() - targets_array.extend(agi_state_targets) - loss_matrix = _add_agi_metric_columns(loss_matrix, sim) - - targets_array, loss_matrix = _add_state_real_estate_taxes( - loss_matrix, targets_array, sim - ) - - snap_state_target_names, snap_state_targets = _add_snap_state_targets(sim) - targets_array.extend(snap_state_targets) - loss_matrix = _add_snap_metric_columns(loss_matrix, sim) - - return loss_matrix, np.array(targets_array) - - -def _add_tax_expenditure_targets( - dataset, - time_period, - baseline_simulation, - loss_matrix: pd.DataFrame, - targets_array: list, -): - from policyengine_us import Microsimulation - - income_tax_b = baseline_simulation.calculate( - "income_tax", map_to="household" - ).values - - # Dictionary of itemized deductions and their target values - # (in billions for 2024, per the 2024 JCT Tax Expenditures report) - # https://www.jct.gov/publications/2024/jcx-48-24/ - ITEMIZED_DEDUCTIONS = { - "salt_deduction": 21.247e9, - "medical_expense_deduction": 11.4e9, - "charitable_deduction": 65.301e9, - "interest_deduction": 24.8e9, - "qualified_business_income_deduction": 63.1e9, - } - - def make_repeal_class(deduction_var): - # Create a custom Reform subclass that neutralizes the given deduction. - class RepealDeduction(Reform): - def apply(self): - self.neutralize_variable(deduction_var) - - return RepealDeduction - - for deduction, target in ITEMIZED_DEDUCTIONS.items(): - # Generate the custom repeal class for the current deduction. - RepealDeduction = make_repeal_class(deduction) - - # Run the microsimulation using the repeal reform. - simulation = Microsimulation(dataset=dataset, reform=RepealDeduction) - simulation.default_calculation_period = time_period - - # Calculate the baseline and reform income tax values. - income_tax_r = simulation.calculate( - "income_tax", map_to="household" - ).values - - # Compute the tax expenditure (TE) values. - te_values = income_tax_r - income_tax_b - - # Record the TE difference and the corresponding target value. - loss_matrix[f"nation/jct/{deduction}_expenditure"] = te_values - targets_array.append(target) - - -def get_agi_band_label(lower: float, upper: float) -> str: - """Get the label for the AGI band based on lower and upper bounds.""" - if lower <= 0: - return f"-inf_{int(upper)}" - elif np.isposinf(upper): - return f"{int(lower)}_inf" - else: - return f"{int(lower)}_{int(upper)}" - - -def _add_agi_state_targets(): - """ - Create an aggregate target matrix for the appropriate geographic area - """ - - soi_targets = pd.read_csv(CALIBRATION_FOLDER / "agi_state.csv") - - soi_targets["target_name"] = ( - "state/" - + soi_targets["GEO_NAME"] - + "/" - + soi_targets["VARIABLE"] - + "/" - + soi_targets.apply( - lambda r: get_agi_band_label( - r["AGI_LOWER_BOUND"], r["AGI_UPPER_BOUND"] - ), - axis=1, - ) - ) - - target_names = soi_targets["target_name"].tolist() - target_values = soi_targets["VALUE"].astype(float).tolist() - return target_names, target_values - - -def _add_agi_metric_columns( - loss_matrix: pd.DataFrame, - sim, -): - """ - Add AGI metric columns to the loss_matrix. - """ - soi_targets = pd.read_csv(CALIBRATION_FOLDER / "agi_state.csv") - - agi = sim.calculate("adjusted_gross_income").values - state = sim.calculate("state_code", map_to="person").values - state = sim.map_result( - state, "person", "tax_unit", how="value_from_first_person" - ) - - for _, r in soi_targets.iterrows(): - lower, upper = r.AGI_LOWER_BOUND, r.AGI_UPPER_BOUND - band = get_agi_band_label(lower, upper) - - in_state = state == r.GEO_NAME - in_band = (agi > lower) & (agi <= upper) - - if r.IS_COUNT: - metric = (in_state & in_band & (agi > 0)).astype(float) - else: - metric = np.where(in_state & in_band, agi, 0.0) - - metric = sim.map_result(metric, "tax_unit", "household") - - col_name = f"state/{r.GEO_NAME}/{r.VARIABLE}/{band}" - loss_matrix[col_name] = metric - - return loss_matrix - - -def _add_state_real_estate_taxes(loss_matrix, targets_list, sim): - """ - Add state real estate taxes to the loss matrix and targets list. - """ - # Read the real estate taxes data - real_estate_taxes_targets = pd.read_csv( - CALIBRATION_FOLDER / "real_estate_taxes_by_state_acs.csv" - ) - national_total = HARD_CODED_TOTALS["real_estate_taxes"] - state_sum = real_estate_taxes_targets["real_estate_taxes_bn"].sum() * 1e9 - national_to_state_diff = national_total / state_sum - real_estate_taxes_targets["real_estate_taxes_bn"] *= national_to_state_diff - real_estate_taxes_targets["real_estate_taxes_bn"] = ( - real_estate_taxes_targets["real_estate_taxes_bn"] * 1e9 - ) - - assert np.isclose( - real_estate_taxes_targets["real_estate_taxes_bn"].sum(), - national_total, - rtol=1e-8, - ), "Real estate tax totals do not sum to national target" - - targets_list.extend( - real_estate_taxes_targets["real_estate_taxes_bn"].tolist() - ) - - real_estate_taxes = sim.calculate( - "real_estate_taxes", map_to="household" - ).values - state = sim.calculate("state_code", map_to="household").values - - for _, r in real_estate_taxes_targets.iterrows(): - in_state = (state == r["state_code"]).astype(float) - label = f"state/real_estate_taxes/{r['state_code']}" - loss_matrix[label] = real_estate_taxes * in_state - - return targets_list, loss_matrix - - -def _add_snap_state_targets(sim): - """ - Add snap targets at the state level, adjusted in aggregate to the sim - """ - snap_targets = pd.read_csv(CALIBRATION_FOLDER / "snap_state.csv") - time_period = sim.default_calculation_period - - national_cost_target = sim.tax_benefit_system.parameters( - time_period - ).calibration.gov.cbo._children["snap"] - ratio = snap_targets[["Cost"]].sum().values[0] / national_cost_target - snap_targets[["CostAdj"]] = snap_targets[["Cost"]] / ratio - assert ( - np.round(snap_targets[["CostAdj"]].sum().values[0]) - == national_cost_target - ) - - cost_targets = snap_targets.copy()[["GEO_ID", "CostAdj"]] - cost_targets["target_name"] = ( - cost_targets["GEO_ID"].str[-4:] + "/snap-cost" - ) - - hh_targets = snap_targets.copy()[["GEO_ID", "Households"]] - hh_targets["target_name"] = snap_targets["GEO_ID"].str[-4:] + "/snap-hhs" - - target_names = ( - cost_targets["target_name"].tolist() - + hh_targets["target_name"].tolist() - ) - target_values = ( - cost_targets["CostAdj"].astype(float).tolist() - + hh_targets["Households"].astype(float).tolist() - ) - return target_names, target_values - - -def _add_snap_metric_columns( - loss_matrix: pd.DataFrame, - sim, -): - """ - Add SNAP metric columns to the loss_matrix. - """ - snap_targets = pd.read_csv(CALIBRATION_FOLDER / "snap_state.csv") - - snap_cost = sim.calculate("snap_reported", map_to="household").values - snap_hhs = ( - sim.calculate("snap_reported", map_to="household").values > 0 - ).astype(int) - - state = sim.calculate("state_code", map_to="person").values - state = sim.map_result( - state, "person", "household", how="value_from_first_person" - ) - STATE_ABBR_TO_FIPS["DC"] = 11 - state_fips = pd.Series(state).apply(lambda s: STATE_ABBR_TO_FIPS[s]) - - for _, r in snap_targets.iterrows(): - in_state = state_fips == r.GEO_ID[-2:] - metric = np.where(in_state, snap_cost, 0.0) - col_name = f"{r.GEO_ID[-4:]}/snap-cost" - loss_matrix[col_name] = metric - - for _, r in snap_targets.iterrows(): - in_state = state_fips == r.GEO_ID[-2:] - metric = np.where(in_state, snap_hhs, 0.0) - col_name = f"{r.GEO_ID[-4:]}/snap-hhs" - loss_matrix[col_name] = metric - - return loss_matrix - - -def print_reweighting_diagnostics( - optimised_weights, loss_matrix, targets_array, label -): - # Convert all inputs to NumPy arrays right at the start - optimised_weights_np = ( - optimised_weights.numpy() - if hasattr(optimised_weights, "numpy") - else np.asarray(optimised_weights) - ) - loss_matrix_np = ( - loss_matrix.numpy() - if hasattr(loss_matrix, "numpy") - else np.asarray(loss_matrix) - ) - targets_array_np = ( - targets_array.numpy() - if hasattr(targets_array, "numpy") - else np.asarray(targets_array) - ) - - logging.info(f"\n\n---{label}: reweighting quick diagnostics----\n") - logging.info( - f"{np.sum(optimised_weights_np == 0)} are zero, " - f"{np.sum(optimised_weights_np != 0)} weights are nonzero" - ) - - # All subsequent calculations use the guaranteed NumPy versions - estimate = optimised_weights_np @ loss_matrix_np - - rel_error = ( - ((estimate - targets_array_np) + 1) / (targets_array_np + 1) - ) ** 2 - within_10_percent_mask = np.abs(estimate - targets_array_np) <= ( - 0.10 * np.abs(targets_array_np) - ) - percent_within_10 = np.mean(within_10_percent_mask) * 100 - logging.info( - f"rel_error: min: {np.min(rel_error):.2f}\n" - f"max: {np.max(rel_error):.2f}\n" - f"mean: {np.mean(rel_error):.2f}\n" - f"median: {np.median(rel_error):.2f}\n" - f"Within 10% of target: {percent_within_10:.2f}%" - ) - logging.info("Relative error over 100% for:") - for i in np.where(rel_error > 1)[0]: - # Keep this check, as Tensors won't have a .columns attribute - if hasattr(loss_matrix, "columns"): - logging.info(f"target_name: {loss_matrix.columns[i]}") - else: - logging.info(f"target_index: {i}") - - logging.info(f"target_value: {targets_array_np[i]}") - logging.info(f"estimate_value: {estimate[i]}") - logging.info(f"has rel_error: {rel_error[i]:.2f}\n") - logging.info("---End of reweighting quick diagnostics------") - return percent_within_10 diff --git a/policyengine_us_data/utils/seed.py b/policyengine_us_data/utils/seed.py deleted file mode 100644 index e5fa7669a..000000000 --- a/policyengine_us_data/utils/seed.py +++ /dev/null @@ -1,21 +0,0 @@ -import random -import numpy as np - -try: - import torch -except ImportError: - torch = None - - -def set_seeds(seed: int) -> None: - """Seed Python, NumPy and PyTorch for reproducible behavior.""" - random.seed(seed) - np.random.seed(seed) - if torch is not None: - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - try: - torch.use_deterministic_algorithms(True, warn_only=True) - except Exception: - pass diff --git a/policyengine_us_data/utils/soi.py b/policyengine_us_data/utils/soi.py deleted file mode 100644 index d9538addb..000000000 --- a/policyengine_us_data/utils/soi.py +++ /dev/null @@ -1,274 +0,0 @@ -import pandas as pd -import numpy as np -from .uprating import create_policyengine_uprating_factors_table -from policyengine_us_data.storage import STORAGE_FOLDER, CALIBRATION_FOLDER - - -def pe_to_soi(pe_dataset, year): - from policyengine_us import Microsimulation - - pe_sim = Microsimulation(dataset=pe_dataset) - pe_sim.default_calculation_period = year - df = pd.DataFrame() - - pe = lambda variable: np.array( - pe_sim.calculate(variable, map_to="tax_unit") - ) - - df["adjusted_gross_income"] = pe("adjusted_gross_income") - df["exemption"] = pe("exemptions") - df["itemded"] = pe("itemized_taxable_income_deductions") - df["income_tax_after_credits"] = pe("income_tax") - df["total_income_tax"] = pe("income_tax_before_credits") - df["taxable_income"] = pe("taxable_income") - df["business_net_profits"] = pe("self_employment_income") * ( - pe("self_employment_income") > 0 - ) - df["business_net_losses"] = -pe("self_employment_income") * ( - pe("self_employment_income") < 0 - ) - df["capital_gains_distributions"] = pe("non_sch_d_capital_gains") - df["capital_gains_gross"] = pe("loss_limited_net_capital_gains") * ( - pe("loss_limited_net_capital_gains") > 0 - ) - df["capital_gains_losses"] = -pe("loss_limited_net_capital_gains") * ( - pe("loss_limited_net_capital_gains") < 0 - ) - df["estate_income"] = pe("estate_income") * (pe("estate_income") > 0) - df["estate_losses"] = -pe("estate_income") * (pe("estate_income") < 0) - df["exempt_interest"] = pe("tax_exempt_interest_income") - df["ira_distributions"] = pe("taxable_ira_distributions") - df["count_of_exemptions"] = pe("exemptions_count") - df["ordinary_dividends"] = pe("non_qualified_dividend_income") + pe( - "qualified_dividend_income" - ) - df["partnership_and_s_corp_income"] = pe("partnership_s_corp_income") * ( - pe("partnership_s_corp_income") > 0 - ) - df["partnership_and_s_corp_losses"] = -pe("partnership_s_corp_income") * ( - pe("partnership_s_corp_income") < 0 - ) - df["total_pension_income"] = pe("pension_income") - df["taxable_pension_income"] = pe("taxable_pension_income") - df["qualified_dividends"] = pe("qualified_dividend_income") - df["rent_and_royalty_net_income"] = pe("rental_income") * ( - pe("rental_income") > 0 - ) - df["rent_and_royalty_net_losses"] = -pe("rental_income") * ( - pe("rental_income") < 0 - ) - df["total_social_security"] = pe("social_security") - df["taxable_social_security"] = pe("taxable_social_security") - df["income_tax_before_credits"] = pe("income_tax_before_credits") - df["taxable_interest_income"] = pe("taxable_interest_income") - df["unemployment_compensation"] = pe("taxable_unemployment_compensation") - df["employment_income"] = pe("irs_employment_income") - df["qualified_business_income_deduction"] = pe( - "qualified_business_income_deduction" - ) - df["charitable_contributions_deduction"] = pe("charitable_deduction") - df["interest_paid_deductions"] = pe("interest_deduction") - df["medical_expense_deductions_uncapped"] = pe("medical_expense_deduction") - df["state_and_local_tax_deductions"] = pe("salt_deduction") - df["itemized_state_income_and_sales_tax_deductions"] = pe( - "state_and_local_sales_or_income_tax" - ) - df["itemized_real_estate_tax_deductions"] = pe("real_estate_taxes") - df["is_tax_filer"] = pe("tax_unit_is_filer") - df["count"] = 1 - - df["filing_status"] = pe("filing_status") - df["weight"] = pe("tax_unit_weight") - df["household_id"] = pe("household_id") - - return df - - -def puf_to_soi(puf, year): - df = pd.DataFrame() - - df["adjusted_gross_income"] = puf.E00100 - df["total_income_tax"] = puf.E06500 - df["employment_income"] = puf.E00200 - df["capital_gains_distributions"] = puf.E01100 - df["capital_gains_gross"] = puf["E01000"] * (puf["E01000"] > 0) - df["capital_gains_losses"] = -puf["E01000"] * (puf["E01000"] < 0) - df["estate_income"] = puf.E26390 - df["estate_losses"] = puf.E26400 - df["exempt_interest"] = puf.E00400 - df["ira_distributions"] = puf.E01400 - df["count_of_exemptions"] = puf.XTOT - df["ordinary_dividends"] = puf.E00600 - df["partnership_and_s_corp_income"] = puf.E26270 * (puf.E26270 > 0) - df["partnership_and_s_corp_losses"] = -puf.E26270 * (puf.E26270 < 0) - df["total_pension_income"] = puf.E01500 - df["taxable_pension_income"] = puf.E01700 - df["qualified_dividends"] = puf.E00650 - df["rent_and_royalty_net_income"] = puf.E25850 - df["rent_and_royalty_net_losses"] = puf.E25860 - df["total_social_security"] = puf.E02400 - df["taxable_social_security"] = puf.E02500 - df["income_tax_before_credits"] = puf.E06500 - df["taxable_interest_income"] = puf.E00300 - df["unemployment_compensation"] = puf.E02300 - df["employment_income"] = puf.E00200 - df["charitable_contributions_deduction"] = puf.E19700 - df["interest_paid_deductions"] = puf.E19200 - df["medical_expense_deductions_uncapped"] = puf.E17500 - df["itemized_state_income_and_sales_tax_deductions"] = puf.E18400 - df["itemized_real_estate_tax_deductions"] = puf.E18500 - df["state_and_local_tax_deductions"] = puf.E18400 + puf.E18500 - df["income_tax_after_credits"] = puf.E08800 - df["business_net_profits"] = puf.E00900 * (puf.E00900 > 0) - df["business_net_losses"] = -puf.E00900 * (puf.E00900 < 0) - df["taxable_income"] = puf.E04800 - df["is_tax_filer"] = True - df["count"] = 1 - df["filing_status"] = puf.MARS.map( - { - 0: "SINGLE", # Assume the aggregate record is single - 1: "SINGLE", - 2: "JOINT", - 3: "SEPARATE", - 4: "HEAD_OF_HOUSEHOLD", - } - ) - - df["weight"] = puf["S006"] / 100 - - return df - - -def get_soi(year: int) -> pd.DataFrame: - uprating = create_policyengine_uprating_factors_table() - - uprating_map = { - "adjusted_gross_income": "adjusted_gross_income", - "count": "population", - "employment_income": "employment_income", - "business_net_profits": "self_employment_income", - "capital_gains_gross": "long_term_capital_gains", - "ordinary_dividends": "non_qualified_dividend_income", - "partnership_and_s_corp_income": "partnership_s_corp_income", - "qualified_dividends": "qualified_dividend_income", - "taxable_interest_income": "taxable_interest_income", - "total_pension_income": "pension_income", - "total_social_security": "social_security", - "business_net_losses": "self_employment_income", - "capital_gains_distributions": "long_term_capital_gains", - "capital_gains_losses": "long_term_capital_gains", - "estate_income": "estate_income", - "estate_losses": "estate_income", - "exempt_interest": "tax_exempt_interest_income", - "ira_distributions": "taxable_ira_distributions", - "partnership_and_s_corp_losses": "partnership_s_corp_income", - "rent_and_royalty_net_income": "rental_income", - "rent_and_royalty_net_losses": "rental_income", - "taxable_pension_income": "taxable_pension_income", - "taxable_social_security": "taxable_social_security", - "unemployment_compensation": "unemployment_compensation", - } - soi = pd.read_csv(CALIBRATION_FOLDER / "soi_targets.csv") - soi = soi[soi.Year == soi.Year.max()] - - uprating_factors = {} - for variable in uprating_map: - pe_name = uprating_map.get(variable) - if pe_name in uprating.index: - uprating_factors[variable] = ( - uprating.loc[pe_name, year] - / uprating.loc[pe_name, soi.Year.max()] - ) - else: - uprating_factors[variable] = ( - uprating.loc["employment_income", year] - / uprating.loc["employment_income", soi.Year.max()] - ) - - for variable, uprating_factor in uprating_factors.items(): - soi.loc[soi.Variable == variable, "Value"] *= uprating_factor - - return soi - - -def compare_soi_replication_to_soi(df, soi): - variables = [] - filing_statuses = [] - agi_lower_bounds = [] - agi_upper_bounds = [] - counts = [] - taxables = [] - full_pops = [] - values = [] - soi_values = [] - - for i, row in soi.iterrows(): - if row.Variable not in df.columns: - continue - - subset = df[df.adjusted_gross_income >= row["AGI lower bound"]][ - df.adjusted_gross_income < row["AGI upper bound"] - ] - - variable = row["Variable"] - - fs = row["Filing status"] - if fs == "Single": - subset = subset[subset.filing_status == "SINGLE"] - elif fs == "Head of Household": - subset = subset[subset.filing_status == "HEAD_OF_HOUSEHOLD"] - elif fs == "Married Filing Jointly/Surviving Spouse": - subset = subset[ - subset.filing_status.isin(["JOINT", "SURVIVING_SPOUSE"]) - ] - elif fs == "Married Filing Separately": - subset = subset[subset.filing_status == "SEPARATE"] - - if row["Taxable only"]: - subset = subset[subset.total_income_tax > 0] - else: - subset = subset[subset.is_tax_filer.values > 0] - - if row["Count"]: - value = subset[subset[variable] > 0].weight.sum() - else: - value = (subset[variable] * subset.weight).sum() - - variables.append(row["Variable"]) - filing_statuses.append(row["Filing status"]) - agi_lower_bounds.append(row["AGI lower bound"]) - agi_upper_bounds.append(row["AGI upper bound"]) - counts.append(row["Count"] or (row["Variable"] == "count")) - taxables.append(row["Taxable only"]) - full_pops.append(row["Full population"]) - values.append(value) - soi_values.append(row["Value"]) - - soi_replication = pd.DataFrame( - { - "Variable": variables, - "Filing status": filing_statuses, - "AGI lower bound": agi_lower_bounds, - "AGI upper bound": agi_upper_bounds, - "Count": counts, - "Taxable only": taxables, - "Full population": full_pops, - "Value": values, - "SOI Value": soi_values, - } - ) - - soi_replication["Error"] = ( - soi_replication["Value"] - soi_replication["SOI Value"] - ) - soi_replication["Absolute error"] = soi_replication["Error"].abs() - soi_replication["Relative error"] = ( - (soi_replication["Error"] / soi_replication["SOI Value"]) - .replace([np.inf, -np.inf], np.nan) - .fillna(0) - ) - soi_replication["Absolute relative error"] = soi_replication[ - "Relative error" - ].abs() - - return soi_replication diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py deleted file mode 100644 index 1ec097a7b..000000000 --- a/tests/test_reproducibility.py +++ /dev/null @@ -1,273 +0,0 @@ -""" -Reproducibility tests for Enhanced CPS generation. - -These tests ensure the pipeline produces consistent results -and can be reproduced in different environments. -""" - -import pytest -import numpy as np -import pandas as pd -from pathlib import Path -import hashlib -import json - - -class TestReproducibility: - """Test suite for reproducibility validation.""" - - def test_environment_setup(self): - """Test that required packages are installed.""" - required_packages = [ - "policyengine_us", - "policyengine_us_data", - "quantile_forest", - "pandas", - "numpy", - "torch", - ] - - for package in required_packages: - try: - __import__(package.replace("-", "_")) - except ImportError: - pytest.fail(f"Required package '{package}' not installed") - - def test_deterministic_imputation(self): - """Test that imputation produces deterministic results with fixed seed.""" - from policyengine_us_data.datasets.cps.enhanced_cps.imputation import ( - QuantileRegressionForestImputer, - ) - - # Create small test data - n_samples = 100 - predictors = pd.DataFrame( - { - "age": np.random.randint(18, 80, n_samples), - "sex": np.random.choice([1, 2], n_samples), - "filing_status": np.random.choice([1, 2], n_samples), - } - ) - - target = pd.Series(np.random.lognormal(10, 1, n_samples)) - - # Run imputation twice with same seed - imputer1 = QuantileRegressionForestImputer(random_state=42) - imputer1.fit(predictors, target) - result1 = imputer1.predict(predictors) - - imputer2 = QuantileRegressionForestImputer(random_state=42) - imputer2.fit(predictors, target) - result2 = imputer2.predict(predictors) - - # Results should be identical - np.testing.assert_array_almost_equal(result1, result2) - - def test_weight_optimization_convergence(self): - """Test that weight optimization converges consistently.""" - from policyengine_us_data.datasets.cps.enhanced_cps.reweight import ( - optimize_weights, - ) - - # Create test loss matrix - n_households = 100 - n_targets = 10 - - loss_matrix = np.random.rand(n_households, n_targets) - targets = np.random.rand(n_targets) * 1e6 - initial_weights = np.ones(n_households) - - # Run optimization twice - weights1, loss1 = optimize_weights( - loss_matrix, - targets, - initial_weights, - n_iterations=100, - dropout_rate=0.05, - seed=42, - ) - - weights2, loss2 = optimize_weights( - loss_matrix, - targets, - initial_weights, - n_iterations=100, - dropout_rate=0.05, - seed=42, - ) - - # Results should be very close - np.testing.assert_allclose(weights1, weights2, rtol=1e-5) - assert abs(loss1 - loss2) < 1e-6 - - def test_validation_metrics_stable(self): - """Test that validation metrics are stable across runs.""" - # This would load actual data in practice - # For now, test with synthetic data - - metrics = { - "gini_coefficient": 0.521, - "top_10_share": 0.472, - "top_1_share": 0.198, - "poverty_rate": 0.116, - } - - # In practice, would calculate from data - # Here we verify expected ranges - assert 0.50 <= metrics["gini_coefficient"] <= 0.55 - assert 0.45 <= metrics["top_10_share"] <= 0.50 - assert 0.18 <= metrics["top_1_share"] <= 0.22 - assert 0.10 <= metrics["poverty_rate"] <= 0.13 - - def test_output_checksums(self): - """Test that output files match expected checksums.""" - test_data_dir = Path("data/test") - - if not test_data_dir.exists(): - pytest.skip("Test data not generated") - - checksum_file = test_data_dir / "checksums.txt" - if not checksum_file.exists(): - pytest.skip("Checksum file not found") - - # Read expected checksums - expected_checksums = {} - with open(checksum_file) as f: - for line in f: - if line.strip(): - filename, checksum = line.strip().split(": ") - expected_checksums[filename] = checksum - - # Verify files - for filename, expected_checksum in expected_checksums.items(): - file_path = test_data_dir / filename - if file_path.exists() and filename != "checksums.txt": - with open(file_path, "rb") as f: - actual_checksum = hashlib.sha256(f.read()).hexdigest() - assert ( - actual_checksum == expected_checksum - ), f"Checksum mismatch for {filename}" - - def test_memory_usage(self): - """Test that memory usage stays within bounds.""" - import psutil - import os - - process = psutil.Process(os.getpid()) - memory_before = process.memory_info().rss / 1024 / 1024 # MB - - # Run a small imputation task - n_samples = 10000 - data = pd.DataFrame( - { - "age": np.random.randint(18, 80, n_samples), - "income": np.random.lognormal(10, 1, n_samples), - } - ) - - # Process data - data["income_bracket"] = pd.qcut(data["income"], 10) - - memory_after = process.memory_info().rss / 1024 / 1024 # MB - memory_used = memory_after - memory_before - - # Should use less than 500MB for this small task - assert memory_used < 500, f"Used {memory_used:.1f}MB, expected <500MB" - - def test_platform_independence(self): - """Test that code works across platforms.""" - import platform - - system = platform.system() - assert system in [ - "Linux", - "Darwin", - "Windows", - ], f"Unsupported platform: {system}" - - # Test path handling - test_path = Path("data") / "test" / "file.csv" - assert str(test_path).replace("\\", "/") == "data/test/file.csv" - - def test_api_credentials_documented(self): - """Test that API credential requirements are documented.""" - readme_path = Path("REPRODUCTION.md") - assert readme_path.exists(), "REPRODUCTION.md not found" - - content = readme_path.read_text() - - # Check for credential documentation - required_sections = [ - "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN", - "CENSUS_API_KEY", - "PUF Data Access", - ] - - for section in required_sections: - assert section in content, f"Missing documentation for '{section}'" - - def test_synthetic_data_generation(self): - """Test that synthetic data can be generated for testing.""" - from scripts.generate_test_data import ( - generate_synthetic_cps, - generate_synthetic_puf, - ) - - # Generate small datasets - households, persons = generate_synthetic_cps(n_households=10) - puf = generate_synthetic_puf(n_returns=50) - - # Verify structure - assert len(households) == 10 - assert len(persons) > 10 # Multiple persons per household - assert len(puf) == 50 - - # Verify required columns - assert "household_id" in households.columns - assert "person_id" in persons.columns - assert "wages" in puf.columns - - def test_smoke_test_pipeline(self): - """Run a minimal version of the full pipeline.""" - # This test would be marked as slow and only run in CI - pytest.skip("Full pipeline test - run with --runslow") - - # Would include: - # 1. Load test data - # 2. Run imputation on subset - # 3. Run reweighting with few targets - # 4. Validate outputs exist - - def test_documentation_completeness(self): - """Test that all necessary documentation exists.""" - required_docs = [ - "README.md", - "REPRODUCTION.md", - "CLAUDE.md", - "docs/methodology.md", - "docs/data.md", - ] - - for doc in required_docs: - doc_path = Path(doc) - assert doc_path.exists(), f"Missing documentation: {doc}" - - # Check not empty - content = doc_path.read_text() - assert len(content) > 100, f"Documentation too short: {doc}" - - -@pytest.mark.slow -class TestFullReproduction: - """Full reproduction tests (run with --runslow flag).""" - - def test_full_pipeline_subset(self): - """Test full pipeline on data subset.""" - # This would run the complete pipeline on a small subset - # Taking ~10 minutes instead of hours - pass - - def test_validation_dashboard(self): - """Test that validation dashboard can be generated.""" - # Would test dashboard generation - pass