diff --git a/.github/workflows/validate_district_calibration.yml b/.github/workflows/validate_district_calibration.yml new file mode 100644 index 00000000..adf3b6a3 --- /dev/null +++ b/.github/workflows/validate_district_calibration.yml @@ -0,0 +1,104 @@ +name: Validate District-Level Calibration + +on: + push: + branches: + - new-cd-var + workflow_dispatch: + inputs: + gcs_date: + description: 'GCS date prefix (e.g., 2025-10-22-1721)' + required: true + type: string + +jobs: + validate-and-upload: + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + + - name: Install dependencies + run: uv pip install -e .[dev] --system + + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + workload_identity_provider: "projects/322898545428/locations/global/workloadIdentityPools/policyengine-research-id-pool/providers/prod-github-provider" + service_account: "policyengine-research@policyengine-research.iam.gserviceaccount.com" + + - name: Set up Cloud SDK + uses: google-github-actions/setup-gcloud@v2 + + - name: Download weights from GCS + run: | + GCS_DATE="${{ inputs.gcs_date || '2025-10-22-1721' }}" + echo "Downloading weights from gs://policyengine-calibration/$GCS_DATE/outputs/" + mkdir -p policyengine_us_data/storage/calibration + gsutil ls gs://policyengine-calibration/$GCS_DATE/outputs/**/w_cd.npy | head -1 | xargs -I {} gsutil cp {} policyengine_us_data/storage/calibration/w_cd.npy + echo "Downloaded w_cd.npy" + + - name: Download prerequisite datasets + run: | + GCS_DATE="${{ inputs.gcs_date || '2025-10-22-1721' }}" + echo "Downloading stratified dataset and database from calibration run..." + mkdir -p policyengine_us_data/storage + gsutil cp gs://policyengine-calibration/$GCS_DATE/inputs/stratified_extended_cps_2023.h5 policyengine_us_data/storage/ + gsutil cp gs://policyengine-calibration/$GCS_DATE/inputs/policy_data.db policyengine_us_data/storage/ + + - name: Verify downloaded files + run: | + echo "Verifying downloaded files exist..." + if [ ! -f policyengine_us_data/storage/stratified_extended_cps_2023.h5 ]; then + echo "ERROR: stratified_extended_cps_2023.h5 not found" + exit 1 + fi + if [ ! -f policyengine_us_data/storage/policy_data.db ]; then + echo "ERROR: policy_data.db not found" + exit 1 + fi + echo "All required files present:" + ls -lh policyengine_us_data/storage/stratified_extended_cps_2023.h5 + ls -lh policyengine_us_data/storage/policy_data.db + + - name: Create state files + run: | + echo "Creating state-level .h5 files..." + python -m policyengine_us_data.datasets.cps.geo_stacking_calibration.create_sparse_cd_stacked \ + --weights-path policyengine_us_data/storage/calibration/w_cd.npy \ + --dataset-path policyengine_us_data/storage/stratified_extended_cps_2023.h5 \ + --db-path policyengine_us_data/storage/policy_data.db \ + --output-dir policyengine_us_data/storage/cd_states + + - name: Run district-level validation tests + run: | + echo "Running validation tests..." + pytest -m "district_level_validation" -v + + - name: Upload state files to GCS + if: success() + run: | + GCS_DATE="${{ inputs.gcs_date || '2025-10-22-1721' }}" + echo "Tests passed! Uploading state files to GCS..." + gsutil -m cp policyengine_us_data/storage/cd_states/*.h5 gs://policyengine-calibration/$GCS_DATE/state_files/ + gsutil -m cp policyengine_us_data/storage/cd_states/*_household_mapping.csv gs://policyengine-calibration/$GCS_DATE/state_files/ + echo "" + echo "✅ State files uploaded to gs://policyengine-calibration/$GCS_DATE/state_files/" + + - name: Report validation failure + if: failure() + run: | + echo "❌ District-level calibration validation FAILED" + echo "Check the test output above for details" + echo "State files were NOT uploaded to GCS" + exit 1 diff --git a/.gitignore b/.gitignore index 36301c6f..34925779 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,7 @@ node_modules !policyengine_us_data/storage/social_security_aux.csv !policyengine_us_data/storage/SSPopJul_TR2024.csv docs/.ipynb_checkpoints/ + +# Geo-stacking pipeline outputs +policyengine_us_data/storage/calibration/ +policyengine_us_data/storage/cd_states/ diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 804b82f7..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,64 +0,0 @@ -# CLAUDE.md - Guidelines for PolicyEngine US Data - -## Build Commands -- `make install` - Install dependencies and dev environment -- `make build` - Build the package using Python build -- `make data` - Generate project datasets - -## Testing -- `pytest` - Run all tests -- `pytest path/to/test_file.py::test_function` - Run a specific test -- `make test` - Also runs all tests - -## Formatting -- `make format` - Format all code using Black with 79 char line length -- `black . -l 79 --check` - Check formatting without changing files - -## Code Style Guidelines -- **Imports**: Standard libraries first, then third-party, then internal -- **Type Hints**: Use for all function parameters and return values -- **Naming**: Classes: PascalCase, Functions/Variables: snake_case, Constants: UPPER_SNAKE_CASE -- **Documentation**: Google-style docstrings with Args and Returns sections -- **Error Handling**: Use validation checks with specific error messages -- **Line Length**: 79 characters max (Black configured in pyproject.toml) -- **Python Version**: Targeting Python 3.11 - -## Git and PR Guidelines -- **CRITICAL**: NEVER create PRs from personal forks - ALL PRs MUST be created from branches pushed to the upstream PolicyEngine repository -- CI requires access to secrets that are not available to fork PRs for security reasons -- Fork PRs will fail on data download steps and cannot be merged -- Always create branches directly on the upstream repository: - ```bash - git checkout main - git pull upstream main - git checkout -b your-branch-name - git push -u upstream your-branch-name - ``` -- Use descriptive branch names like `fix-issue-123` or `add-feature-name` -- Always run `make format` before committing - -## CRITICAL RULES FOR ACADEMIC INTEGRITY - -### NEVER FABRICATE DATA OR RESULTS -- **NEVER make up numbers, statistics, or results** - This is academic malpractice -- **NEVER invent performance metrics, error rates, or validation results** -- **NEVER create fictional poverty rates, income distributions, or demographic statistics** -- **NEVER fabricate cross-validation results, correlations, or statistical tests** -- If you don't have actual data, say "Results to be determined" or "Analysis pending" -- Always use placeholder text like "[TO BE CALCULATED]" for unknown values -- When writing papers, use generic descriptions without specific numbers unless verified - -### When Writing Academic Papers -- Only cite actual results from running code or published sources -- Use placeholders for any metrics you haven't calculated -- Clearly mark sections that need empirical validation -- Never guess or estimate academic results -- If asked to complete analysis without data, explain what would need to be done - -### Consequences of Fabrication -- Fabricating data in academic work can lead to: - - Rejection from journals - - Blacklisting from future publications - - Damage to institutional reputation - - Legal consequences in funded research - - Career-ending academic misconduct charges \ No newline at end of file diff --git a/Makefile b/Makefile index 78d0904d..65347c4f 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,7 @@ documentation-dev: database: python policyengine_us_data/db/create_database_tables.py python policyengine_us_data/db/create_initial_strata.py + python policyengine_us_data/db/etl_national_targets.py python policyengine_us_data/db/etl_age.py python policyengine_us_data/db/etl_medicaid.py python policyengine_us_data/db/etl_snap.py @@ -74,9 +75,79 @@ data: mv policyengine_us_data/storage/enhanced_cps_2024.h5 policyengine_us_data/storage/dense_enhanced_cps_2024.h5 cp policyengine_us_data/storage/sparse_enhanced_cps_2024.h5 policyengine_us_data/storage/enhanced_cps_2024.h5 +data-geo: data + GEO_STACKING=true python policyengine_us_data/datasets/cps/cps.py + GEO_STACKING=true python policyengine_us_data/datasets/puf/puf.py + GEO_STACKING_MODE=true python policyengine_us_data/datasets/cps/extended_cps.py + python policyengine_us_data/datasets/cps/geo_stacking_calibration/create_stratified_cps.py 10000 + +calibration-package: data-geo + python policyengine_us_data/datasets/cps/geo_stacking_calibration/create_calibration_package.py \ + --db-path policyengine_us_data/storage/policy_data.db \ + --dataset-uri policyengine_us_data/storage/stratified_extended_cps_2023.h5 \ + --mode Stratified \ + --local-output policyengine_us_data/storage/calibration + +optimize-weights-local: calibration-package + python policyengine_us_data/datasets/cps/geo_stacking_calibration/optimize_weights.py \ + --input-dir policyengine_us_data/storage/calibration \ + --output-dir policyengine_us_data/storage/calibration \ + --total-epochs 100 \ + --device cpu + +create-state-files: optimize-weights-local + python -m policyengine_us_data.datasets.cps.geo_stacking_calibration.create_sparse_cd_stacked \ + --weights-path policyengine_us_data/storage/calibration/w_cd.npy \ + --dataset-path policyengine_us_data/storage/stratified_extended_cps_2023.h5 \ + --db-path policyengine_us_data/storage/policy_data.db \ + --output-dir policyengine_us_data/storage/cd_states + +upload-calibration-package: calibration-package + $(eval GCS_DATE := $(shell date +%Y-%m-%d-%H%M)) # For bash: GCS_DATE=$$(date +%Y-%m-%d-%H%M) + python policyengine_us_data/datasets/cps/geo_stacking_calibration/create_calibration_package.py \ + --db-path policyengine_us_data/storage/policy_data.db \ + --dataset-uri policyengine_us_data/storage/stratified_extended_cps_2023.h5 \ + --mode Stratified \ + --gcs-bucket policyengine-calibration \ + --gcs-date $(GCS_DATE) + @echo "Uploading dataset and database to GCS inputs..." + gsutil cp policyengine_us_data/storage/stratified_extended_cps_2023.h5 gs://policyengine-calibration/$(GCS_DATE)/inputs/ + gsutil cp policyengine_us_data/storage/policy_data.db gs://policyengine-calibration/$(GCS_DATE)/inputs/ + @echo "" + @echo "Calibration package uploaded to GCS" + @echo "Date prefix: $(GCS_DATE)" + @echo "" + @echo "To submit GCP batch job, update batch_pipeline/config.env:" + @echo " INPUT_PATH=$(GCS_DATE)/inputs" + @echo " OUTPUT_PATH=$(GCS_DATE)/outputs" + +optimize-weights-gcp: + @echo "Submitting Cloud Batch job for weight optimization..." + @echo "Make sure you've run 'make upload-calibration-package' first" + @echo "and updated batch_pipeline/config.env with the correct paths" + @echo "" + cd policyengine_us_data/datasets/cps/geo_stacking_calibration/batch_pipeline && ./submit_batch_job.sh + +download-weights-from-gcs: + @echo "Downloading weights from GCS..." + rm -f policyengine_us_data/storage/calibration/w_cd.npy + @read -p "Enter GCS date prefix (e.g., 2025-10-22-1630): " gcs_date; \ + gsutil ls gs://policyengine-calibration/$$gcs_date/outputs/**/w_cd.npy | head -1 | xargs -I {} gsutil cp {} policyengine_us_data/storage/calibration/w_cd.npy && \ + gsutil ls gs://policyengine-calibration/$$gcs_date/outputs/**/w_cd_*.npy | xargs -I {} gsutil cp {} policyengine_us_data/storage/calibration/ && \ + echo "Weights downloaded successfully" + +upload-state-files-to-gcs: + @echo "Uploading state files to GCS..." + @read -p "Enter GCS date prefix (e.g., 2025-10-22-1721): " gcs_date; \ + gsutil -m cp policyengine_us_data/storage/cd_states/*.h5 gs://policyengine-calibration/$$gcs_date/state_files/ && \ + gsutil -m cp policyengine_us_data/storage/cd_states/*_household_mapping.csv gs://policyengine-calibration/$$gcs_date/state_files/ && \ + echo "" && \ + echo "State files uploaded to gs://policyengine-calibration/$$gcs_date/state_files/" + clean: rm -f policyengine_us_data/storage/*.h5 rm -f policyengine_us_data/storage/*.db + rm -f policyengine_us_data/storage/*.pkl git clean -fX -- '*.csv' rm -rf policyengine_us_data/docs/_build diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..8ab33e3b 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,9 @@ +- bump: minor + changes: + added: + - Targets database infrastructure for geo-stacking calibration + - Congressional district level estimation capability + - Geo-stacking calibration utilities and modeling functionality + - GEO_STACKING environment variable for specialized data pipeline + - Hierarchical validation for calibration targets + - Holdout validation framework for geo-stacking models diff --git a/docs/DATA_PIPELINE.md b/docs/DATA_PIPELINE.md new file mode 100644 index 00000000..db519b73 --- /dev/null +++ b/docs/DATA_PIPELINE.md @@ -0,0 +1,417 @@ +# PolicyEngine US Data Pipeline Documentation + +## Overview + +The PolicyEngine US data pipeline integrates Census surveys (CPS, ACS), IRS tax data (PUF, SOI), and Federal Reserve wealth data (SCF) to create a comprehensive microsimulation dataset. The pipeline produces three progressively enhanced dataset levels: +1. **CPS**: Base demographic layer from Census +2. **Extended CPS**: CPS + PUF-imputed financial variables +3. **Enhanced CPS**: Extended CPS + calibrated weights to match official statistics + +## The Complete Pipeline Architecture + +```bash +# Full pipeline in execution order +make download # Download private IRS data from HuggingFace +make database # Build calibration targets database +make data # Run complete pipeline: + ├── python policyengine_us_data/utils/uprating.py + ├── python policyengine_us_data/datasets/acs/acs.py + ├── python policyengine_us_data/datasets/cps/cps.py + ├── python policyengine_us_data/datasets/puf/irs_puf.py + ├── python policyengine_us_data/datasets/puf/puf.py + ├── python policyengine_us_data/datasets/cps/extended_cps.py + ├── python policyengine_us_data/datasets/cps/enhanced_cps.py + └── python policyengine_us_data/datasets/cps/small_enhanced_cps.py +make upload # Upload completed datasets to cloud storage +``` + +## Critical Pipeline Dependencies + +### Hidden Dependencies + +1. **PUF always requires CPS_2021**: The PUF generation hardcodes CPS_2021 for pension contribution imputation, regardless of target year. This creates a permanent dependency on 2021 data. + +2. **PUF_2021 is the base for all future years**: Unlike going back to 2015, years 2022+ start from PUF_2021 and apply uprating. This makes PUF_2021 a critical checkpoint. + +3. **Pre-trained models are cached**: SIPP tip model (tips.pkl) and SCF relationships are trained once and reused. These are not part of the main pipeline execution. + +4. **Database targets are required for Enhanced CPS**: The calibration targets database must be populated before running Enhanced CPS generation. + +## Private Data Management + +### Download Prerequisites +The pipeline requires private IRS data downloaded from HuggingFace: +- `puf_2015.csv`: IRS Public Use File base data +- `demographics_2015.csv`: Demographic supplement +- `soi.csv`: Statistics of Income aggregates +- `np2023_d5_mid.csv`: Census population projections + +Access controlled via `HUGGING_FACE_TOKEN` environment variable. + +### Upload Distribution +Completed datasets are uploaded to: +- **HuggingFace**: Public access at `policyengine/policyengine-us-data` +- **Google Cloud Storage**: `policyengine-us-data` bucket + +Uploaded files include: +- `enhanced_cps_2024.h5` (sparse version) +- `dense_enhanced_cps_2024.h5` (full weights) +- `small_enhanced_cps_2024.h5` (1,000 household sample) +- `pooled_3_year_cps_2023.h5` (combined 2021-2023) +- `policy_data.db` (calibration targets database) + +## The Three-Stage Dataset Hierarchy + +### Stage 1: CPS (Base Demographics) +**What it provides**: +- Household structure and demographics +- Basic income variables +- Geographic distribution +- Raw survey weights + +**Transformations applied**: +1. Immigration status via ASEC-UA algorithm (targeting 13M undocumented) +2. Rent imputed from ACS-trained model +3. Tips from pre-trained SIPP model (loaded from tips.pkl) +4. Wealth/auto loans from SCF via QRF imputation + +### Stage 2: Extended CPS (Financial Imputation) +**The Statistical Fusion Process**: +1. Train QRF models on PUF's 70+ financial variables +2. Learn relationships between demographics and finances +3. Apply patterns to CPS households +4. Result: CPS demographics + PUF-learned financial distributions + +**Variables Imputed**: +- Income types: wages, capital gains, dividends, pensions +- Deductions: mortgage interest, charitable, state/local taxes +- Credits: EITC-relevant amounts, child care expenses +- Business income: partnership, S-corp, farm, rental + +### Stage 3: Enhanced CPS (Calibrated Weights) +**The Calibration Process**: +Enhanced CPS reweights Extended CPS households to match official statistics through sophisticated optimization. + +**Calibration Targets**: +- **IRS SOI Statistics**: Income distributions by AGI bracket, state, filing status +- **Hard-coded totals**: Medical expenses, child support, property tax, rent +- **National/State balance**: Separate normalization for national vs state targets + +**Two Optimization Approaches**: + +1. **Dense Optimization** (Standard gradient descent): + - All households receive adjusted weights + - Smooth weight distribution + - Better for small-area estimates + +2. **Sparse Optimization** (L0 regularization via HardConcrete gates): + - Many households get zero weight + - Fewer non-zero weights but higher values + - More computationally efficient for large-scale simulations + - Uses temperature and initialization parameters to control sparsity + +The sparse version is the default distributed dataset, with dense available as `dense_enhanced_cps_2024.h5`. + +## Dataset Variants + +### Pooled CPS +Combines multiple years for increased sample size: +- **Pooled_3_Year_CPS_2023**: Merges CPS 2021, 2022, 2023 +- Maintains year indicators for time-series analysis +- Larger sample for state-level estimates + +### Small Enhanced CPS +Two reduction methods for development/testing: + +1. **Random Sampling**: 1,000 households randomly selected +2. **Sparse Selection**: Uses L0 regularization results + +Benefits: +- Fast iteration during development +- Unit testing microsimulation changes +- Reduced memory footprint (100MB vs 16GB) + +## The Two-Phase Uprating System + +### 2021 is a Methodology Boundary + +The system uses completely different uprating approaches before and after 2021: + +#### Phase 1: SOI Historical (2015 → 2021) +- Function: `uprate_puf()` in `datasets/puf/uprate_puf.py` +- Data source: IRS Statistics of Income actuals +- Method: Variable-specific growth from SOI aggregates +- Population adjustment: Divides by population growth for per-capita rates +- Special cases: Itemized deductions fixed at 2% annual growth + +#### Phase 2: Parameter Projection (2021 → Future) +- Function: `create_policyengine_uprating_factors_table()` +- Data source: PolicyEngine parameters (CBO, Census projections) +- Method: Indexed growth factors (2020 = 1.0) +- Coverage: 131+ variables with consistent methodology +- Any year >= 2021 can be generated this way + +### Why This Matters + +The 2021 boundary means: +- Historical accuracy for 2015-2021 using actual IRS data +- Forward flexibility for 2022+ using economic projections +- PUF_2021 must exist before creating any future year +- Changing pre-2021 methodology requires modifying SOI-based code + +## How Data Sources Actually Connect + +### ACS: Model Training Only +ACS_2022 doesn't contribute data to the final dataset. Instead: +- Trains a QRF model relating demographics to rent/property tax +- Model learns patterns like "income X in state Y → rent Z" +- These relationships apply across years (why 2022 works for 2023+) +- Located in `add_rent()` function in CPS generation + +### CPS: The Demographic Foundation +Foundation for all subsequent processing with four imputation layers. + +### PUF: Tax Detail Layer +**Critical Processing Steps**: +1. Uprating (two-phase system described above) +2. QBI simulation (W-2 wages, UBIA for Section 199A) +3. Demographics imputation for records missing age/gender +4. **Pension contributions learned from CPS_2021** (hardcoded dependency) + +**The QBI Simulation**: Since PUF lacks Section 199A details, the system: +- Simulates W-2 wages paid by businesses +- Estimates unadjusted basis of qualified property +- Assigns SSTB (specified service trade or business) status +- Based on parameters in `qbi_assumptions.yaml` + +## Technical Implementation Details + +### Memory Management +- ExtendedCPS QRF imputation: ~16GB RAM peak +- Processing 70+ variables sequentially to manage memory +- Batch processing with configurable batch sizes +- HDF5 format for efficient storage/access + +### Performance Optimization +- **Parallel processing**: Tool calls run concurrently where possible +- **Caching**: Pre-trained models cached to disk +- **Sparse storage**: Default distribution uses sparse weights +- **Incremental generation**: Can generate specific years without full rebuild + +### Error Recovery +- **Checkpoint saves**: Each major stage saves to disk +- **Resumable pipeline**: Can restart from last successful stage +- **Validation checks**: After each stage to catch issues early +- **Fallback options**: Dense weights if sparse optimization fails + +## CI/CD Integration + +### GitHub Actions Workflow +Triggered on: +- Push to main branch +- Pull requests +- Manual dispatch + +Pipeline stages: +1. **Lint**: Code quality checks +2. **Test**: + - Basic tests (every PR) + - Full suite with data build (main branch only) +3. **Publish**: PyPI release on version bump + +### Test Modes +- **Standard**: Unit tests only +- **Full Suite** (`full_suite: true`): + - Downloads private data + - Builds calibration database + - Generates all datasets + - Uploads to cloud storage + +### Environment Requirements +- **Secrets**: + - `HUGGING_FACE_TOKEN`: Private data access + - `POLICYENGINE_US_DATA_GITHUB_TOKEN`: Cross-repo operations +- **GCP Authentication**: Workload identity for uploads +- **TEST_LITE**: Reduces processing for non-production runs + +## Geographic Stacking and Entity Weights (Beta) + +### Geographic Stacking Architecture +**Note: The geo-stacking approach is currently in beta development with ongoing work to improve calibration accuracy.** + +The geo-stacking calibration creates sparse datasets where the same household characteristics can be weighted differently across multiple geographic units (states or congressional districts). This allows a single household to represent similar households in different locations with appropriate statistical weights. + +### Calibration Approach +The geo-stacking method calibrates household weights to match: +- **Census demographic targets**: Age distributions, population counts +- **IRS tax statistics**: Income distributions, tax filing patterns +- **Administrative program data**: SNAP participation, Medicaid enrollment +- **National hardcoded targets**: Medical expenses, child support, tips + +For congressional districts, a stratified sampling approach reduces the CPS from 112,502 to ~13,000 households while preserving all high-income households critical for tax policy analysis. + +### Hierarchical Target Selection +When calibrating to specific geographic levels, the system uses a hierarchical fallback: +1. Use target at most specific level (e.g., congressional district) if available +2. Fall back to state-level target if CD-level doesn't exist +3. Use national target if neither CD nor state target exists + +This ensures complete coverage while respecting the most granular data available. + +### Critical Entity Weight Relationships +When creating geo-stacked datasets, PolicyEngine uses a **person-level DataFrame** structure where: +- Each row represents one person +- Household weights are repeated for each person in the household +- Tax units and other entities are represented through ID references, not separate rows + +#### Weight Assignment Rules +1. **Person weights = household weights** (NOT multiplied by persons_per_household) +2. **Tax unit weights = household weights** (derived automatically by PolicyEngine) +3. **DO NOT** explicitly set tax_unit_weight - let PolicyEngine derive from household structure + +#### Entity Reindexing Requirements +When combining DataFrames from multiple geographic units: + +1. **Households**: Must preserve groupings - all persons in a household get the SAME new household ID +2. **Tax Units**: Must stay within households - use `person_tax_unit_id` column for grouping +3. **SPM/Marital Units**: Follow same pattern as tax units + +**Common Bug**: Creating new IDs for each row instead of each entity group breaks household structure. + +#### Correct Reindexing Pattern +```python +# CORRECT: One ID per household group +for old_hh_id, row_indices in hh_groups.items(): + for row_idx in row_indices: + hh_row_to_new_id[row_idx] = new_hh_id + new_hh_id += 1 # Increment AFTER all rows in group + +# WRONG: Creates new household for each person +for old_hh_id, row_indices in hh_groups.items(): + for row_idx in row_indices: + hh_row_to_new_id[row_idx] = new_hh_id + new_hh_id += 1 # BUG: Splits household +``` + +#### Weight Validation +To verify correct implementation: +```python +# These should be equal if weights are correct: +sim.calculate("person_count", map_to="household").sum() +sim.calculate("person_count", map_to="person").sum() + +# Tax unit count should be less than person count: +sim.calculate("tax_unit_count", map_to="household").sum() < +sim.calculate("person_count", map_to="household").sum() +``` + +## Data Validation Checkpoints + +### After CPS Generation +- Immigration status populations (13M undocumented target) +- Household structure integrity +- Geographic distribution +- Weight normalization + +### After PUF Processing +- QBI component reasonableness +- Pension contribution distributions +- Demographic completeness +- Tax variable consistency + +### After Extended CPS +- Financial variable distributions vs PUF +- Preservation of CPS demographics +- Total income aggregates +- Imputation quality metrics + +### After Enhanced CPS +- Target achievement rates (>95% for key variables) +- Weight distribution statistics +- State-level calibration quality +- Sparsity metrics (for sparse version) + +## Creating Datasets for Arbitrary Years + +### Creating Any Year >= 2021 + +You can create any year >= 2021 by defining a class: + +```python +class PUF_2023(PUF): + name = "puf_2023" + time_period = 2023 + file_path = STORAGE_FOLDER / "puf_2023.h5" + +PUF_2023().generate() # Automatically uprates from PUF_2021 +``` + +### Why Only 2015, 2021, 2024 Are Pre-Built + +- **2015**: IRS PUF base year (original data) +- **2021**: Methodology pivot + calibration year +- **2024**: Current year for policy analysis + +The infrastructure supports any year 2021-2034 (extent of uprating parameters). + +### The Cascade Effect + +Creating ExtendedCPS_2023 requires: +1. CPS_2023 (or uprated from CPS_2023 if no raw data) +2. PUF_2023 (uprated from PUF_2021) +3. ACS_2022 (already suitable, relationships stable) +4. SCF_2022 (wealth patterns applicable) + +Creating EnhancedCPS_2023 additionally requires: +5. ExtendedCPS_2023 (from above) +6. Calibration targets database (SOI + other sources) + +## Understanding the Web of Dependencies + +``` +uprating_factors.csv ──────────────────┐ + ↓ +ACS_2022 → [rent model] ────────→ CPS_2023 → ExtendedCPS_2023 → EnhancedCPS_2023 + ↑ ↑ ↑ +CPS_2021 → [pension model] ──────────┘ │ │ + ↓ │ │ +PUF_2015 → PUF_2021 → PUF_2023 ─────────────────────┘ │ + ↑ │ + [SOI data] │ + │ +calibration_targets.db ─────────────────────────────────────────────────┘ +``` + +This web means: +- Can't generate PUF without CPS_2021 existing +- Can't generate ExtendedCPS without both CPS and PUF +- Can't generate EnhancedCPS without ExtendedCPS and targets database +- Can't uprate PUF_2022+ without PUF_2021 +- But CAN reuse ACS_2022 for multiple years + +## Reproducibility Considerations + +### Ensuring Consistent Results +- **Random seeds**: Set via `set_seeds()` function +- **Model versioning**: Pre-trained models include version tags +- **Parameter freezing**: Uprating factors fixed at generation time +- **Data hashing**: Input files verified via checksums + +### Sources of Variation +- **Optimization convergence**: Different hardware may converge differently +- **Floating point precision**: GPU vs CPU differences +- **Library versions**: Especially torch, scikit-learn +- **Calibration targets**: Updates to SOI data affect results + +## Glossary + +- **QRF**: Quantile Random Forest - preserves distributions during imputation +- **SOI**: Statistics of Income - IRS published aggregates +- **QBI**: Qualified Business Income (Section 199A deduction) +- **UBIA**: Unadjusted Basis Immediately After Acquisition +- **SSTB**: Specified Service Trade or Business +- **ASEC-UA**: Algorithm for imputing undocumented status in CPS +- **HardConcrete**: Differentiable gate for L0 regularization +- **L0 Regularization**: Penalty on number of non-zero weights +- **Dense weights**: All households have positive weights +- **Sparse weights**: Many households have zero weight \ No newline at end of file diff --git a/docs/SNAP_SIMULATION_ANALYSIS.md b/docs/SNAP_SIMULATION_ANALYSIS.md new file mode 100644 index 00000000..ab7c486f --- /dev/null +++ b/docs/SNAP_SIMULATION_ANALYSIS.md @@ -0,0 +1,207 @@ +# SNAP Simulation Analysis + +## Overview + +This document analyzes the relationship between `snap_reported` (CPS survey data) and `snap` (calculated) in PolicyEngine's CPS 2023 dataset. The analysis reveals critical differences between survey-reported benefits and rule-based calculations. + +## Key Findings + +### Two Independent SNAP Variables + +PolicyEngine maintains two separate SNAP variables that operate independently: + +**snap_reported** ($42.59B): +- Source: CPS ASEC survey responses (`SPM_SNAPSUB`) +- Represents what households reported receiving in surveys +- Known underreporting: ~40% of administrative totals +- Loaded in: `policyengine_us_data/datasets/cps/cps.py:609-630` + +**snap** ($57.26B calculated): +- Calculated independently using federal eligibility rules +- Does not reference `snap_reported` values +- Calculation flow: Eligibility rules → $70.84B → Apply 82% take-up → $57.26B +- Target: $107B (USDA FY2023 administrative data) + +### The Take-up Mechanism + +**Purpose**: Models the empirical observation that ~82% of eligible households claim SNAP benefits. + +**Implementation**: +```python +# policyengine_us_data/datasets/cps/cps.py:219 +generator = np.random.default_rng(seed=100) +data["snap_take_up_seed"] = generator.random(len(data["spm_unit_id"])) + +# policyengine_us/variables/gov/usda/snap/takes_up_snap_if_eligible.py +def formula(spm_unit, period, parameters): + seed = spm_unit("snap_take_up_seed", period) + takeup_rate = parameters(period).gov.usda.snap.takeup_rate # 0.82 + return seed < takeup_rate +``` + +**Effect**: Reduces calculated benefits from $70.84B to $57.26B (19.2% reduction) + +### Critical Problems Identified + +#### Problem A: Actual Recipients Zeroed Out + +**$25.89B in reported SNAP benefits are set to $0:** + +| Reason | SPM Units | Amount | Explanation | +|--------|-----------|--------|-------------| +| Deemed ineligible | 7.3M | $21.52B | Rules say they don't qualify, but they actually receive SNAP | +| Failed take-up seed | 1.2M | $4.31B | Eligible but random seed ≥ 0.82 | + +**Example case**: +- Household reports $276/year SNAP +- Eligible for $264/year calculated +- But `snap_take_up_seed = 0.861 > 0.82` +- Result: `snap = $0` + +**The take-up mechanism applies to ALL households, including those who reported receiving benefits.** + +#### Problem B: Eligibility Rules Don't Match Reality + +**7.7M SPM units actually receiving SNAP are deemed "ineligible" by PolicyEngine.** + +Evidence from sample analysis: + +| SPM Unit | Reported SNAP | Calculated | Gross Income | 130% FPL Limit | Status | +|----------|--------------|------------|--------------|----------------|---------| +| 7 | $5,160/year | $0 | $5,000/mo | $2,693/mo | 186% over limit | +| 43 | $276/year | $0 | $1,973/mo | $2,136/mo | 92% of limit | +| 78 | $3,492/year | $0 | $0/mo | $1,580/mo | 0% income but still ineligible | + +**Root causes**: + +1. **Broad-Based Categorical Eligibility (BBCE)** not modeled + - 40+ states use BBCE + - Recipients of any TANF-funded service are categorically eligible + - Income limits waived or raised to 185-200% FPL + +2. **State-specific variations** not captured + - Different income limits by state + - Varying asset tests (often waived) + - State supplements + +3. **Income comparison**: + - "Ineligible" recipients: Mean income $4,600/month + - "Eligible" units: Mean income $1,668/month + - Ratio: 2.8x higher income among actual recipients + +**PolicyEngine uses federal baseline rules, missing state-level expansions that cover millions of real recipients.** + +#### Problem C: Poor Household-Level Matching + +Overlap analysis between reported and calculated SNAP: + +| Category | Count | Notes | +|----------|-------|-------| +| Both reported AND calculated | 5.2M | Correlation: 0.55 between amounts | +| Reported but NOT calculated | 8.5M | Actual recipients zeroed out | +| Calculated but NOT reported | 11.6M | Survey underreporting | +| Neither | 107.1M | | + +**Only 37% of actual recipients (5.2M / 14M) are correctly identified, with weak correlation in benefit amounts.** + +## Why snap > snap_reported + +Despite the 82% take-up reducing calculated benefits by 19%, `snap` ($57.26B) is still 34% higher than `snap_reported` ($42.59B): + +1. **Starting point is higher**: Eligibility rules produce $70.84B before take-up +2. **Calculated entitlements exceed reports**: Rules-based calculation captures "proper" benefit amounts while surveys are imprecise +3. **Survey underreporting is severe**: Known issue in CPS ASEC data +4. **Emergency allotments included**: Jan-Feb 2023 had COVID-era supplements ($4.46B/month) + +The 19% reduction from take-up is smaller than the 66% increase from calculated entitlements over reported benefits. + +## Data Flow + +```mermaid +graph TD + A[CPS ASEC Survey] -->|SPM_SNAPSUB| B[snap_reported: $42.59B] + + C[Household Income/Size] -->|Eligibility Rules| D[snap_normal_allotment: $70.84B] + D -->|Random Seed < 0.82| E[snap calculated: $57.26B] + D -->|Random Seed >= 0.82| F[snap = $0] + + G[USDA Administrative] -->|Target| H[$107B FY2023] + + E -->|Reweighting| H + + style B fill:#f9f,stroke:#333 + style E fill:#bbf,stroke:#333 + style H fill:#bfb,stroke:#333 +``` + +## Implications + +### For Analysis + +1. **snap_reported** and **snap** are not comparable - they represent fundamentally different measurements +2. **Individual household accuracy is poor** - only 37% match, correlation 0.55 +3. **Aggregate totals require calibration** - raw calculations underestimate by 47% ($57B vs $107B target) + +### For Policy Simulation + +**Advantages**: +- Consistent methodology for policy reforms +- Can model eligibility rule changes +- Not anchored to survey underreporting + +**Disadvantages**: +- Destroys empirical information from actual recipients +- Misses state-level program variations +- Household-level predictions unreliable + +### For Calibration + +The Enhanced CPS reweighting process must bridge a large gap: +- Starting point: $57.26B (raw calculated) +- Target: $107B (administrative) +- Required adjustment: 87% increase via household weights + +This heavy reliance on calibration suggests the base eligibility calculations need improvement. + +## Recommendations for Future Work + +1. **Preserve reported information**: Don't zero out households that report receiving SNAP + ```python + # Proposed logic + if snap_reported > 0: + return max(snap_reported, calculated_value) + else: + return calculated_value * takes_up + ``` + +2. **Model state-level SNAP variations**: Implement BBCE and state-specific rules + +3. **Investigate eligibility rule accuracy**: Why do 7.7M actual recipients fail eligibility? + +4. **Consider conditional take-up**: Apply take-up only to households without reported benefits + +5. **Document limitations**: Make clear that household-level SNAP predictions are unreliable + +## Technical Details + +### Files Analyzed + +- Data: `policyengine-us-data/datasets/cps/cps.py` +- Calculation: `policyengine-us/variables/gov/usda/snap/snap.py` +- Take-up: `policyengine-us/variables/gov/usda/snap/takes_up_snap_if_eligible.py` +- Parameter: `policyengine-us/parameters/gov/usda/snap/takeup_rate.yaml` +- Calibration targets: `policyengine-us-data/db/etl_snap.py` + +### Dataset Information + +- Analysis date: 2025-01-13 +- Dataset: `cps_2023.h5` (uncalibrated) +- Total SPM units: 21,533 +- Reported SNAP recipients: 14.0M weighted +- Calculated SNAP recipients: 18.4M weighted + +## Conclusion + +The SNAP simulation in PolicyEngine is a **complete re-calculation that ignores reported survey values**. This approach prioritizes policy simulation consistency over empirical accuracy at the household level. The take-up mechanism reduces calculated benefits but does not bridge `snap_reported` to `snap` - they remain independent estimates representing different measurement approaches (survey vs. rules-based). + +The system relies heavily on subsequent calibration to match administrative totals, with household-level predictions showing poor accuracy (37% overlap, 0.55 correlation). Real SNAP recipients are frequently zeroed out, either due to incomplete state rule modeling ($21.52B) or random take-up exclusion ($4.31B). diff --git a/docs/myst.yml b/docs/myst.yml index a39b3cfb..38c9cb1e 100644 --- a/docs/myst.yml +++ b/docs/myst.yml @@ -28,6 +28,8 @@ project: - file: discussion.md - file: conclusion.md - file: appendix.md + - file: SNAP_SIMULATION_ANALYSIS.md + title: SNAP Simulation Analysis site: options: logo: logo.png diff --git a/policyengine_us_data/datasets/acs/census_acs.py b/policyengine_us_data/datasets/acs/census_acs.py index 842af627..16363087 100644 --- a/policyengine_us_data/datasets/acs/census_acs.py +++ b/policyengine_us_data/datasets/acs/census_acs.py @@ -206,3 +206,12 @@ class CensusACS_2022(CensusACS): name = "census_acs_2022.h5" file_path = STORAGE_FOLDER / "census_acs_2022.h5" time_period = 2022 + + +# TODO: 2023 ACS obviously exists, but this generation script is not +# able to extract it, potentially due to changes +# class CensusACS_2023(CensusACS): +# label = "Census ACS (2023)" +# name = "census_acs_2023.h5" +# file_path = STORAGE_FOLDER / "census_acs_2023.h5" +# time_period = 2023 diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index f932e0d5..23d155da 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -38,11 +38,15 @@ def generate(self): """ if self.raw_cps is None: - # Extrapolate from previous year + # Extrapolate from previous year or use actual data when available if self.time_period == 2025: cps_2024 = CPS_2024(require=True) arrays = cps_2024.load_dataset() arrays = uprate_cps_data(arrays, 2024, self.time_period) + elif self.time_period == 2024: + # Use actual 2024 data from CPS_2024 + cps_2024 = CPS_2024(require=True) + arrays = cps_2024.load_dataset() else: # Default to CPS 2023 for backward compatibility cps_2023 = CPS_2023(require=True) @@ -2058,6 +2062,15 @@ class CPS_2023_Full(CPS): time_period = 2023 +class CPS_2024_Full(CPS): + name = "cps_2024_full" + label = "CPS 2024 (full)" + raw_cps = CensusCPS_2024 + previous_year_raw_cps = CensusCPS_2023 + file_path = STORAGE_FOLDER / "cps_2024_full.h5" + time_period = 2024 + + class PooledCPS(Dataset): data_format = Dataset.ARRAYS input_datasets: list @@ -2117,10 +2130,15 @@ class Pooled_3_Year_CPS_2023(PooledCPS): url = "hf://policyengine/policyengine-us-data/pooled_3_year_cps_2023.h5" +geo_stacking = os.environ.get("GEO_STACKING") == "true" + if __name__ == "__main__": if test_lite: + CPS_2023().generate() CPS_2024().generate() CPS_2025().generate() + elif geo_stacking: + CPS_2023_Full().generate() else: CPS_2021().generate() CPS_2022().generate() @@ -2130,4 +2148,5 @@ class Pooled_3_Year_CPS_2023(PooledCPS): CPS_2021_Full().generate() CPS_2022_Full().generate() CPS_2023_Full().generate() + CPS_2024_Full().generate() Pooled_3_Year_CPS_2023().generate() diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 8bbe67bc..a6673bc0 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -14,10 +14,13 @@ from typing import Type from policyengine_us_data.storage import STORAGE_FOLDER from policyengine_us_data.datasets.cps.extended_cps import ( - ExtendedCPS_2024, + ExtendedCPS_2024, # NOTE (baogorek) : I made this the FULL version CPS_2019, CPS_2024, ) +from scipy import sparse as sp +from l0.calibration import SparseCalibrationWeights + import os from pathlib import Path import logging @@ -33,7 +36,7 @@ def reweight( original_weights, loss_matrix, targets_array, - dropout_rate=0.05, + dropout_rate=0.00, log_path="calibration_log.csv", epochs=500, l0_lambda=2.6445e-07, @@ -44,6 +47,12 @@ def reweight( set_seeds(seed) target_names = np.array(loss_matrix.columns) is_national = loss_matrix.columns.str.startswith("nation/") + + # I just realized that I already have a stratified data set which I can reweight + # I don't really need L0 right now at all! + ## Breaking in with the new L0 method + #X_sparse = sp.csr_matrix(loss_matrix.values) + loss_matrix = torch.tensor(loss_matrix.values, dtype=torch.float32) nation_normalisation_factor = is_national * (1 / is_national.sum()) state_normalisation_factor = ~is_national * (1 / (~is_national).sum()) @@ -354,7 +363,7 @@ def generate(self): class EnhancedCPS_2024(EnhancedCPS): - input_dataset = ExtendedCPS_2024 + input_dataset = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/stratified_extended_cps_2024.h5" # ExtendedCPS_2024 start_year = 2024 end_year = 2024 name = "enhanced_cps_2024" diff --git a/policyengine_us_data/datasets/cps/extended_cps.py b/policyengine_us_data/datasets/cps/extended_cps.py index f28c726c..ba6ce8fa 100644 --- a/policyengine_us_data/datasets/cps/extended_cps.py +++ b/policyengine_us_data/datasets/cps/extended_cps.py @@ -320,8 +320,19 @@ def impute_income_variables( return result +class ExtendedCPS_2023(ExtendedCPS): + cps = CPS_2023_Full + puf = PUF_2023 + name = "extended_cps_2023" + label = "Extended CPS (2023)" + file_path = STORAGE_FOLDER / "extended_cps_2023.h5" + time_period = 2023 + + +# TODO (baogorek added _Full), not sure what the ramifications are, +# But I need the extra data for the lon class ExtendedCPS_2024(ExtendedCPS): - cps = CPS_2024 + cps = CPS_2024_Full puf = PUF_2024 name = "extended_cps_2024" label = "Extended CPS (2024)" @@ -330,4 +341,11 @@ class ExtendedCPS_2024(ExtendedCPS): if __name__ == "__main__": - ExtendedCPS_2024().generate() + geo_stacking_mode = ( + os.environ.get("GEO_STACKING_MODE", "").lower() == "true" + ) + + if geo_stacking_mode: + ExtendedCPS_2023().generate() + else: + ExtendedCPS_2024().generate() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/.gitignore b/policyengine_us_data/datasets/cps/local_area_calibration/.gitignore new file mode 100644 index 00000000..c10d44db --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/.gitignore @@ -0,0 +1,17 @@ +# Test files (but not verify_calibration.py) + +# Analysis scripts - uncomment specific ones to commit if needed +analyze* +# !analyze_calibration_coverage.py +# !analyze_missing_actionable.py +# !analyze_missing_variables.py + +# NumPy weight arrays +*.npy + +# Generated artifacts +metadata.json +.ipynb_checkpoints/ + +# Debug scripts (including debug_uprating.py - temporary tool) +debug* diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/README.md b/policyengine_us_data/datasets/cps/local_area_calibration/README.md new file mode 100644 index 00000000..44114b62 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/README.md @@ -0,0 +1,279 @@ +# Local Area Calibration + +Creates state-level microsimulation datasets with Congressional District (CD) level calibration weights. Takes Current Population Survey (CPS) data, enriches it with Public Use File (PUF) income variables, applies L0 sparse calibration to match ~34k demographic and economic targets across 436 Congressional Districts, and produces optimized datasets for each US state. + +**Key Achievement**: Reduces ~200k household dataset to ~13k households while maintaining statistical representativeness across all 436 CDs through sophisticated weight calibration. + +## Quick Start + +### Local Testing (100 epochs) +```bash +make data-geo +make calibration-package +make optimize-weights-local +make create-state-files +``` + +### Production (GCP, 4000 epochs) +```bash +make data-geo +make upload-calibration-package # Note the date prefix shown +# Edit batch_pipeline/config.env with INPUT_PATH and OUTPUT_PATH +make optimize-weights-gcp +make download-weights-from-gcs +make create-state-files +make upload-state-files-to-gcs +``` + +## Pipeline Architecture + +``` +Phase 1: Data Preparation +├── CPS_2023_Full → Extended_CPS_2023 (288MB) +└── Extended_CPS_2023 → Stratified_CPS_2023 (28MB, ~13k households) + +Phase 2: Calibration Package +├── Sparse Matrix (~34k targets × ~5.7M household-CD pairs) +├── Target Groups & Initial Weights +└── Upload → GCS://policyengine-calibration/DATE/inputs/ + +Phase 3: Weight Optimization (L0 Calibration) +├── Local: 100 epochs (testing) → ~0% sparsity +└── GCP: 4000 epochs (production) → ~87% sparsity + +Phase 4: State Dataset Creation +├── Apply weights to stratified dataset +├── Create 51 state files + optional combined file +└── Upload → GCS & Hugging Face +``` + +## Conceptual Framework + +### The Geo-Stacking Approach + +The same household dataset is treated as existing in multiple geographic areas simultaneously, creating an "empirical superpopulation" where each household can represent itself in different locations with different weights. + +**Matrix Structure:** +- **Rows = Targets** (calibration constraints) +- **Columns = Households × Geographic Areas** + +This creates a "small n, large p" problem where household weights are the parameters we estimate. + +**Sparsity Pattern Example (2 states):** +``` + H1_CA H2_CA H3_CA H1_TX H2_TX H3_TX +national_employment X X X X X X +CA_age_0_5 X X X 0 0 0 +TX_age_0_5 0 0 0 X X X +``` + +### Hierarchical Target Selection + +For each target concept: +1. If CD-level target exists → use it for that CD only +2. If no CD target but state target exists → use state target for all CDs in that state +3. If neither exists → use national target + +For administrative data (SNAP, Medicaid), always prefer admin over survey data. + +## Target Groups + +Targets are grouped to ensure balanced optimization: + +| Group Type | Count | Description | +|------------|-------|-------------| +| National | 30 | Hardcoded US-level targets (each singleton) | +| Age | 7,848 | 18 bins × 436 CDs | +| AGI Distribution | 3,924 | 9 brackets × 436 CDs | +| SNAP Household | 436 | CD-level counts | +| SNAP Cost | 51 | State-level administrative | +| Medicaid | 436 | CD-level enrollment | +| EITC | 1,744 | 4 categories × 436 CDs | +| IRS SOI | ~25k | Various tax variables by CD | + +## Key Technical Details + +### L0 Regularization + +Creates truly sparse weights through stochastic gates: +- Gate formula: `gate = sigmoid(log_alpha/beta) * (zeta - gamma) + gamma` +- With default parameters, gates create exact zeros even with `lambda_l0=0` +- Production runs achieve ~87% sparsity (725k active from 5.7M weights) + +### Relative Loss Function + +Using `((y - y_pred) / (y + 1))^2`: +- Handles massive scale disparities (targets range from 178K to 385B) +- 10% error on $1B target = same penalty as 10% error on $100K target + +### ID Allocation System + +Each CD gets a 10,000 ID range to prevent collisions: +- Household IDs: `CD_index × 10,000` to `CD_index × 10,000 + 9,999` +- Person IDs: Add 5M offset to avoid household collision +- Max safe: ~49k per CD to stay under int32 overflow + +### State-Dependent Variables + +SNAP and other state-dependent variables require special handling: +- Matrix construction pre-calculates values for each state +- H5 creation reindexes entity IDs (same household in different CDs needs unique IDs) +- ID reindexing changes `random()` seeds, causing ~10-15% variance in random-dependent variables +- End-to-end tests use **aggregate tolerance** (~15%) rather than exact matching + +### Cache Clearing for State Swaps + +When setting `state_fips` to recalculate state-dependent benefits, cached variables must be cleared. This is subtle: + +**What to clear** (variables that need recalculation): +- Variables with `formulas` (traditional calculated variables) +- Variables with `adds` (sum of other variables, e.g., `snap_unearned_income`) +- Variables with `subtracts` (difference of variables) + +**What NOT to clear** (structural data from H5): +- ID variables: `person_id`, `household_id`, `tax_unit_id`, `spm_unit_id`, `family_id`, `marital_unit_id` +- These have formulas that generate sequential IDs (0, 1, 2, ...), but we need the original H5 values + +**Why IDs matter**: PolicyEngine's `random()` function uses entity IDs as deterministic seeds: +```python +seed = abs(entity_id * 100 + count_random_calls) +``` +If IDs are regenerated, random-dependent variables produce different results. Three variables use `random()`: +- `meets_ssi_resource_test` (SSI eligibility) +- `is_wic_at_nutritional_risk` (WIC eligibility) +- `would_claim_wic` (WIC takeup) + +**Implementation** in `calibration_utils.py` (single source of truth): +```python +def get_calculated_variables(sim): + exclude_ids = {'person_id', 'household_id', 'tax_unit_id', + 'spm_unit_id', 'family_id', 'marital_unit_id'} + return [name for name, var in sim.tax_benefit_system.variables.items() + if (var.formulas or getattr(var, 'adds', None) or getattr(var, 'subtracts', None)) + and name not in exclude_ids] +``` + +**Why same-state households also get set_input + cache clear**: The matrix builder always creates a fresh simulation, sets `state_fips`, and clears the cache—even when a household stays in its original state. This seems redundant but is intentional: + +1. **Consistency**: All matrix values are computed the same way, regardless of whether state changes +2. **Deterministic random()**: The `random()` function's seed includes `count_random_calls`. Clearing the cache resets this counter to 0, ensuring reproducible results. Without cache clearing, different calculation histories produce different random outcomes. +3. **Verification**: Tests can verify matrix values by replicating this exact procedure. Comparing against the original simulation (without cache clear) would show ~10-15% mismatches due to different random() counter states—not bugs, just different calculation paths. + +``` +Question 1: How are SSI/WIC related to SNAP? +The connection is through income calculation chains: +snap + └── snap_gross_income + └── snap_unearned_income (uses `adds`) + └── ssi (SSI benefit amount) + └── is_ssi_eligible + └── meets_ssi_resource_test + └── random() ← stochastic eligibility + +SSI (Supplemental Security Income) counts as unearned income for SNAP. So: +- random() determines if someone "passes" SSI's resource test (since CPS lacks actual asset data) +- This affects ssi benefit amount +- Which feeds into snap_unearned_income +- Which affects final snap calculation + +WIC doesn't directly affect SNAP, but shares similar random-dependent eligibility logic (is_wic_at_nutritional_risk, would_claim_wic). +Question 2: Why still 13.5% mismatches if we preserved IDs? + +The key is the full seed formula: +seed = abs(entity_id * 100 + count_random_calls) + +We preserved entity_id by excluding ID variables from clearing. But count_random_calls tracks how many times random() has been called for that entity during the simulation + +When we: +1. Create a fresh simulation +2. Set state_fips +3. Clear calculated variables +4. Call calculate("snap") + +The calculation order may differ from the original simulation's calculation order. Different traversal paths through the variable dependency graph → different +count_random_calls when meets_ssi_resource_test is reached → different seed → different random result. +``` + +## File Reference + +### Core Scripts +| Script | Purpose | +|--------|---------| +| `create_stratified_cps.py` | Income-based stratification sampling | +| `create_calibration_package.py` | Build optimization inputs | +| `optimize_weights.py` | L0 weight optimization | +| `stacked_dataset_builder.py` | Apply weights, create state files | +| `sparse_matrix_builder.py` | Build sparse target matrix | +| `calibration_utils.py` | Helper functions, CD mappings | + +### Data Files +| File | Purpose | +|------|---------| +| `policy_data.db` | SQLite with all calibration targets | +| `stratified_extended_cps_2023.h5` | Input dataset (~13k households) | +| `calibration_package.pkl` | Sparse matrix & metadata | +| `w_cd.npy` | Final calibration weights | + +### Batch Pipeline +| File | Purpose | +|------|---------| +| `batch_pipeline/Dockerfile` | CUDA + PyTorch container | +| `batch_pipeline/submit_batch_job.sh` | Build, push, submit to GCP | +| `batch_pipeline/config.env` | GCP settings | + +## Validation + +### Matrix Cell Lookup + +Use `household_tracer.py` to navigate the matrix: + +```python +from household_tracer import HouseholdTracer +tracer = HouseholdTracer(targets_df, matrix, household_mapping, cd_geoids, sim) + +# Find where a household appears +positions = tracer.get_household_column_positions(household_id=565) + +# Look up any cell +cell_info = tracer.lookup_matrix_cell(row_idx=10, col_idx=500) +``` + +### Key Validation Findings + +1. **Tax Unit vs Household**: AGI constraints apply at tax unit level. A 5-person household with 3 people in a qualifying tax unit shows matrix value 3.0 (correct). + +2. **Hierarchical Consistency**: Targets sum correctly from CD → State → National levels. + +3. **SNAP Behavior**: May use reported values from dataset (not formulas), so state changes may not affect SNAP. + +## Troubleshooting + +### "CD exceeded 10k household allocation" +Weight vector has wrong dimensions or 0% sparsity. Check sparsity is ~87% for production. + +### Memory Issues +- Local: Reduce batch size or use GCP +- State file creation: Use `--include-full-dataset` only with 32GB+ RAM + +### GCP Job Fails +1. Check paths in `config.env` +2. Run `gcloud auth configure-docker` +3. Verify input file exists in GCS + +## Known Issues + +### CD-County Mappings +Only 10 CDs have real county proportions. Remaining CDs use state's most populous county. Fix requires Census geographic relationship files. + +### Variables Excluded from Calibration +Certain high-error variables are excluded (rental income, various tax deductions). See `calibration_utils.py` for the full list. + +## Architecture Decisions + +| Decision | Rationale | +|----------|-----------| +| Stratified sampling | 93% size reduction while preserving income distribution | +| L0 regularization | Creates exact zeros for truly sparse weights | +| 10k ID ranges | Prevents int32 overflow in PolicyEngine | +| Group-wise loss | Prevents histogram variables from dominating | +| Relative loss | Handles 6 orders of magnitude in target scales | diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/add_hierarchical_check.py b/policyengine_us_data/datasets/cps/local_area_calibration/add_hierarchical_check.py new file mode 100644 index 00000000..fc78e5ec --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/add_hierarchical_check.py @@ -0,0 +1,252 @@ +""" +Quick patch to add hierarchical consistency checking to simple_holdout results. +This can be called after simple_holdout completes. +""" + +import numpy as np +import pandas as pd +import pickle +import os +from scipy import sparse as sp +import torch + + +def compute_hierarchical_consistency(calibration_package_path): + """ + Load calibration package and compute hierarchical consistency metrics. + Assumes model has been trained and weights are available. + + Args: + calibration_package_path: Path to calibration_package.pkl + + Returns: + dict with hierarchical consistency metrics + """ + + # Load the package + with open(calibration_package_path, "rb") as f: + data = pickle.load(f) + + X_sparse = data["X_sparse"] + targets_df = data["targets_df"] + targets = targets_df.value.values + + # Load the most recent trained model or weights + # For now, we'll compute what the metrics would look like + # In practice, you'd load the actual weights from the trained model + + # Get CD-level targets + cd_mask = targets_df["geographic_id"].str.len() > 2 + cd_targets = targets_df[cd_mask].copy() + + # Group CDs by state and variable + hierarchical_checks = [] + + for variable in cd_targets["variable"].unique(): + var_cd_targets = cd_targets[cd_targets["variable"] == variable] + + # Extract state from CD (assuming format like '0101' where first 2 digits are state) + var_cd_targets["state"] = var_cd_targets["geographic_id"].apply( + lambda x: x[:2] if len(x) == 4 else x[:-2] + ) + + # Sum by state + state_sums = var_cd_targets.groupby("state")["value"].sum() + + # Check if we have corresponding state-level targets + state_targets = targets_df[ + (targets_df["geographic_id"].isin(state_sums.index)) + & (targets_df["variable"] == variable) + ] + + if not state_targets.empty: + for state_id in state_sums.index: + state_target = state_targets[ + state_targets["geographic_id"] == state_id + ] + if not state_target.empty: + cd_sum = state_sums[state_id] + state_val = state_target["value"].iloc[0] + rel_diff = ( + (cd_sum - state_val) / state_val + if state_val != 0 + else 0 + ) + + hierarchical_checks.append( + { + "variable": variable, + "state": state_id, + "cd_sum": cd_sum, + "state_target": state_val, + "relative_difference": rel_diff, + } + ) + + # Check national consistency + national_target = targets_df[ + (targets_df["geographic_id"] == "US") + & (targets_df["variable"] == variable) + ] + + if not national_target.empty: + cd_national_sum = var_cd_targets["value"].sum() + national_val = national_target["value"].iloc[0] + rel_diff = ( + (cd_national_sum - national_val) / national_val + if national_val != 0 + else 0 + ) + + hierarchical_checks.append( + { + "variable": variable, + "state": "US", + "cd_sum": cd_national_sum, + "state_target": national_val, + "relative_difference": rel_diff, + } + ) + + if hierarchical_checks: + checks_df = pd.DataFrame(hierarchical_checks) + + # Summary statistics + summary = { + "mean_abs_rel_diff": np.abs( + checks_df["relative_difference"] + ).mean(), + "max_abs_rel_diff": np.abs(checks_df["relative_difference"]).max(), + "n_checks": len(checks_df), + "n_perfect_matches": ( + np.abs(checks_df["relative_difference"]) < 0.001 + ).sum(), + "n_within_1pct": ( + np.abs(checks_df["relative_difference"]) < 0.01 + ).sum(), + "n_within_5pct": ( + np.abs(checks_df["relative_difference"]) < 0.05 + ).sum(), + "n_within_10pct": ( + np.abs(checks_df["relative_difference"]) < 0.10 + ).sum(), + } + + # Worst mismatches + worst = checks_df.nlargest(5, "relative_difference") + summary["worst_overestimates"] = worst[ + ["variable", "state", "relative_difference"] + ].to_dict("records") + + best = checks_df.nsmallest(5, "relative_difference") + summary["worst_underestimates"] = best[ + ["variable", "state", "relative_difference"] + ].to_dict("records") + + return {"summary": summary, "details": checks_df} + else: + return { + "summary": { + "message": "No hierarchical targets found for comparison" + }, + "details": pd.DataFrame(), + } + + +def analyze_holdout_hierarchical_consistency(results, targets_df): + """ + Analyze hierarchical consistency for holdout groups only. + This is useful when some groups are geographic aggregates. + + Args: + results: Output from simple_holdout + targets_df: Full targets dataframe with geographic info + + Returns: + Enhanced results dict with hierarchical analysis + """ + + # Check if any holdout groups represent state or national aggregates + holdout_group_ids = list(results["holdout_group_losses"].keys()) + + # Map group IDs to geographic levels + group_geo_analysis = [] + + for group_id in holdout_group_ids: + group_targets = targets_df[ + targets_df.index.isin( + [i for i, g in enumerate(target_groups) if g == group_id] + ) + ] + + if not group_targets.empty: + geo_ids = group_targets["geographic_id"].unique() + + # Classify the geographic level + if "US" in geo_ids: + level = "national" + elif all(len(g) <= 2 for g in geo_ids): + level = "state" + elif all(len(g) > 2 for g in geo_ids): + level = "cd" + else: + level = "mixed" + + group_geo_analysis.append( + { + "group_id": group_id, + "geographic_level": level, + "n_geos": len(geo_ids), + "loss": results["holdout_group_losses"][group_id], + } + ) + + # Add to results + if group_geo_analysis: + geo_df = pd.DataFrame(group_geo_analysis) + + # Compare performance by geographic level + level_performance = geo_df.groupby("geographic_level")["loss"].agg( + ["mean", "std", "min", "max", "count"] + ) + + results["hierarchical_analysis"] = { + "group_geographic_levels": group_geo_analysis, + "performance_by_level": level_performance.to_dict(), + "observation": "Check if state/national groups have higher loss than CD groups", + } + + return results + + +# Example usage: +if __name__ == "__main__": + # Check hierarchical consistency of targets + consistency = compute_hierarchical_consistency( + "~/Downloads/cd_calibration_data/calibration_package.pkl" + ) + + print("Hierarchical Consistency Check") + print("=" * 60) + print( + f"Mean absolute relative difference: {consistency['summary']['mean_abs_rel_diff']:.2%}" + ) + print( + f"Max absolute relative difference: {consistency['summary']['max_abs_rel_diff']:.2%}" + ) + print( + f"Checks within 1%: {consistency['summary']['n_within_1pct']}/{consistency['summary']['n_checks']}" + ) + print( + f"Checks within 5%: {consistency['summary']['n_within_5pct']}/{consistency['summary']['n_checks']}" + ) + print( + f"Checks within 10%: {consistency['summary']['n_within_10pct']}/{consistency['summary']['n_checks']}" + ) + + if "worst_overestimates" in consistency["summary"]: + print("\nWorst overestimates (CD sum > state/national target):") + for item in consistency["summary"]["worst_overestimates"][:3]: + print( + f" {item['variable']} in {item['state']}: {item['relative_difference']:.1%}" + ) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/Dockerfile b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/Dockerfile new file mode 100644 index 00000000..61522a88 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/Dockerfile @@ -0,0 +1,26 @@ +# Use Google's Deep Learning container (optimized for GCP) +# Has PyTorch 2.x, CUDA, cuDNN, numpy, scipy, pandas, and gsutil pre-installed +FROM gcr.io/deeplearning-platform-release/pytorch-gpu.2-0:latest + +# Fix NumPy compatibility issue - force reinstall numpy compatible version +RUN pip install --no-cache-dir --force-reinstall "numpy>=1.24,<2.0" + +# Install additional dependencies +RUN pip install --no-cache-dir \ + google-cloud-storage + +# Install L0 package from GitHub (this might have compiled components) +RUN pip install --no-cache-dir --no-build-isolation git+https://github.com/PolicyEngine/L0.git@L0-sept + +# Create working directory +WORKDIR /app + +# Copy the optimization script +COPY optimize_weights.py /app/ +COPY run_batch_job.sh /app/ + +# Make the run script executable +RUN chmod +x /app/run_batch_job.sh + +# Set the entrypoint +ENTRYPOINT ["/app/run_batch_job.sh"] diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/README.md b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/README.md new file mode 100644 index 00000000..b45a5840 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/README.md @@ -0,0 +1,95 @@ +# Cloud Batch GPU Pipeline for Calibration Optimization + +This pipeline runs the L0 calibration optimization on GCP using Cloud Batch with GPU support. + +## Architecture +- **Cloud Batch**: Automatically provisions GPU VMs, runs the job, and tears down +- **Spot Instances**: Uses spot pricing for cost efficiency +- **GPU**: NVIDIA Tesla P100 for CUDA acceleration +- **Auto-shutdown**: VM terminates after job completion + +## Quick Start + +### For You (Original User) + +```bash +cd batch_pipeline +./submit_batch_job.sh +``` + +Your settings are already configured in `config.env`. + +### For Other Users + +1. **Run setup script:** +```bash +cd batch_pipeline +./setup.sh +``` + +2. **Edit configuration:** +```bash +# Copy and edit configuration +cp config.env .env +nano .env +``` + +Change these settings: +- `PROJECT_ID`: Your GCP project ID +- `SERVICE_ACCOUNT`: Your service account email +- `BUCKET_NAME`: Your GCS bucket name +- `INPUT_PATH`: Path to input data in bucket +- `OUTPUT_PATH`: Path for output data in bucket + +3. **Submit the job:** +```bash +./submit_batch_job.sh +``` + +4. **Monitor progress:** +```bash +./monitor_batch_job.sh +``` + +## Files +- `config.env` - Configuration template with your current settings +- `.env` - User's custom configuration (created from config.env) +- `Dockerfile` - Container with CUDA, PyTorch, L0 package +- `optimize_weights.py` - The optimization script +- `run_batch_job.sh` - Runs inside container +- `generate_config.py` - Creates batch config from .env +- `submit_batch_job.sh` - Builds, pushes, submits job +- `monitor_batch_job.sh` - Monitors job progress +- `setup.sh` - Initial setup for new users + +## How It Works + +1. `submit_batch_job.sh` reads configuration from `.env` (or `config.env`) +2. Builds Docker image with your code +3. Pushes to Google Container Registry +4. Generates `batch_job_config.json` from your settings +5. Submits job to Cloud Batch +6. Cloud Batch: + - Provisions spot GPU VM + - Pulls Docker image + - Downloads data from GCS + - Runs optimization + - Uploads results to GCS + - Terminates VM + +## Monitoring + +View job status: +```bash +gcloud batch jobs describe --location=us-central1 +``` + +View logs: +```bash +gcloud logging read "resource.type=batch.googleapis.com/Job AND resource.labels.job_id=" +``` + +## Cost Savings +- Spot instances: ~70% cheaper than on-demand +- Auto-shutdown: No forgotten VMs +- P100 GPU: Older but sufficient, cheaper than V100/A100 \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/batch_job_config.json b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/batch_job_config.json new file mode 100644 index 00000000..11ac2cc1 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/batch_job_config.json @@ -0,0 +1,70 @@ +{ + "taskGroups": [ + { + "taskSpec": { + "runnables": [ + { + "container": { + "imageUri": "us-docker.pkg.dev/policyengine-research/us.gcr.io/calibration-optimizer:latest", + "entrypoint": "/app/run_batch_job.sh" + } + } + ], + "computeResource": { + "cpuMilli": 8000, + "memoryMib": 32768 + }, + "maxRunDuration": "86400s", + "environment": { + "variables": { + "BUCKET_NAME": "policyengine-calibration", + "INPUT_PATH": "2025-10-22-1721/inputs", + "OUTPUT_PATH": "2025-10-22-1721/outputs", + "BETA": "0.35", + "LAMBDA_L0": "5e-7", + "LAMBDA_L2": "5e-9", + "LR": "0.1", + "TOTAL_EPOCHS": "4000", + "EPOCHS_PER_CHUNK": "1000", + "ENABLE_LOGGING": "true" + } + } + }, + "taskCount": 1, + "parallelism": 1 + } + ], + "allocationPolicy": { + "instances": [ + { + "installGpuDrivers": true, + "policy": { + "machineType": "n1-standard-2", + "provisioningModel": "SPOT", + "accelerators": [ + { + "type": "nvidia-tesla-p100", + "count": 1 + } + ], + "bootDisk": { + "sizeGb": "50" + } + } + } + ], + "location": { + "allowedLocations": [ + "zones/us-central1-a", + "zones/us-central1-b", + "zones/us-central1-c" + ] + }, + "serviceAccount": { + "email": "policyengine-research@policyengine-research.iam.gserviceaccount.com" + } + }, + "logsPolicy": { + "destination": "CLOUD_LOGGING" + } +} \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/config.env b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/config.env new file mode 100644 index 00000000..98b63f6a --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/config.env @@ -0,0 +1,41 @@ +# Cloud Batch Pipeline Configuration +# Copy this file to .env and modify for your project + +# GCP Project Configuration +PROJECT_ID=policyengine-research +REGION=us-central1 +SERVICE_ACCOUNT=policyengine-research@policyengine-research.iam.gserviceaccount.com + +# Docker Image Settings +IMAGE_NAME=calibration-optimizer +IMAGE_TAG=latest + +# GCS Bucket Configuration +BUCKET_NAME=policyengine-calibration +INPUT_PATH=2025-10-22-1721/inputs +OUTPUT_PATH=2025-10-22-1721/outputs + +# GPU Configuration +GPU_TYPE=nvidia-tesla-p100 +GPU_COUNT=1 +MACHINE_TYPE=n1-standard-2 + +# Optimization Parameters +BETA=0.35 +LAMBDA_L0=5e-7 +LAMBDA_L2=5e-9 +LR=0.1 +TOTAL_EPOCHS=4000 +EPOCHS_PER_CHUNK=1000 +ENABLE_LOGGING=true + +# Resource Limits +CPU_MILLI=8000 +MEMORY_MIB=32768 +MAX_RUN_DURATION=86400s + +# Provisioning Model (SPOT or STANDARD) +PROVISIONING_MODEL=SPOT + +# Allowed zones for the job (must be in same region as REGION above) +ALLOWED_ZONES=zones/us-central1-a,zones/us-central1-b,zones/us-central1-c \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/generate_config.py b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/generate_config.py new file mode 100755 index 00000000..f1d6841e --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/generate_config.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +""" +Generate Cloud Batch job configuration from environment variables +""" +import json +import os +from pathlib import Path + + +def load_env_file(env_file=".env"): + """Load environment variables from file""" + if not Path(env_file).exists(): + env_file = "config.env" + + if Path(env_file).exists(): + with open(env_file) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) + os.environ[key] = value + + +def generate_config(): + """Generate batch_job_config.json from environment variables""" + + # Load environment variables + load_env_file() + + # Parse allowed zones + allowed_zones = os.getenv("ALLOWED_ZONES", "zones/us-central1-a").split( + "," + ) + + config = { + "taskGroups": [ + { + "taskSpec": { + "runnables": [ + { + "container": { + "imageUri": f"us-docker.pkg.dev/{os.getenv('PROJECT_ID')}/us.gcr.io/{os.getenv('IMAGE_NAME')}:{os.getenv('IMAGE_TAG', 'latest')}", + "entrypoint": "/app/run_batch_job.sh", + } + } + ], + "computeResource": { + "cpuMilli": int(os.getenv("CPU_MILLI", "8000")), + "memoryMib": int(os.getenv("MEMORY_MIB", "32768")), + }, + "maxRunDuration": os.getenv("MAX_RUN_DURATION", "86400s"), + "environment": { + "variables": { + "BUCKET_NAME": os.getenv("BUCKET_NAME"), + "INPUT_PATH": os.getenv("INPUT_PATH"), + "OUTPUT_PATH": os.getenv("OUTPUT_PATH"), + "BETA": os.getenv("BETA", "0.35"), + "LAMBDA_L0": os.getenv("LAMBDA_L0", "5e-7"), + "LAMBDA_L2": os.getenv("LAMBDA_L2", "5e-9"), + "LR": os.getenv("LR", "0.1"), + "TOTAL_EPOCHS": os.getenv("TOTAL_EPOCHS", "12000"), + "EPOCHS_PER_CHUNK": os.getenv( + "EPOCHS_PER_CHUNK", "1000" + ), + "ENABLE_LOGGING": os.getenv( + "ENABLE_LOGGING", "true" + ), + } + }, + }, + "taskCount": 1, + "parallelism": 1, + } + ], + "allocationPolicy": { + "instances": [ + { + "installGpuDrivers": True, + "policy": { + "machineType": os.getenv( + "MACHINE_TYPE", "n1-standard-2" + ), + "provisioningModel": os.getenv( + "PROVISIONING_MODEL", "SPOT" + ), + "accelerators": [ + { + "type": os.getenv( + "GPU_TYPE", "nvidia-tesla-p100" + ), + "count": int(os.getenv("GPU_COUNT", "1")), + } + ], + "bootDisk": {"sizeGb": "50"}, + }, + } + ], + "location": {"allowedLocations": allowed_zones}, + "serviceAccount": {"email": os.getenv("SERVICE_ACCOUNT")}, + }, + "logsPolicy": {"destination": "CLOUD_LOGGING"}, + } + + # Write the configuration + with open("batch_job_config.json", "w") as f: + json.dump(config, f, indent=2) + + print("Generated batch_job_config.json from environment configuration") + print(f"Project: {os.getenv('PROJECT_ID')}") + print( + f"Image: us-docker.pkg.dev/{os.getenv('PROJECT_ID')}/us.gcr.io/{os.getenv('IMAGE_NAME')}:{os.getenv('IMAGE_TAG')}" + ) + print(f"GPU: {os.getenv('GPU_TYPE')}") + print(f"Service Account: {os.getenv('SERVICE_ACCOUNT')}") + + +if __name__ == "__main__": + generate_config() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/monitor_batch_job.sh b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/monitor_batch_job.sh new file mode 100755 index 00000000..3e8faf6e --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/monitor_batch_job.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Monitor Cloud Batch job status + +JOB_NAME="${1}" +REGION="${2:-us-central1}" + +if [ -z "${JOB_NAME}" ]; then + echo "Usage: $0 [region]" + echo "Example: $0 calibration-job-20241015-143022 us-central1" + exit 1 +fi + +echo "Monitoring job: ${JOB_NAME}" +echo "Region: ${REGION}" +echo "Press Ctrl+C to stop monitoring" +echo "" + +# Function to get job status +get_status() { + gcloud batch jobs describe ${JOB_NAME} \ + --location=${REGION} \ + --format="value(status.state)" 2>/dev/null +} + +# Monitor loop +while true; do + STATUS=$(get_status) + TIMESTAMP=$(date "+%Y-%m-%d %H:%M:%S") + + case ${STATUS} in + "SUCCEEDED") + echo "[${TIMESTAMP}] Job ${JOB_NAME} completed successfully!" + echo "" + echo "Fetching final logs..." + gcloud logging read "resource.type=batch.googleapis.com/Job AND resource.labels.job_id=${JOB_NAME}" \ + --limit=100 \ + --format="table(timestamp,severity,textPayload)" + echo "" + echo "Job completed! Check your GCS bucket for results." + exit 0 + ;; + "FAILED") + echo "[${TIMESTAMP}] Job ${JOB_NAME} failed!" + echo "" + echo "Fetching error logs..." + gcloud logging read "resource.type=batch.googleapis.com/Job AND resource.labels.job_id=${JOB_NAME} AND severity>=ERROR" \ + --limit=50 \ + --format="table(timestamp,severity,textPayload)" + exit 1 + ;; + "RUNNING") + echo "[${TIMESTAMP}] Job is running..." + # Optionally fetch recent logs + echo "Recent logs:" + gcloud logging read "resource.type=batch.googleapis.com/Job AND resource.labels.job_id=${JOB_NAME}" \ + --limit=5 \ + --format="table(timestamp,textPayload)" 2>/dev/null + ;; + "PENDING"|"QUEUED"|"SCHEDULED") + echo "[${TIMESTAMP}] Job status: ${STATUS} - waiting for resources..." + ;; + *) + echo "[${TIMESTAMP}] Job status: ${STATUS}" + ;; + esac + + sleep 30 +done \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/optimize_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/optimize_weights.py new file mode 100755 index 00000000..b1af99af --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/optimize_weights.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +import os +import argparse +from pathlib import Path +from datetime import datetime +import pickle +import torch +import numpy as np +from scipy import sparse as sp +from l0.calibration import SparseCalibrationWeights + + +def main(): + parser = argparse.ArgumentParser( + description="Run sparse L0 weight optimization" + ) + parser.add_argument( + "--input-dir", + required=True, + help="Directory containing calibration_package.pkl", + ) + parser.add_argument( + "--output-dir", required=True, help="Directory for output files" + ) + parser.add_argument( + "--beta", + type=float, + default=0.35, + help="Beta parameter for L0 regularization", + ) + parser.add_argument( + "--lambda-l0", + type=float, + default=5e-7, + help="L0 regularization strength", + ) + parser.add_argument( + "--lambda-l2", + type=float, + default=5e-9, + help="L2 regularization strength", + ) + parser.add_argument("--lr", type=float, default=0.1, help="Learning rate") + parser.add_argument( + "--total-epochs", type=int, default=12000, help="Total training epochs" + ) + parser.add_argument( + "--epochs-per-chunk", + type=int, + default=1000, + help="Epochs per logging chunk", + ) + parser.add_argument( + "--enable-logging", + action="store_true", + help="Enable detailed epoch logging", + ) + parser.add_argument( + "--device", + default="cuda", + choices=["cuda", "cpu"], + help="Device to use", + ) + + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading calibration package from {args.input_dir}") + with open(Path(args.input_dir) / "calibration_package.pkl", "rb") as f: + calibration_data = pickle.load(f) + + X_sparse = calibration_data["X_sparse"] + init_weights = calibration_data["initial_weights"] + targets_df = calibration_data["targets_df"] + targets = targets_df.value.values + + print(f"Matrix shape: {X_sparse.shape}") + print(f"Number of targets: {len(targets)}") + + target_names = [] + for _, row in targets_df.iterrows(): + geo_prefix = f"{row['geographic_id']}" + name = f"{geo_prefix}/{row['variable_desc']}" + target_names.append(name) + + model = SparseCalibrationWeights( + n_features=X_sparse.shape[1], + beta=args.beta, + gamma=-0.1, + zeta=1.1, + init_keep_prob=0.999, + init_weights=init_weights, + log_weight_jitter_sd=0.05, + log_alpha_jitter_sd=0.01, + device=args.device, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if args.enable_logging: + log_path = output_dir / "cd_calibration_log.csv" + with open(log_path, "w") as f: + f.write( + "target_name,estimate,target,epoch,error,rel_error,abs_error,rel_abs_error,loss\n" + ) + print(f"Initialized incremental log at: {log_path}") + + sparsity_path = output_dir / f"cd_sparsity_history_{timestamp}.csv" + with open(sparsity_path, "w") as f: + f.write("epoch,active_weights,total_weights,sparsity_pct\n") + print(f"Initialized sparsity tracking at: {sparsity_path}") + + for chunk_start in range(0, args.total_epochs, args.epochs_per_chunk): + chunk_epochs = min( + args.epochs_per_chunk, args.total_epochs - chunk_start + ) + current_epoch = chunk_start + chunk_epochs + + print( + f"\nTraining epochs {chunk_start + 1} to {current_epoch} of {args.total_epochs}" + ) + + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=args.lambda_l0, + lambda_l2=args.lambda_l2, + lr=args.lr, + epochs=chunk_epochs, + loss_type="relative", + verbose=True, + verbose_freq=chunk_epochs, + ) + + active_info = model.get_active_weights() + active_count = active_info["count"] + total_count = X_sparse.shape[1] + sparsity_pct = 100 * (1 - active_count / total_count) + + with open(sparsity_path, "a") as f: + f.write( + f"{current_epoch},{active_count},{total_count},{sparsity_pct:.4f}\n" + ) + + if args.enable_logging: + with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + + with open(log_path, "a") as f: + for i in range(len(targets)): + estimate = y_pred[i] + target = targets[i] + error = estimate - target + rel_error = error / target if target != 0 else 0 + abs_error = abs(error) + rel_abs_error = abs(rel_error) + loss = rel_error**2 + + f.write( + f'"{target_names[i]}",{estimate},{target},{current_epoch},' + f"{error},{rel_error},{abs_error},{rel_abs_error},{loss}\n" + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + with torch.no_grad(): + w = model.get_weights(deterministic=True).cpu().numpy() + + versioned_filename = f"w_cd_{timestamp}.npy" + full_path = output_dir / versioned_filename + np.save(full_path, w) + + canonical_path = output_dir / "w_cd.npy" + np.save(canonical_path, w) + + print(f"\nOptimization complete!") + print(f"Final weights saved to: {full_path}") + print(f"Canonical weights saved to: {canonical_path}") + print(f"Weights shape: {w.shape}") + print(f"Sparsity history saved to: {sparsity_path}") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/run_batch_job.sh b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/run_batch_job.sh new file mode 100755 index 00000000..e514f663 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/run_batch_job.sh @@ -0,0 +1,75 @@ +#!/bin/bash +set -e + +# Environment variables passed from Cloud Batch job config +BUCKET_NAME="${BUCKET_NAME:-policyengine-calibration}" +INPUT_PATH="${INPUT_PATH:-2024-10-08-2209/inputs}" +OUTPUT_PATH="${OUTPUT_PATH:-2024-10-08-2209/outputs}" + +# Optimization parameters (can be overridden via env vars) +BETA="${BETA:-0.35}" +LAMBDA_L0="${LAMBDA_L0:-5e-7}" +LAMBDA_L2="${LAMBDA_L2:-5e-9}" +LR="${LR:-0.1}" +TOTAL_EPOCHS="${TOTAL_EPOCHS:-12000}" +EPOCHS_PER_CHUNK="${EPOCHS_PER_CHUNK:-1000}" +ENABLE_LOGGING="${ENABLE_LOGGING:-true}" + +# Generate timestamp for this run +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +JOB_ID="${JOB_ID:-batch_job_${TIMESTAMP}}" + +echo "Starting Cloud Batch optimization job: ${JOB_ID}" +echo "Timestamp: ${TIMESTAMP}" +echo "Input: gs://${BUCKET_NAME}/${INPUT_PATH}" +echo "Output: gs://${BUCKET_NAME}/${OUTPUT_PATH}/${TIMESTAMP}" + +# Create local working directories +LOCAL_INPUT="/tmp/input" +LOCAL_OUTPUT="/tmp/output" +mkdir -p ${LOCAL_INPUT} +mkdir -p ${LOCAL_OUTPUT} + +# Download input data from GCS +echo "Downloading input data..." +gsutil cp "gs://${BUCKET_NAME}/${INPUT_PATH}/calibration_package.pkl" ${LOCAL_INPUT}/ +gsutil cp "gs://${BUCKET_NAME}/${INPUT_PATH}/metadata.json" ${LOCAL_INPUT}/ 2>/dev/null || echo "No metadata.json found" + +# Prepare logging flag +LOGGING_FLAG="" +if [ "${ENABLE_LOGGING}" = "true" ]; then + LOGGING_FLAG="--enable-logging" +fi + +# Run the optimization +echo "Starting optimization with parameters:" +echo " Beta: ${BETA}" +echo " Lambda L0: ${LAMBDA_L0}" +echo " Lambda L2: ${LAMBDA_L2}" +echo " Learning rate: ${LR}" +echo " Total epochs: ${TOTAL_EPOCHS}" +echo " Epochs per chunk: ${EPOCHS_PER_CHUNK}" +echo " Device: cuda" + +python /app/optimize_weights.py \ + --input-dir ${LOCAL_INPUT} \ + --output-dir ${LOCAL_OUTPUT} \ + --beta ${BETA} \ + --lambda-l0 ${LAMBDA_L0} \ + --lambda-l2 ${LAMBDA_L2} \ + --lr ${LR} \ + --total-epochs ${TOTAL_EPOCHS} \ + --epochs-per-chunk ${EPOCHS_PER_CHUNK} \ + ${LOGGING_FLAG} \ + --device cuda + +# Upload results to GCS +echo "Uploading results to GCS..." +gsutil -m cp -r ${LOCAL_OUTPUT}/* "gs://${BUCKET_NAME}/${OUTPUT_PATH}/${TIMESTAMP}/" + +# Create a completion marker +echo "{\"job_id\": \"${JOB_ID}\", \"timestamp\": \"${TIMESTAMP}\", \"status\": \"completed\"}" > ${LOCAL_OUTPUT}/job_complete.json +gsutil cp ${LOCAL_OUTPUT}/job_complete.json "gs://${BUCKET_NAME}/${OUTPUT_PATH}/${TIMESTAMP}/" + +echo "Job completed successfully!" +echo "Results uploaded to: gs://${BUCKET_NAME}/${OUTPUT_PATH}/${TIMESTAMP}/" \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/setup.sh b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/setup.sh new file mode 100755 index 00000000..0de64ff7 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/setup.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +echo "=========================================" +echo "Cloud Batch Pipeline Setup" +echo "=========================================" +echo "" + +# Check prerequisites +echo "Checking prerequisites..." + +# Check if Docker is installed +if ! command -v docker &> /dev/null; then + echo "❌ Docker is not installed" + echo " Please install Docker: https://docs.docker.com/get-docker/" + exit 1 +else + echo "✅ Docker is installed: $(docker --version)" +fi + +# Check if gcloud is installed +if ! command -v gcloud &> /dev/null; then + echo "❌ gcloud CLI is not installed" + echo " Please install gcloud: https://cloud.google.com/sdk/docs/install" + exit 1 +else + echo "✅ gcloud is installed: $(gcloud --version | head -n 1)" +fi + +# Check if authenticated +if ! gcloud auth list --filter=status:ACTIVE --format="value(account)" &> /dev/null; then + echo "❌ Not authenticated with gcloud" + echo " Please run: gcloud auth login" + exit 1 +else + ACTIVE_ACCOUNT=$(gcloud auth list --filter=status:ACTIVE --format="value(account)") + echo "✅ Authenticated as: ${ACTIVE_ACCOUNT}" +fi + +# Check Docker authentication for GCR +echo "" +echo "Configuring Docker for Google Container Registry..." +gcloud auth configure-docker --quiet + +# Create .env from config.env if it doesn't exist +if [ ! -f .env ]; then + echo "" + echo "Creating .env configuration file..." + cp config.env .env + echo "✅ Created .env from config.env" + echo "" + echo "⚠️ IMPORTANT: Edit .env to configure your project settings:" + echo " - PROJECT_ID: Your GCP project ID" + echo " - SERVICE_ACCOUNT: Your service account email" + echo " - BUCKET_NAME: Your GCS bucket name" + echo " - INPUT_PATH: Path to input data in bucket" + echo " - OUTPUT_PATH: Path for output data in bucket" + echo "" + echo " Edit with: nano .env" +else + echo "✅ .env file already exists" +fi + +# Make scripts executable +chmod +x *.sh +echo "✅ Made all scripts executable" + +echo "" +echo "=========================================" +echo "Setup complete!" +echo "" +echo "Next steps:" +echo "1. Edit .env with your project configuration" +echo "2. Ensure your input data is in GCS" +echo "3. Run: ./submit_batch_job.sh" +echo "=========================================" \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/submit_batch_job.sh b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/submit_batch_job.sh new file mode 100755 index 00000000..8cf9bc35 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/batch_pipeline/submit_batch_job.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# Script to build Docker image, push to GCR, and submit Cloud Batch job + +# Load configuration from .env if it exists, otherwise use config.env +if [ -f .env ]; then + echo "Loading configuration from .env" + source .env +elif [ -f config.env ]; then + echo "Loading configuration from config.env" + source config.env +else + echo "Error: No configuration file found. Please copy config.env to .env and customize it." + exit 1 +fi + +# Allow command-line overrides +IMAGE_TAG="${1:-${IMAGE_TAG:-latest}}" +REGION="${2:-${REGION:-us-central1}}" +JOB_NAME="calibration-job-$(date +%Y%m%d-%H%M%S)" + +echo "===========================================" +echo "Cloud Batch Calibration Job Submission" +echo "===========================================" +echo "Project: ${PROJECT_ID}" +echo "Image: us-docker.pkg.dev/${PROJECT_ID}/us.gcr.io/${IMAGE_NAME}:${IMAGE_TAG}" +echo "Region: ${REGION}" +echo "Job Name: ${JOB_NAME}" +echo "" + +# Step 1: Build Docker image +echo "Step 1: Building Docker image..." +docker build -t us-docker.pkg.dev/${PROJECT_ID}/us.gcr.io/${IMAGE_NAME}:${IMAGE_TAG} . + +if [ $? -ne 0 ]; then + echo "Error: Docker build failed" + exit 1 +fi + +# Step 2: Push to Artifact Registry +echo "" +echo "Step 2: Pushing image to Artifact Registry..." +docker push us-docker.pkg.dev/${PROJECT_ID}/us.gcr.io/${IMAGE_NAME}:${IMAGE_TAG} + +if [ $? -ne 0 ]; then + echo "Error: Docker push failed" + echo "Make sure you're authenticated: gcloud auth configure-docker" + exit 1 +fi + +# Step 3: Generate config and submit Cloud Batch job +echo "" +echo "Step 3a: Generating job configuration..." +python3 generate_config.py + +echo "" +echo "Step 3b: Submitting Cloud Batch job..." +gcloud batch jobs submit ${JOB_NAME} \ + --location=${REGION} \ + --config=batch_job_config.json + +if [ $? -eq 0 ]; then + echo "" + echo "===========================================" + echo "Job submitted successfully!" + echo "Job Name: ${JOB_NAME}" + echo "Region: ${REGION}" + echo "" + echo "Monitor job status with:" + echo " gcloud batch jobs describe ${JOB_NAME} --location=${REGION}" + echo "" + echo "View logs with:" + echo " gcloud batch jobs list --location=${REGION}" + echo " gcloud logging read \"resource.type=batch.googleapis.com/Job AND resource.labels.job_id=${JOB_NAME}\" --limit=50" + echo "" + echo "Or use the monitoring script:" + echo " ./monitor_batch_job.sh ${JOB_NAME} ${REGION}" + echo "===========================================" +else + echo "Error: Job submission failed" + exit 1 +fi \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/build_cd_county_mappings.py b/policyengine_us_data/datasets/cps/local_area_calibration/build_cd_county_mappings.py new file mode 100644 index 00000000..451ccf24 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/build_cd_county_mappings.py @@ -0,0 +1,273 @@ +""" +Build Congressional District to County mappings using Census data. + +This script: +1. Uses Census Bureau's geographic relationship files +2. Calculates what proportion of each CD's population lives in each county +3. Saves the mappings for use in create_sparse_state_stacked.py +""" + +import pandas as pd +import numpy as np +import json +from pathlib import Path +import requests +from typing import Dict, List, Tuple + + +def get_cd_county_relationships() -> pd.DataFrame: + """ + Get CD-County relationships from Census Bureau. + + The Census provides geographic relationship files that show + how different geographic units overlap. + """ + + # Try to use local file first if it exists + cache_file = Path("cd_county_relationships_2023.csv") + + if cache_file.exists(): + print(f"Loading cached relationships from {cache_file}") + return pd.read_csv(cache_file) + + # Census API endpoint for CD-County relationships + # This uses the 2020 Census geographic relationships + # Format: https://www.census.gov/geographies/reference-files/time-series/geo/relationship-files.html + + print("Downloading CD-County relationship data from Census...") + + # We'll use the census tract level data and aggregate up + # Each tract is in exactly one county and one CD + census_api_key = "YOUR_API_KEY" # You can get one from https://api.census.gov/data/key_signup.html + + # Alternative: Use pre-processed data from PolicyEngine or other sources + # For now, let's create a simplified mapping based on known relationships + + print("Creating simplified CD-County mappings based on major counties...") + + # This is a simplified mapping - in production you'd want complete Census data + # Format: CD -> List of (county_fips, approx_proportion) + simplified_mappings = { + # California examples + "601": [ + ("06089", 0.35), + ("06103", 0.25), + ("06115", 0.20), + ("06007", 0.20), + ], # CA-01: Shasta, Tehama, Yuba, Butte counties + "652": [("06073", 1.0)], # CA-52: San Diego County + "612": [ + ("06075", 0.60), + ("06081", 0.40), + ], # CA-12: San Francisco, San Mateo + # Texas examples + "4801": [ + ("48001", 0.15), + ("48213", 0.25), + ("48423", 0.35), + ("48183", 0.25), + ], # TX-01: Multiple counties + "4838": [("48201", 1.0)], # TX-38: Harris County (Houston) + # New York examples + "3601": [ + ("36103", 0.80), + ("36059", 0.20), + ], # NY-01: Suffolk, Nassau counties + "3612": [ + ("36061", 0.50), + ("36047", 0.50), + ], # NY-12: New York (Manhattan), Kings (Brooklyn) + # Florida examples + "1201": [ + ("12033", 0.40), + ("12091", 0.30), + ("12113", 0.30), + ], # FL-01: Escambia, Okaloosa, Santa Rosa + "1228": [("12086", 1.0)], # FL-28: Miami-Dade County + # Illinois example + "1701": [("17031", 1.0)], # IL-01: Cook County (Chicago) + # DC at-large + "1101": [("11001", 1.0)], # DC + } + + # Convert to DataFrame format + rows = [] + for cd_geoid, counties in simplified_mappings.items(): + for county_fips, proportion in counties: + rows.append( + { + "congressional_district_geoid": cd_geoid, + "county_fips": county_fips, + "proportion": proportion, + } + ) + + df = pd.DataFrame(rows) + + # Save for future use + df.to_csv(cache_file, index=False) + print(f"Saved relationships to {cache_file}") + + return df + + +def get_all_cds_from_database() -> List[str]: + """Get all CD GEOIDs from the database.""" + from sqlalchemy import create_engine, text + + db_path = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db" + db_uri = f"sqlite:///{db_path}" + engine = create_engine(db_uri) + + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM stratum_constraints sc + WHERE sc.constraint_variable = 'congressional_district_geoid' + ORDER BY sc.value + """ + + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + return [row[0] for row in result] + + +def build_complete_cd_county_mapping() -> Dict[str, Dict[str, float]]: + """ + Build a complete mapping of CD to county proportions. + + Returns: + Dict mapping CD GEOID -> {county_fips: proportion} + """ + + # Get all CDs from database + all_cds = get_all_cds_from_database() + print(f"Found {len(all_cds)} congressional districts in database") + + # Get relationships (simplified for now) + relationships = get_cd_county_relationships() + + # Build the complete mapping + cd_county_map = {} + + for cd in all_cds: + if cd in relationships["congressional_district_geoid"].values: + cd_data = relationships[ + relationships["congressional_district_geoid"] == cd + ] + cd_county_map[cd] = dict( + zip(cd_data["county_fips"], cd_data["proportion"]) + ) + else: + # For CDs not in our simplified mapping, assign to most populous county in state + state_fips = str(cd).zfill(4)[:2] # Extract state from CD GEOID + + # Default county assignments by state (most populous county) + state_default_counties = { + "01": "01073", # AL -> Jefferson County + "02": "02020", # AK -> Anchorage + "04": "04013", # AZ -> Maricopa County + "05": "05119", # AR -> Pulaski County + "06": "06037", # CA -> Los Angeles County + "08": "08031", # CO -> Denver County + "09": "09003", # CT -> Hartford County + "10": "10003", # DE -> New Castle County + "11": "11001", # DC -> District of Columbia + "12": "12086", # FL -> Miami-Dade County + "13": "13121", # GA -> Fulton County + "15": "15003", # HI -> Honolulu County + "16": "16001", # ID -> Ada County + "17": "17031", # IL -> Cook County + "18": "18097", # IN -> Marion County + "19": "19153", # IA -> Polk County + "20": "20091", # KS -> Johnson County + "21": "21111", # KY -> Jefferson County + "22": "22071", # LA -> Orleans Parish + "23": "23005", # ME -> Cumberland County + "24": "24003", # MD -> Anne Arundel County + "25": "25017", # MA -> Middlesex County + "26": "26163", # MI -> Wayne County + "27": "27053", # MN -> Hennepin County + "28": "28049", # MS -> Hinds County + "29": "29189", # MO -> St. Louis County + "30": "30111", # MT -> Yellowstone County + "31": "31055", # NE -> Douglas County + "32": "32003", # NV -> Clark County + "33": "33011", # NH -> Hillsborough County + "34": "34003", # NJ -> Bergen County + "35": "35001", # NM -> Bernalillo County + "36": "36047", # NY -> Kings County + "37": "37119", # NC -> Mecklenburg County + "38": "38015", # ND -> Cass County + "39": "39049", # OH -> Franklin County + "40": "40109", # OK -> Oklahoma County + "41": "41051", # OR -> Multnomah County + "42": "42101", # PA -> Philadelphia County + "44": "44007", # RI -> Providence County + "45": "45079", # SC -> Richland County + "46": "46103", # SD -> Minnehaha County + "47": "47157", # TN -> Shelby County + "48": "48201", # TX -> Harris County + "49": "49035", # UT -> Salt Lake County + "50": "50007", # VT -> Chittenden County + "51": "51059", # VA -> Fairfax County + "53": "53033", # WA -> King County + "54": "54039", # WV -> Kanawha County + "55": "55079", # WI -> Milwaukee County + "56": "56021", # WY -> Laramie County + } + + default_county = state_default_counties.get(state_fips) + if default_county: + cd_county_map[cd] = {default_county: 1.0} + else: + print(f"Warning: No mapping for CD {cd} in state {state_fips}") + + return cd_county_map + + +def save_mappings(cd_county_map: Dict[str, Dict[str, float]]): + """Save the mappings to a JSON file.""" + + output_file = Path("cd_county_mappings.json") + + with open(output_file, "w") as f: + json.dump(cd_county_map, f, indent=2) + + print(f"\nSaved CD-County mappings to {output_file}") + print(f"Total CDs mapped: {len(cd_county_map)}") + + # Show statistics + counties_per_cd = [len(counties) for counties in cd_county_map.values()] + print(f"Average counties per CD: {np.mean(counties_per_cd):.1f}") + print(f"Max counties in a CD: {max(counties_per_cd)}") + print( + f"CDs with single county: {sum(1 for c in counties_per_cd if c == 1)}" + ) + + +def main(): + """Main function to build and save CD-County mappings.""" + + print("Building Congressional District to County mappings...") + print("=" * 70) + + # Build the complete mapping + cd_county_map = build_complete_cd_county_mapping() + + # Save to file + save_mappings(cd_county_map) + + # Show sample mappings + print("\nSample mappings:") + for cd, counties in list(cd_county_map.items())[:5]: + print(f"\nCD {cd}:") + for county, proportion in counties.items(): + print(f" County {county}: {proportion:.1%}") + + print("\n✅ CD-County mapping complete!") + + return cd_county_map + + +if __name__ == "__main__": + mappings = main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibrate_cds_sparse.py b/policyengine_us_data/datasets/cps/local_area_calibration/calibrate_cds_sparse.py new file mode 100644 index 00000000..352ea3dc --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/calibrate_cds_sparse.py @@ -0,0 +1,535 @@ +# ============================================================================ +# CONFIGURATION +# ============================================================================ +import os + +# Set before any CUDA operations - helps with memory fragmentation on long runs +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +# ============================================================================ +# IMPORTS +# ============================================================================ +from pathlib import Path +from datetime import datetime +from sqlalchemy import create_engine, text +import logging + +# Set up logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + +import torch +import numpy as np +import pandas as pd +from scipy import sparse as sp +from l0.calibration import SparseCalibrationWeights + +from policyengine_us import Microsimulation +from policyengine_us_data.datasets.cps.local_area_calibration.metrics_matrix_geo_stacking_sparse import ( + SparseGeoStackingMatrixBuilder, +) +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, + download_from_huggingface, + filter_target_groups, + get_all_cds_from_database, +) +from policyengine_us_data.datasets.cps.local_area_calibration.household_tracer import HouseholdTracer + +# ============================================================================ +# STEP 1: DATA LOADING AND CD LIST RETRIEVAL +# ============================================================================ + +# db_path = download_from_huggingface("policy_data.db") +db_path = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db" +db_uri = f"sqlite:///{db_path}" +builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + +# Query all congressional district GEOIDs from database +all_cd_geoids = get_all_cds_from_database(db_uri) +print(f"Found {len(all_cd_geoids)} congressional districts in database") + +# For testing, use only 10 CDs (can change to all_cd_geoids for full run) +MODE = "Stratified" +if MODE == "Test": + # Select 10 diverse CDs from different states + # Note: CD GEOIDs are 3-4 digits, format is state_fips + district_number + cds_to_calibrate = [ + "601", # California CD 1 + "652", # California CD 52 + "3601", # New York CD 1 + "3626", # New York CD 26 + "4801", # Texas CD 1 + "4838", # Texas CD 38 + "1201", # Florida CD 1 + "1228", # Florida CD 28 + "1701", # Illinois CD 1 + "1101", # DC at-large + ] + print(f"TEST MODE: Using only {len(cds_to_calibrate)} CDs for testing") + dataset_uri = "hf://policyengine/test/extended_cps_2023.h5" +elif MODE == "Stratified": + cds_to_calibrate = all_cd_geoids + # dataset_uri = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/stratified_extended_cps_2023.h5" + dataset_uri = "/home/baogorek/devl/stratified_10k.h5" + print(f"Stratified mode") +else: + cds_to_calibrate = all_cd_geoids + dataset_uri = "hf://policyengine/test/extended_cps_2023.h5" + print( + f"FULL MODE needs a lot of RAM!: Using all {len(cds_to_calibrate)} CDs" + ) + +sim = Microsimulation(dataset=dataset_uri) + +# ============================================================================ +# STEP 2: BUILD SPARSE MATRIX +# ============================================================================ + +targets_df, X_sparse, household_id_mapping = ( + builder.build_stacked_matrix_sparse( + "congressional_district", cds_to_calibrate, sim + ) +) +print(f"\nMatrix shape: {X_sparse.shape}") +print(f"Total targets: {len(targets_df)}") + +# ============================================================================ +# STEP 2.5: GROUP ANALYSIS AND OPTIONAL FILTERING +# ============================================================================ + +target_groups, group_info = create_target_groups(targets_df) + +print(f"\nAutomatic target grouping:") +print(f"Total groups: {len(np.unique(target_groups))}") +for info in group_info: + print(f" {info}") + + +tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, cds_to_calibrate, sim) +tracer.print_matrix_structure() + +# After reviewing the printout above, specify group IDs to exclude +# Example: groups_to_exclude = [5, 12, 18, 23, 27] +groups_to_exclude = [ + # National -- + 0, # Group 0: National alimony_expense (1 target, value=12,610,232,250) + 1, # Group 1: National alimony_income (1 target, value=12,610,232,250) + 2, # Group 2: National charitable_deduction (1 target, value=63,343,136,630) + 3, # Group 3: National child_support_expense (1 target, value=32,010,589,559) - 51% error + 4, # Group 4: National child_support_received (1 target, value=32,010,589,559) + 5, # Group 5: National eitc (1 target, value=64,440,000,000) + 8, # Group 8: National interest_deduction (1 target, value=24,056,443,062) + 12, # Group 12: National net_worth (1 target, value=155,202,858,467,594)', + 10, # Group 10: National medical_expense_deduction (1 target, value=11,058,203,666) + 15, # Group 15: National person_count (Undocumented population) (1 target, value=19,529,896) + 17, # Group 17: National person_count_ssn_card_type=NONE (1 target, value=12,200,000)', + 18, # Group 18: National qualified_business_income_deduction (1 target, value=61,208,127,308) + 21, # Group 21: National salt_deduction (1 target, value=20,609,969,587)' + # IRS variables at the cd level --- + 34, # Group 34: Tax Units eitc_child_count==0 (436 targets across 436 geographies)', + 35, # Group 35: Tax Units eitc_child_count==1 (436 targets across 436 geographies)', + 36, # Group 36: Tax Units eitc_child_count==2 (436 targets across 436 geographies)', + 37, # Group 37: Tax Units eitc_child_count>2 (436 targets across 436 geographies)', + 31, # 'Group 31: Person Income Distribution (3924 targets across 436 geographies)' + 56, # 'Group 56: AGI Total Amount (436 targets across 436 geographies)', + 42, # Group 42: Tax Units qualified_business_income_deduction>0 (436 targets across 436 geographies) + 64, # Group 64: Qualified Business Income Deduction (436 targets across 436 geographies) + 46, # Group 46: Tax Units rental_income>0 (436 targets across 436 geographies) + 68, # Group 68: Rental Income (436 targets across 436 geographies) + 47, # Group 47: Tax Units salt>0 (436 targets across 436 geographies) + 69, # Group 69: Salt (436 targets across 436 geographies) +] + +targets_df, X_sparse, target_groups = filter_target_groups( + targets_df, X_sparse, target_groups, groups_to_exclude +) + +tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, cds_to_calibrate, sim) +tracer.print_matrix_structure() + +household_targets = tracer.trace_household_targets(565) + + + + +# Extract target values after filtering +targets = targets_df.value.values + +print(f"\nSparse Matrix Statistics:") +print(f"- Shape: {X_sparse.shape}") +print(f"- Non-zero elements: {X_sparse.nnz:,}") +print( + f"- Percent non-zero: {100 * X_sparse.nnz / (X_sparse.shape[0] * X_sparse.shape[1]):.4f}%" +) +print( + f"- Memory usage: {(X_sparse.data.nbytes + X_sparse.indices.nbytes + X_sparse.indptr.nbytes) / 1024**2:.2f} MB" +) + +# Compare to dense matrix memory +dense_memory = ( + X_sparse.shape[0] * X_sparse.shape[1] * 4 / 1024**2 +) # 4 bytes per float32, in MB +print(f"- Dense matrix would use: {dense_memory:.2f} MB") +print( + f"- Memory savings: {100*(1 - (X_sparse.data.nbytes + X_sparse.indices.nbytes + X_sparse.indptr.nbytes)/(dense_memory * 1024**2)):.2f}%" +) + +# ============================================================================ +# STEP 3: EXPORT FOR GPU PROCESSING +# ============================================================================ + +# Create export directory +export_dir = os.path.expanduser("~/Downloads/cd_calibration_data") +os.makedirs(export_dir, exist_ok=True) + +# Save target groups +target_groups_path = os.path.join(export_dir, "cd_target_groups.npy") +np.save(target_groups_path, target_groups) +print(f"\nExported target groups to: {target_groups_path}") + +# Save sparse matrix +sparse_path = os.path.join(export_dir, "cd_matrix_sparse.npz") +sp.save_npz(sparse_path, X_sparse) +print(f"\nExported sparse matrix to: {sparse_path}") + +# Create target names array for epoch logging +target_names = [] +for _, row in targets_df.iterrows(): + # Add clear geographic level prefixes for better readability + if row["geographic_id"] == "US": + geo_prefix = "US" + elif row.get("stratum_group_id") == "state_snap_cost": # State SNAP costs + geo_prefix = f"ST/{row['geographic_id']}" + else: # CD targets + geo_prefix = f"CD/{row['geographic_id']}" + name = f"{geo_prefix}/{row['variable_desc']}" + target_names.append(name) + +# Save target names array (replaces pickled dataframe) +target_names_path = os.path.join(export_dir, "cd_target_names.json") +import json + +with open(target_names_path, "w") as f: + json.dump(target_names, f) +print(f"Exported target names to: {target_names_path}") + +# Save targets array for direct model.fit() use +targets_array_path = os.path.join(export_dir, "cd_targets_array.npy") +np.save(targets_array_path, targets) +print(f"Exported targets array to: {targets_array_path}") + +# Save the full targets_df for debugging +targets_df_path = os.path.join(export_dir, "cd_targets_df.csv") +targets_df.to_csv(targets_df_path, index=False) +print(f"Exported targets dataframe to: {targets_df_path}") + +# Save CD list for reference +cd_list_path = os.path.join(export_dir, "cd_list.txt") +with open(cd_list_path, "w") as f: + for cd in cds_to_calibrate: + f.write(f"{cd}\n") +print(f"Exported CD list to: {cd_list_path}") + +# ============================================================================ +# STEP 4: CALCULATE CD POPULATIONS AND INITIAL WEIGHTS +# ============================================================================ + +cd_populations = {} +for cd_geoid in cds_to_calibrate: + # Match targets for this CD using geographic_id + cd_age_targets = targets_df[ + (targets_df["geographic_id"] == cd_geoid) + & (targets_df["variable"] == "person_count") + & (targets_df["variable_desc"].str.contains("age", na=False)) + ] + if not cd_age_targets.empty: + unique_ages = cd_age_targets.drop_duplicates(subset=["variable_desc"]) + cd_populations[cd_geoid] = unique_ages["value"].sum() + +if cd_populations: + min_pop = min(cd_populations.values()) + max_pop = max(cd_populations.values()) + print(f"\nCD population range: {min_pop:,.0f} to {max_pop:,.0f}") +else: + print("\nWarning: Could not calculate CD populations from targets") + min_pop = 700000 # Approximate average CD population + +# Create arrays for both keep probabilities and initial weights +keep_probs = np.zeros(X_sparse.shape[1]) +init_weights = np.zeros(X_sparse.shape[1]) +cumulative_idx = 0 +cd_household_indices = {} # Maps CD to (start_col, end_col) in X_sparse + +# Calculate weights for ALL CDs +for cd_key, household_list in household_id_mapping.items(): + cd_geoid = cd_key.replace("cd", "") + n_households = len(household_list) + + if cd_geoid in cd_populations: + cd_pop = cd_populations[cd_geoid] + else: + cd_pop = min_pop # Use minimum as default + + # Scale initial keep probability by population + pop_ratio = cd_pop / min_pop + adjusted_keep_prob = min(0.15, 0.02 * np.sqrt(pop_ratio)) + keep_probs[cumulative_idx : cumulative_idx + n_households] = ( + adjusted_keep_prob + ) + + # Calculate initial weight + base_weight = cd_pop / n_households + sparsity_adjustment = 1.0 / np.sqrt(adjusted_keep_prob) + initial_weight = base_weight * sparsity_adjustment + # initial_weight = np.clip(initial_weight, 0, 100000) # Not clipping + + init_weights[cumulative_idx : cumulative_idx + n_households] = ( + initial_weight + ) + cd_household_indices[cd_geoid] = ( + cumulative_idx, + cumulative_idx + n_households, + ) + cumulative_idx += n_households + +print("\nCD-aware keep probabilities and initial weights calculated.") +print( + f"Initial weight range: {init_weights.min():.0f} to {init_weights.max():.0f}" +) +print(f"Mean initial weight: {init_weights.mean():.0f}") + +# Save initialization arrays +keep_probs_path = os.path.join(export_dir, "cd_keep_probs.npy") +np.save(keep_probs_path, keep_probs) +print(f"Exported keep probabilities to: {keep_probs_path}") + +init_weights_path = os.path.join(export_dir, "cd_init_weights.npy") +np.save(init_weights_path, init_weights) +print(f"Exported initial weights to: {init_weights_path}") + +# ============================================================================ +# STEP 6: CREATE EXPLORATION PACKAGE (BEFORE CALIBRATION) +# ============================================================================ +print("\n" + "=" * 70) +print("CREATING EXPLORATION PACKAGE") +print("=" * 70) + +# Save exploration package with just the essentials (before calibration) +exploration_package = { + "X_sparse": X_sparse, + "targets_df": targets_df, + "household_id_mapping": household_id_mapping, + "cd_household_indices": cd_household_indices, + "dataset_uri": dataset_uri, + "cds_to_calibrate": cds_to_calibrate, + "initial_weights": init_weights, + "keep_probs": keep_probs, + "target_groups": target_groups, +} + +package_path = os.path.join(export_dir, "calibration_package.pkl") +with open(package_path, "wb") as f: + import pickle + + pickle.dump(exploration_package, f) + +print(f"✅ Exploration package saved to {package_path}") +print(f" Size: {os.path.getsize(package_path) / 1024 / 1024:.1f} MB") +print("\nTo use the package:") +print(" with open('calibration_package.pkl', 'rb') as f:") +print(" data = pickle.load(f)") +print(" X_sparse = data['X_sparse']") +print(" targets_df = data['targets_df']") +print(" # See create_and_use_exploration_package.py for usage examples") + +# ============================================================================ +# STEP 7: L0 CALIBRATION WITH EPOCH LOGGING +# ============================================================================ + +print("\n" + "=" * 70) +print("RUNNING L0 CALIBRATION WITH EPOCH LOGGING") +print("=" * 70) + +# Create model with per-feature keep probabilities and weights +model = SparseCalibrationWeights( + n_features=X_sparse.shape[1], + beta=2 / 3, + gamma=-0.1, + zeta=1.1, + init_keep_prob=0.999, # keep_probs, # CD-specific keep probabilities + init_weights=init_weights, # CD population-based initial weights + log_weight_jitter_sd=0.05, + log_alpha_jitter_sd=0.01, + # device = "cuda", # Uncomment for GPU +) + +# Configuration for epoch logging +ENABLE_EPOCH_LOGGING = True # Set to False to disable logging +EPOCHS_PER_CHUNK = 2 # Train in chunks of 50 epochs +TOTAL_EPOCHS = 4 # Total epochs to train (set to 3 for quick test) +# For testing, you can use: +# EPOCHS_PER_CHUNK = 1 +# TOTAL_EPOCHS = 3 + +# Initialize CSV files for incremental writing +if ENABLE_EPOCH_LOGGING: + log_path = os.path.join(export_dir, "cd_calibration_log.csv") + # Write header + with open(log_path, "w") as f: + f.write( + "target_name,estimate,target,epoch,error,rel_error,abs_error,rel_abs_error,loss\n" + ) + print(f"Initialized incremental log at: {log_path}") + +# Initialize sparsity tracking CSV with timestamp +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +sparsity_path = os.path.join( + export_dir, f"cd_sparsity_history_{timestamp}.csv" +) +with open(sparsity_path, "w") as f: + f.write("epoch,active_weights,total_weights,sparsity_pct\n") +print(f"Initialized sparsity tracking at: {sparsity_path}") + +# Train in chunks and capture metrics between chunks +for chunk_start in range(0, TOTAL_EPOCHS, EPOCHS_PER_CHUNK): + chunk_epochs = min(EPOCHS_PER_CHUNK, TOTAL_EPOCHS - chunk_start) + current_epoch = chunk_start + chunk_epochs + + print( + f"\nTraining epochs {chunk_start + 1} to {current_epoch} of {TOTAL_EPOCHS}" + ) + + model.fit( + M=X_sparse, + y=targets, + target_groups=target_groups, + lambda_l0=1.0e-6, + lambda_l2=0, + lr=0.2, + epochs=chunk_epochs, + loss_type="relative", + verbose=True, + verbose_freq=chunk_epochs, # Print at end of chunk + ) + + # Track sparsity after each chunk + active_info = model.get_active_weights() + active_count = active_info["count"] + total_count = X_sparse.shape[1] + sparsity_pct = 100 * (1 - active_count / total_count) + + with open(sparsity_path, "a") as f: + f.write( + f"{current_epoch},{active_count},{total_count},{sparsity_pct:.4f}\n" + ) + + if ENABLE_EPOCH_LOGGING: + # Capture metrics after this chunk + with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + + # Write incrementally to CSV + with open(log_path, "a") as f: + for i in range(len(targets)): + # Calculate all metrics + estimate = y_pred[i] + target = targets[i] + error = estimate - target + rel_error = error / target if target != 0 else 0 + abs_error = abs(error) + rel_abs_error = abs(rel_error) + loss = rel_error**2 + + # Write row directly to file + f.write( + f'"{target_names[i]}",{estimate},{target},{current_epoch},' + f"{error},{rel_error},{abs_error},{rel_abs_error},{loss}\n" + ) + + # Clear GPU cache after large prediction operation + if torch.cuda.is_available(): + torch.cuda.empty_cache() +# Save epoch logging data if enabled +if ENABLE_EPOCH_LOGGING: + print(f"\nIncremental log complete at: {log_path}") + print( + f"Log contains metrics for {TOTAL_EPOCHS // EPOCHS_PER_CHUNK} logging points" + ) + +# Final evaluation +with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + y_actual = targets + rel_errors = np.abs((y_actual - y_pred) / (y_actual + 1)) + + print(f"\nAfter {TOTAL_EPOCHS} epochs:") + print(f"Mean relative error: {np.mean(rel_errors):.2%}") + print(f"Max relative error: {np.max(rel_errors):.2%}") + + # Get sparsity info + active_info = model.get_active_weights() + final_sparsity = 100 * (1 - active_info["count"] / X_sparse.shape[1]) + print( + f"Active weights: {active_info['count']} out of {X_sparse.shape[1]} ({100*active_info['count']/X_sparse.shape[1]:.2f}%)" + ) + print(f"Final sparsity: {final_sparsity:.2f}%") + + # Save final weights + w = model.get_weights(deterministic=True).cpu().numpy() + final_weights_path = os.path.join( + export_dir, f"cd_weights_{TOTAL_EPOCHS}epochs.npy" + ) + np.save(final_weights_path, w) + print( + f"\nSaved final weights ({TOTAL_EPOCHS} epochs) to: {final_weights_path}" + ) + +print( + "\n✅ L0 calibration complete! Matrix, targets, and epoch log are ready for analysis." +) + +# ============================================================================ +# SUMMARY +# ============================================================================ + +print("\n" + "=" * 70) +print("CD CALIBRATION DATA EXPORT COMPLETE") +print("=" * 70) +print(f"\nAll files exported to: {export_dir}") +print("\nFiles ready for GPU transfer:") +print(f" 1. cd_matrix_sparse.npz - Sparse calibration matrix") +print(f" 2. cd_target_names.json - Target names for epoch logging") +print(f" 3. cd_targets_array.npy - Target values array") +print(f" 4. cd_targets_df.csv - Full targets dataframe for debugging") +print(f" 5. cd_keep_probs.npy - Initial keep probabilities") +print(f" 6. cd_init_weights.npy - Initial weights") +print(f" 7. cd_target_groups.npy - Target grouping for loss") +print(f" 8. cd_list.txt - List of CD GEOIDs") +if "w" in locals(): + print( + f" 9. cd_weights_{TOTAL_EPOCHS}epochs.npy - Final calibration weights" + ) +if ENABLE_EPOCH_LOGGING: + print( + f" 10. cd_calibration_log.csv - Epoch-by-epoch metrics for dashboard" + ) +print( + f" 11. cd_sparsity_history_{timestamp}.csv - Sparsity tracking over epochs" +) + +print("\nTo load on GPU platform:") +print(" import scipy.sparse as sp") +print(" import numpy as np") +print(" import pandas as pd") +print(f" X = sp.load_npz('{sparse_path}')") +print(f" targets = np.load('{targets_array_path}')") +print(f" target_groups = np.load('{target_groups_path}')") +print(f" keep_probs = np.load('{keep_probs_path}')") +print(f" init_weights = np.load('{init_weights_path}')") + +# Note: The exploration package was already created earlier (Step 6) +# It can be used immediately without waiting for calibration to complete +print("\n📦 Exploration package available at:", package_path) +print(" Can be shared with coworkers for data exploration") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py new file mode 100644 index 00000000..751eec6f --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/calibration_utils.py @@ -0,0 +1,694 @@ +""" +Shared utilities for calibration scripts. +""" + +import os +import urllib +import tempfile +from typing import Tuple, List, Optional + +import numpy as np +import pandas as pd + +from policyengine_us.variables.household.demographic.geographic.state_name import ( + StateName, +) +from policyengine_us.variables.household.demographic.geographic.state_code import ( + StateCode, +) + + +# ============================================================================= +# State/Geographic Mappings (Single Source of Truth) +# ============================================================================= + +STATE_CODES = { + 1: "AL", 2: "AK", 4: "AZ", 5: "AR", 6: "CA", 8: "CO", 9: "CT", 10: "DE", + 11: "DC", 12: "FL", 13: "GA", 15: "HI", 16: "ID", 17: "IL", 18: "IN", + 19: "IA", 20: "KS", 21: "KY", 22: "LA", 23: "ME", 24: "MD", 25: "MA", + 26: "MI", 27: "MN", 28: "MS", 29: "MO", 30: "MT", 31: "NE", 32: "NV", + 33: "NH", 34: "NJ", 35: "NM", 36: "NY", 37: "NC", 38: "ND", 39: "OH", + 40: "OK", 41: "OR", 42: "PA", 44: "RI", 45: "SC", 46: "SD", 47: "TN", + 48: "TX", 49: "UT", 50: "VT", 51: "VA", 53: "WA", 54: "WV", 55: "WI", + 56: "WY", +} + +STATE_FIPS_TO_NAME = { + 1: StateName.AL, 2: StateName.AK, 4: StateName.AZ, 5: StateName.AR, + 6: StateName.CA, 8: StateName.CO, 9: StateName.CT, 10: StateName.DE, + 11: StateName.DC, 12: StateName.FL, 13: StateName.GA, 15: StateName.HI, + 16: StateName.ID, 17: StateName.IL, 18: StateName.IN, 19: StateName.IA, + 20: StateName.KS, 21: StateName.KY, 22: StateName.LA, 23: StateName.ME, + 24: StateName.MD, 25: StateName.MA, 26: StateName.MI, 27: StateName.MN, + 28: StateName.MS, 29: StateName.MO, 30: StateName.MT, 31: StateName.NE, + 32: StateName.NV, 33: StateName.NH, 34: StateName.NJ, 35: StateName.NM, + 36: StateName.NY, 37: StateName.NC, 38: StateName.ND, 39: StateName.OH, + 40: StateName.OK, 41: StateName.OR, 42: StateName.PA, 44: StateName.RI, + 45: StateName.SC, 46: StateName.SD, 47: StateName.TN, 48: StateName.TX, + 49: StateName.UT, 50: StateName.VT, 51: StateName.VA, 53: StateName.WA, + 54: StateName.WV, 55: StateName.WI, 56: StateName.WY, +} + +STATE_FIPS_TO_CODE = { + 1: StateCode.AL, 2: StateCode.AK, 4: StateCode.AZ, 5: StateCode.AR, + 6: StateCode.CA, 8: StateCode.CO, 9: StateCode.CT, 10: StateCode.DE, + 11: StateCode.DC, 12: StateCode.FL, 13: StateCode.GA, 15: StateCode.HI, + 16: StateCode.ID, 17: StateCode.IL, 18: StateCode.IN, 19: StateCode.IA, + 20: StateCode.KS, 21: StateCode.KY, 22: StateCode.LA, 23: StateCode.ME, + 24: StateCode.MD, 25: StateCode.MA, 26: StateCode.MI, 27: StateCode.MN, + 28: StateCode.MS, 29: StateCode.MO, 30: StateCode.MT, 31: StateCode.NE, + 32: StateCode.NV, 33: StateCode.NH, 34: StateCode.NJ, 35: StateCode.NM, + 36: StateCode.NY, 37: StateCode.NC, 38: StateCode.ND, 39: StateCode.OH, + 40: StateCode.OK, 41: StateCode.OR, 42: StateCode.PA, 44: StateCode.RI, + 45: StateCode.SC, 46: StateCode.SD, 47: StateCode.TN, 48: StateCode.TX, + 49: StateCode.UT, 50: StateCode.VT, 51: StateCode.VA, 53: StateCode.WA, + 54: StateCode.WV, 55: StateCode.WI, 56: StateCode.WY, +} + + +# ============================================================================= +# Simulation Cache Utilities +# ============================================================================= + +def get_calculated_variables(sim) -> List[str]: + """ + Return variables that should be cleared for state-swap recalculation. + + Includes variables with formulas, adds, or subtracts. + + Excludes ID variables (person_id, household_id, etc.) because: + 1. They have formulas that generate sequential IDs (0, 1, 2, ...) + 2. We need the original H5 values, not regenerated sequences + 3. PolicyEngine's random() function uses entity IDs as seeds: + seed = abs(entity_id * 100 + count_random_calls) + If IDs change, random-dependent variables (SSI resource test, + WIC nutritional risk, WIC takeup) produce different results. + """ + exclude_ids = {'person_id', 'household_id', 'tax_unit_id', 'spm_unit_id', + 'family_id', 'marital_unit_id'} + return [name for name, var in sim.tax_benefit_system.variables.items() + if (var.formulas or getattr(var, 'adds', None) or getattr(var, 'subtracts', None)) + and name not in exclude_ids] + + +def apply_op(values: np.ndarray, op: str, val: str) -> np.ndarray: + """Apply constraint operation to values array.""" + try: + parsed = float(val) + if parsed.is_integer(): + parsed = int(parsed) + except ValueError: + if val == 'True': + parsed = True + elif val == 'False': + parsed = False + else: + parsed = val + + if op in ('==', '='): + return values == parsed + if op == '>': + return values > parsed + if op == '>=': + return values >= parsed + if op == '<': + return values < parsed + if op == '<=': + return values <= parsed + if op == '!=': + return values != parsed + return np.ones(len(values), dtype=bool) + + +# ============================================================================= +# Geographic Utilities +# ============================================================================= + + +def _get_geo_level(geo_id) -> int: + """Return geographic level: 0=National, 1=State, 2=District.""" + if geo_id == 'US': + return 0 + try: + val = int(geo_id) + return 1 if val < 100 else 2 + except (ValueError, TypeError): + return 3 + + +def create_target_groups( + targets_df: pd.DataFrame, +) -> Tuple[np.ndarray, List[str]]: + """ + Automatically create target groups based on metadata. + + Grouping rules: + 1. Groups are ordered by geographic level: National → State → District + 2. Within each level, targets are grouped by variable type + 3. Each group contributes equally to the total loss + + Parameters + ---------- + targets_df : pd.DataFrame + DataFrame containing target metadata with columns: + - stratum_group_id: Identifier for the type of target + - geographic_id: Geographic identifier (US, state FIPS, CD GEOID) + - variable: Variable name + - value: Target value + + Returns + ------- + target_groups : np.ndarray + Array of group IDs for each target + group_info : List[str] + List of descriptive strings for each group + """ + target_groups = np.zeros(len(targets_df), dtype=int) + group_id = 0 + group_info = [] + processed_mask = np.zeros(len(targets_df), dtype=bool) + + print("\n=== Creating Target Groups ===") + + # Add geo_level column for sorting + targets_df = targets_df.copy() + targets_df['_geo_level'] = targets_df['geographic_id'].apply(_get_geo_level) + + geo_level_names = {0: "National", 1: "State", 2: "District"} + + # Process by geographic level: National (0) → State (1) → District (2) + for level in [0, 1, 2]: + level_mask = targets_df['_geo_level'] == level + if not level_mask.any(): + continue + + level_name = geo_level_names.get(level, f"Level {level}") + print(f"\n{level_name} targets:") + + # Get unique variables at this level + level_df = targets_df[level_mask & ~processed_mask] + unique_vars = sorted(level_df['variable'].unique()) + + for var_name in unique_vars: + var_mask = ( + (targets_df['variable'] == var_name) + & level_mask + & ~processed_mask + ) + + if not var_mask.any(): + continue + + matching = targets_df[var_mask] + n_targets = var_mask.sum() + n_geos = matching['geographic_id'].nunique() + + # Assign group + target_groups[var_mask] = group_id + processed_mask |= var_mask + + # Create descriptive label + stratum_group = matching['stratum_group_id'].iloc[0] + if var_name == "household_count" and stratum_group == 4: + label = "SNAP Household Count" + elif var_name == "snap": + label = "Snap" + else: + label = var_name.replace("_", " ").title() + + # Format output based on level and count + if n_targets == 1: + value = matching['value'].iloc[0] + info_str = f"{level_name} {label} (1 target, value={value:,.0f})" + print_str = f" Group {group_id}: {label} = {value:,.0f}" + else: + info_str = f"{level_name} {label} ({n_targets} targets)" + print_str = f" Group {group_id}: {label} ({n_targets} targets)" + + group_info.append(f"Group {group_id}: {info_str}") + print(print_str) + group_id += 1 + + print(f"\nTotal groups created: {group_id}") + print("=" * 40) + + return target_groups, group_info + + +# NOTE: this is for public files. A TODO is to contrast it with what we already have +def download_from_huggingface(file_name): + """Download a file from HuggingFace to a temporary location.""" + base_url = "https://huggingface.co/policyengine/test/resolve/main/" + url = base_url + file_name + + # Create temporary file + temp_dir = tempfile.gettempdir() + local_path = os.path.join(temp_dir, file_name) + + # Check if already downloaded + if not os.path.exists(local_path): + print(f"Downloading {file_name} from HuggingFace...") + urllib.request.urlretrieve(url, local_path) + print(f"Downloaded to {local_path}") + else: + print(f"Using cached {local_path}") + + return local_path + + +def uprate_target_value( + value: float, variable_name: str, from_year: int, to_year: int, sim=None +) -> float: + """ + Uprate a target value from source year to dataset year. + + Parameters + ---------- + value : float + The value to uprate + variable_name : str + Name of the variable (used to determine uprating type) + from_year : int + Source year of the value + to_year : int + Target year to uprate to + sim : Microsimulation, optional + Existing microsimulation instance for getting parameters + + Returns + ------- + float + Uprated value + """ + if from_year == to_year: + return value + + # Need PolicyEngine parameters for uprating factors + if sim is None: + from policyengine_us import Microsimulation + + sim = Microsimulation( + dataset="hf://policyengine/test/extended_cps_2023.h5" + ) + + params = sim.tax_benefit_system.parameters + + # Determine uprating type based on variable + # Count variables use population uprating + count_variables = [ + "person_count", + "household_count", + "tax_unit_count", + "spm_unit_count", + "family_count", + "marital_unit_count", + ] + + if variable_name in count_variables: + # Use population uprating for counts + try: + pop_from = params.calibration.gov.census.populations.total( + from_year + ) + pop_to = params.calibration.gov.census.populations.total(to_year) + factor = pop_to / pop_from + except Exception as e: + print( + f"Warning: Could not get population uprating for {from_year}->{to_year}: {e}" + ) + factor = 1.0 + else: + # Use CPI-U for monetary values (default) + try: + cpi_from = params.gov.bls.cpi.cpi_u(from_year) + cpi_to = params.gov.bls.cpi.cpi_u(to_year) + factor = cpi_to / cpi_from + except Exception as e: + print( + f"Warning: Could not get CPI uprating for {from_year}->{to_year}: {e}" + ) + factor = 1.0 + + return value * factor + + +def uprate_targets_df( + targets_df: pd.DataFrame, target_year: int, sim=None +) -> pd.DataFrame: + """ + Uprate all targets in a DataFrame to the target year. + + Parameters + ---------- + targets_df : pd.DataFrame + DataFrame containing targets with 'period', 'variable', and 'value' columns + target_year : int + Year to uprate all targets to + sim : Microsimulation, optional + Existing microsimulation instance for getting parameters + + Returns + ------- + pd.DataFrame + DataFrame with uprated values and tracking columns: + - original_value: The value before uprating + - uprating_factor: The factor applied + - uprating_source: 'CPI-U', 'Population', or 'None' + """ + if "period" not in targets_df.columns: + return targets_df + + df = targets_df.copy() + + # Check if already uprated (avoid double uprating) + if "uprating_factor" in df.columns: + return df + + # Store original values and initialize tracking columns + df["original_value"] = df["value"] + df["uprating_factor"] = 1.0 + df["uprating_source"] = "None" + + # Identify rows needing uprating + needs_uprating = df["period"] != target_year + + if not needs_uprating.any(): + return df + + # Get parameters once + if sim is None: + from policyengine_us import Microsimulation + + sim = Microsimulation( + dataset="hf://policyengine/test/extended_cps_2023.h5" + ) + params = sim.tax_benefit_system.parameters + + # Get unique years that need uprating + unique_years = set(df.loc[needs_uprating, "period"].unique()) + + # Remove NaN values if any + unique_years = {year for year in unique_years if pd.notna(year)} + + # Pre-calculate all uprating factors + factors = {} + for from_year in unique_years: + # Convert numpy int64 to Python int for parameter lookups + from_year_int = int(from_year) + target_year_int = int(target_year) + + if from_year_int == target_year_int: + factors[(from_year, "cpi")] = 1.0 + factors[(from_year, "population")] = 1.0 + continue + + # CPI-U factor + try: + cpi_from = params.gov.bls.cpi.cpi_u(from_year_int) + cpi_to = params.gov.bls.cpi.cpi_u(target_year_int) + factors[(from_year, "cpi")] = cpi_to / cpi_from + except Exception as e: + print( + f" Warning: CPI uprating failed for {from_year_int}->{target_year_int}: {e}" + ) + factors[(from_year, "cpi")] = 1.0 + + # Population factor + try: + pop_from = params.calibration.gov.census.populations.total( + from_year_int + ) + pop_to = params.calibration.gov.census.populations.total( + target_year_int + ) + factors[(from_year, "population")] = pop_to / pop_from + except Exception as e: + print( + f" Warning: Population uprating failed for {from_year_int}->{target_year_int}: {e}" + ) + factors[(from_year, "population")] = 1.0 + + # Define count variables (use population uprating) + count_variables = { + "person_count", + "household_count", + "tax_unit_count", + "spm_unit_count", + "family_count", + "marital_unit_count", + } + + # Vectorized application of uprating factors + for from_year in unique_years: + year_mask = (df["period"] == from_year) & needs_uprating + + # Population-based variables + pop_mask = year_mask & df["variable"].isin(count_variables) + if pop_mask.any(): + factor = factors[(from_year, "population")] + df.loc[pop_mask, "value"] *= factor + df.loc[pop_mask, "uprating_factor"] = factor + df.loc[pop_mask, "uprating_source"] = "Population" + + # CPI-based variables (everything else) + cpi_mask = year_mask & ~df["variable"].isin(count_variables) + if cpi_mask.any(): + factor = factors[(from_year, "cpi")] + df.loc[cpi_mask, "value"] *= factor + df.loc[cpi_mask, "uprating_factor"] = factor + df.loc[cpi_mask, "uprating_source"] = "CPI-U" + + # Summary logging (only if factors are not all 1.0) + uprated_count = needs_uprating.sum() + if uprated_count > 0: + # Check if any real uprating happened + cpi_factors = df.loc[ + df["uprating_source"] == "CPI-U", "uprating_factor" + ] + pop_factors = df.loc[ + df["uprating_source"] == "Population", "uprating_factor" + ] + + cpi_changed = len(cpi_factors) > 0 and (cpi_factors != 1.0).any() + pop_changed = len(pop_factors) > 0 and (pop_factors != 1.0).any() + + if cpi_changed or pop_changed: + # Count unique source years (excluding NaN and target year) + source_years = df.loc[needs_uprating, "period"].dropna().unique() + source_years = [y for y in source_years if y != target_year] + unique_sources = len(source_years) + + print( + f"\n ✓ Uprated {uprated_count:,} targets from year(s) {sorted(source_years)} to {target_year}" + ) + + if cpi_changed: + cpi_count = (df["uprating_source"] == "CPI-U").sum() + print( + f" - {cpi_count:,} monetary targets: CPI factors {cpi_factors.min():.4f} - {cpi_factors.max():.4f}" + ) + if pop_changed: + pop_count = (df["uprating_source"] == "Population").sum() + print( + f" - {pop_count:,} count targets: Population factors {pop_factors.min():.4f} - {pop_factors.max():.4f}" + ) + + return df + + +def filter_target_groups( + targets_df: pd.DataFrame, + X_sparse, + target_groups: np.ndarray, + groups_to_exclude: List[int], +) -> Tuple[pd.DataFrame, any, np.ndarray]: + """ + Filter out specified target groups from targets_df and X_sparse. + + Parameters + ---------- + targets_df : pd.DataFrame + DataFrame containing target metadata + X_sparse : scipy.sparse matrix + Sparse calibration matrix (rows = targets, cols = households) + target_groups : np.ndarray + Array of group IDs for each target + groups_to_exclude : List[int] + List of group IDs to exclude + + Returns + ------- + filtered_targets_df : pd.DataFrame + Filtered targets dataframe + filtered_X_sparse : scipy.sparse matrix + Filtered sparse matrix + filtered_target_groups : np.ndarray + Filtered target groups array + """ + if len(groups_to_exclude) == 0: + return targets_df, X_sparse, target_groups + + keep_mask = ~np.isin(target_groups, groups_to_exclude) + + n_to_remove = (~keep_mask).sum() + is_national = targets_df["geographic_id"] == "US" + n_national_removed = is_national[~keep_mask].sum() + n_cd_removed = n_to_remove - n_national_removed + + print(f"\nExcluding groups: {groups_to_exclude}") + print(f"Total targets removed: {n_to_remove} out of {len(targets_df)}") + print(f" - CD/state-level targets removed: {n_cd_removed}") + print(f" - National-level targets removed: {n_national_removed}") + + filtered_targets_df = targets_df[keep_mask].reset_index(drop=True) + filtered_X_sparse = X_sparse[keep_mask, :] + filtered_target_groups = target_groups[keep_mask] + + print( + f"After filtering: {len(filtered_targets_df)} targets, matrix shape: {filtered_X_sparse.shape}" + ) + + return filtered_targets_df, filtered_X_sparse, filtered_target_groups + + +def get_all_cds_from_database(db_uri: str) -> List[str]: + """ + Get ordered list of all CD GEOIDs from database. + + This is the single source of truth for CD queries, replacing + duplicate inline SQL queries throughout the codebase. + + Args: + db_uri: SQLAlchemy database URI (e.g., "sqlite:///path/to/policy_data.db") + + Returns: + List of CD GEOID strings ordered by value (e.g., ['101', '102', ..., '5600']) + """ + from sqlalchemy import create_engine, text + + engine = create_engine(db_uri) + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' + ORDER BY sc.value + """ + + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + return [row[0] for row in result] + + +def get_cd_index_mapping(): + """ + Get the canonical CD GEOID to index mapping. + This MUST be consistent across all uses! + Each CD gets 10,000 IDs for each entity type. + + Returns: + dict: Maps CD GEOID string to index (0-435) + dict: Maps index to CD GEOID string + list: Ordered list of CD GEOIDs + """ + from sqlalchemy import create_engine, text + from pathlib import Path + import os + + script_dir = Path(__file__).parent + db_path = script_dir.parent.parent.parent / "storage" / "policy_data.db" + + if not db_path.exists(): + raise FileNotFoundError( + f"Database file not found at {db_path}. " + f"Current working directory: {os.getcwd()}" + ) + + db_uri = f"sqlite:///{db_path}" + engine = create_engine(db_uri) + + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = "congressional_district_geoid" + ORDER BY sc.value + """ + + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + cds_ordered = [row[0] for row in result] + + # Create bidirectional mappings + cd_to_index = {cd: idx for idx, cd in enumerate(cds_ordered)} + index_to_cd = {idx: cd for idx, cd in enumerate(cds_ordered)} + + return cd_to_index, index_to_cd, cds_ordered + + +def get_id_range_for_cd(cd_geoid, entity_type="household"): + """ + Get the ID range for a specific CD and entity type. + + Args: + cd_geoid: Congressional district GEOID string (e.g., '3701') + entity_type: Entity type ('household', 'person', 'tax_unit', 'spm_unit', 'marital_unit') + + Returns: + tuple: (start_id, end_id) inclusive + """ + cd_to_index, _, _ = get_cd_index_mapping() + + if cd_geoid not in cd_to_index: + raise ValueError(f"Unknown CD GEOID: {cd_geoid}") + + idx = cd_to_index[cd_geoid] + base_start = idx * 10_000 + base_end = base_start + 9_999 + + # Offset different entities to avoid ID collisions + # Max base ID is 435 * 10,000 + 9,999 = 4,359,999 + # Must ensure max_id * 100 < 2,147,483,647 (int32 max) + # So max_id must be < 21,474,836 + # NOTE: Currently only household/person use CD-based ranges + # Tax/SPM/marital units still use sequential numbering from 0 + offsets = { + "household": 0, # Max: 4,359,999 + "person": 5_000_000, # Max: 9,359,999 + "tax_unit": 0, # Not implemented yet + "spm_unit": 0, # Not implemented yet + "marital_unit": 0, # Not implemented yet + } + + offset = offsets.get(entity_type, 0) + return base_start + offset, base_end + offset + + +def get_cd_from_id(entity_id): + """ + Determine which CD an entity ID belongs to. + + Args: + entity_id: The household/person/etc ID + + Returns: + str: CD GEOID + """ + # Remove offset to get base ID + # Currently only persons have offset (5M) + if entity_id >= 5_000_000: + base_id = entity_id - 5_000_000 # Person + else: + base_id = entity_id # Household (or tax/spm/marital unit) + + idx = base_id // 10_000 + _, index_to_cd, _ = get_cd_index_mapping() + + if idx not in index_to_cd: + raise ValueError( + f"ID {entity_id} (base {base_id}) maps to invalid CD index {idx}" + ) + + return index_to_cd[idx] diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/cd_county_mappings.json b/policyengine_us_data/datasets/cps/local_area_calibration/cd_county_mappings.json new file mode 100644 index 00000000..4b959bff --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/cd_county_mappings.json @@ -0,0 +1,1321 @@ +{ + "1001": { + "10003": 1.0 + }, + "101": { + "01073": 1.0 + }, + "102": { + "01073": 1.0 + }, + "103": { + "01073": 1.0 + }, + "104": { + "01073": 1.0 + }, + "105": { + "01073": 1.0 + }, + "106": { + "01073": 1.0 + }, + "107": { + "01073": 1.0 + }, + "1101": { + "11001": 1.0 + }, + "1201": { + "12033": 0.4, + "12091": 0.3, + "12113": 0.3 + }, + "1202": { + "12086": 1.0 + }, + "1203": { + "12086": 1.0 + }, + "1204": { + "12086": 1.0 + }, + "1205": { + "12086": 1.0 + }, + "1206": { + "12086": 1.0 + }, + "1207": { + "12086": 1.0 + }, + "1208": { + "12086": 1.0 + }, + "1209": { + "12086": 1.0 + }, + "1210": { + "12086": 1.0 + }, + "1211": { + "12086": 1.0 + }, + "1212": { + "12086": 1.0 + }, + "1213": { + "12086": 1.0 + }, + "1214": { + "12086": 1.0 + }, + "1215": { + "12086": 1.0 + }, + "1216": { + "12086": 1.0 + }, + "1217": { + "12086": 1.0 + }, + "1218": { + "12086": 1.0 + }, + "1219": { + "12086": 1.0 + }, + "1220": { + "12086": 1.0 + }, + "1221": { + "12086": 1.0 + }, + "1222": { + "12086": 1.0 + }, + "1223": { + "12086": 1.0 + }, + "1224": { + "12086": 1.0 + }, + "1225": { + "12086": 1.0 + }, + "1226": { + "12086": 1.0 + }, + "1227": { + "12086": 1.0 + }, + "1228": { + "12086": 1.0 + }, + "1301": { + "13121": 1.0 + }, + "1302": { + "13121": 1.0 + }, + "1303": { + "13121": 1.0 + }, + "1304": { + "13121": 1.0 + }, + "1305": { + "13121": 1.0 + }, + "1306": { + "13121": 1.0 + }, + "1307": { + "13121": 1.0 + }, + "1308": { + "13121": 1.0 + }, + "1309": { + "13121": 1.0 + }, + "1310": { + "13121": 1.0 + }, + "1311": { + "13121": 1.0 + }, + "1312": { + "13121": 1.0 + }, + "1313": { + "13121": 1.0 + }, + "1314": { + "13121": 1.0 + }, + "1501": { + "15003": 1.0 + }, + "1502": { + "15003": 1.0 + }, + "1601": { + "16001": 1.0 + }, + "1602": { + "16001": 1.0 + }, + "1701": { + "17031": 1.0 + }, + "1702": { + "17031": 1.0 + }, + "1703": { + "17031": 1.0 + }, + "1704": { + "17031": 1.0 + }, + "1705": { + "17031": 1.0 + }, + "1706": { + "17031": 1.0 + }, + "1707": { + "17031": 1.0 + }, + "1708": { + "17031": 1.0 + }, + "1709": { + "17031": 1.0 + }, + "1710": { + "17031": 1.0 + }, + "1711": { + "17031": 1.0 + }, + "1712": { + "17031": 1.0 + }, + "1713": { + "17031": 1.0 + }, + "1714": { + "17031": 1.0 + }, + "1715": { + "17031": 1.0 + }, + "1716": { + "17031": 1.0 + }, + "1717": { + "17031": 1.0 + }, + "1801": { + "18097": 1.0 + }, + "1802": { + "18097": 1.0 + }, + "1803": { + "18097": 1.0 + }, + "1804": { + "18097": 1.0 + }, + "1805": { + "18097": 1.0 + }, + "1806": { + "18097": 1.0 + }, + "1807": { + "18097": 1.0 + }, + "1808": { + "18097": 1.0 + }, + "1809": { + "18097": 1.0 + }, + "1901": { + "19153": 1.0 + }, + "1902": { + "19153": 1.0 + }, + "1903": { + "19153": 1.0 + }, + "1904": { + "19153": 1.0 + }, + "2001": { + "20091": 1.0 + }, + "2002": { + "20091": 1.0 + }, + "2003": { + "20091": 1.0 + }, + "2004": { + "20091": 1.0 + }, + "201": { + "02020": 1.0 + }, + "2101": { + "21111": 1.0 + }, + "2102": { + "21111": 1.0 + }, + "2103": { + "21111": 1.0 + }, + "2104": { + "21111": 1.0 + }, + "2105": { + "21111": 1.0 + }, + "2106": { + "21111": 1.0 + }, + "2201": { + "22071": 1.0 + }, + "2202": { + "22071": 1.0 + }, + "2203": { + "22071": 1.0 + }, + "2204": { + "22071": 1.0 + }, + "2205": { + "22071": 1.0 + }, + "2206": { + "22071": 1.0 + }, + "2301": { + "23005": 1.0 + }, + "2302": { + "23005": 1.0 + }, + "2401": { + "24003": 1.0 + }, + "2402": { + "24003": 1.0 + }, + "2403": { + "24003": 1.0 + }, + "2404": { + "24003": 1.0 + }, + "2405": { + "24003": 1.0 + }, + "2406": { + "24003": 1.0 + }, + "2407": { + "24003": 1.0 + }, + "2408": { + "24003": 1.0 + }, + "2501": { + "25017": 1.0 + }, + "2502": { + "25017": 1.0 + }, + "2503": { + "25017": 1.0 + }, + "2504": { + "25017": 1.0 + }, + "2505": { + "25017": 1.0 + }, + "2506": { + "25017": 1.0 + }, + "2507": { + "25017": 1.0 + }, + "2508": { + "25017": 1.0 + }, + "2509": { + "25017": 1.0 + }, + "2601": { + "26163": 1.0 + }, + "2602": { + "26163": 1.0 + }, + "2603": { + "26163": 1.0 + }, + "2604": { + "26163": 1.0 + }, + "2605": { + "26163": 1.0 + }, + "2606": { + "26163": 1.0 + }, + "2607": { + "26163": 1.0 + }, + "2608": { + "26163": 1.0 + }, + "2609": { + "26163": 1.0 + }, + "2610": { + "26163": 1.0 + }, + "2611": { + "26163": 1.0 + }, + "2612": { + "26163": 1.0 + }, + "2613": { + "26163": 1.0 + }, + "2701": { + "27053": 1.0 + }, + "2702": { + "27053": 1.0 + }, + "2703": { + "27053": 1.0 + }, + "2704": { + "27053": 1.0 + }, + "2705": { + "27053": 1.0 + }, + "2706": { + "27053": 1.0 + }, + "2707": { + "27053": 1.0 + }, + "2708": { + "27053": 1.0 + }, + "2801": { + "28049": 1.0 + }, + "2802": { + "28049": 1.0 + }, + "2803": { + "28049": 1.0 + }, + "2804": { + "28049": 1.0 + }, + "2901": { + "29189": 1.0 + }, + "2902": { + "29189": 1.0 + }, + "2903": { + "29189": 1.0 + }, + "2904": { + "29189": 1.0 + }, + "2905": { + "29189": 1.0 + }, + "2906": { + "29189": 1.0 + }, + "2907": { + "29189": 1.0 + }, + "2908": { + "29189": 1.0 + }, + "3001": { + "30111": 1.0 + }, + "3002": { + "30111": 1.0 + }, + "3101": { + "31055": 1.0 + }, + "3102": { + "31055": 1.0 + }, + "3103": { + "31055": 1.0 + }, + "3201": { + "32003": 1.0 + }, + "3202": { + "32003": 1.0 + }, + "3203": { + "32003": 1.0 + }, + "3204": { + "32003": 1.0 + }, + "3301": { + "33011": 1.0 + }, + "3302": { + "33011": 1.0 + }, + "3401": { + "34003": 1.0 + }, + "3402": { + "34003": 1.0 + }, + "3403": { + "34003": 1.0 + }, + "3404": { + "34003": 1.0 + }, + "3405": { + "34003": 1.0 + }, + "3406": { + "34003": 1.0 + }, + "3407": { + "34003": 1.0 + }, + "3408": { + "34003": 1.0 + }, + "3409": { + "34003": 1.0 + }, + "3410": { + "34003": 1.0 + }, + "3411": { + "34003": 1.0 + }, + "3412": { + "34003": 1.0 + }, + "3501": { + "35001": 1.0 + }, + "3502": { + "35001": 1.0 + }, + "3503": { + "35001": 1.0 + }, + "3601": { + "36103": 0.8, + "36059": 0.2 + }, + "3602": { + "36047": 1.0 + }, + "3603": { + "36047": 1.0 + }, + "3604": { + "36047": 1.0 + }, + "3605": { + "36047": 1.0 + }, + "3606": { + "36047": 1.0 + }, + "3607": { + "36047": 1.0 + }, + "3608": { + "36047": 1.0 + }, + "3609": { + "36047": 1.0 + }, + "3610": { + "36047": 1.0 + }, + "3611": { + "36047": 1.0 + }, + "3612": { + "36061": 0.5, + "36047": 0.5 + }, + "3613": { + "36047": 1.0 + }, + "3614": { + "36047": 1.0 + }, + "3615": { + "36047": 1.0 + }, + "3616": { + "36047": 1.0 + }, + "3617": { + "36047": 1.0 + }, + "3618": { + "36047": 1.0 + }, + "3619": { + "36047": 1.0 + }, + "3620": { + "36047": 1.0 + }, + "3621": { + "36047": 1.0 + }, + "3622": { + "36047": 1.0 + }, + "3623": { + "36047": 1.0 + }, + "3624": { + "36047": 1.0 + }, + "3625": { + "36047": 1.0 + }, + "3626": { + "36047": 1.0 + }, + "3701": { + "37119": 1.0 + }, + "3702": { + "37119": 1.0 + }, + "3703": { + "37119": 1.0 + }, + "3704": { + "37119": 1.0 + }, + "3705": { + "37119": 1.0 + }, + "3706": { + "37119": 1.0 + }, + "3707": { + "37119": 1.0 + }, + "3708": { + "37119": 1.0 + }, + "3709": { + "37119": 1.0 + }, + "3710": { + "37119": 1.0 + }, + "3711": { + "37119": 1.0 + }, + "3712": { + "37119": 1.0 + }, + "3713": { + "37119": 1.0 + }, + "3714": { + "37119": 1.0 + }, + "3801": { + "38015": 1.0 + }, + "3901": { + "39049": 1.0 + }, + "3902": { + "39049": 1.0 + }, + "3903": { + "39049": 1.0 + }, + "3904": { + "39049": 1.0 + }, + "3905": { + "39049": 1.0 + }, + "3906": { + "39049": 1.0 + }, + "3907": { + "39049": 1.0 + }, + "3908": { + "39049": 1.0 + }, + "3909": { + "39049": 1.0 + }, + "3910": { + "39049": 1.0 + }, + "3911": { + "39049": 1.0 + }, + "3912": { + "39049": 1.0 + }, + "3913": { + "39049": 1.0 + }, + "3914": { + "39049": 1.0 + }, + "3915": { + "39049": 1.0 + }, + "4001": { + "40109": 1.0 + }, + "4002": { + "40109": 1.0 + }, + "4003": { + "40109": 1.0 + }, + "4004": { + "40109": 1.0 + }, + "4005": { + "40109": 1.0 + }, + "401": { + "04013": 1.0 + }, + "402": { + "04013": 1.0 + }, + "403": { + "04013": 1.0 + }, + "404": { + "04013": 1.0 + }, + "405": { + "04013": 1.0 + }, + "406": { + "04013": 1.0 + }, + "407": { + "04013": 1.0 + }, + "408": { + "04013": 1.0 + }, + "409": { + "04013": 1.0 + }, + "4101": { + "41051": 1.0 + }, + "4102": { + "41051": 1.0 + }, + "4103": { + "41051": 1.0 + }, + "4104": { + "41051": 1.0 + }, + "4105": { + "41051": 1.0 + }, + "4106": { + "41051": 1.0 + }, + "4201": { + "42101": 1.0 + }, + "4202": { + "42101": 1.0 + }, + "4203": { + "42101": 1.0 + }, + "4204": { + "42101": 1.0 + }, + "4205": { + "42101": 1.0 + }, + "4206": { + "42101": 1.0 + }, + "4207": { + "42101": 1.0 + }, + "4208": { + "42101": 1.0 + }, + "4209": { + "42101": 1.0 + }, + "4210": { + "42101": 1.0 + }, + "4211": { + "42101": 1.0 + }, + "4212": { + "42101": 1.0 + }, + "4213": { + "42101": 1.0 + }, + "4214": { + "42101": 1.0 + }, + "4215": { + "42101": 1.0 + }, + "4216": { + "42101": 1.0 + }, + "4217": { + "42101": 1.0 + }, + "4401": { + "44007": 1.0 + }, + "4402": { + "44007": 1.0 + }, + "4501": { + "45079": 1.0 + }, + "4502": { + "45079": 1.0 + }, + "4503": { + "45079": 1.0 + }, + "4504": { + "45079": 1.0 + }, + "4505": { + "45079": 1.0 + }, + "4506": { + "45079": 1.0 + }, + "4507": { + "45079": 1.0 + }, + "4601": { + "46103": 1.0 + }, + "4701": { + "47157": 1.0 + }, + "4702": { + "47157": 1.0 + }, + "4703": { + "47157": 1.0 + }, + "4704": { + "47157": 1.0 + }, + "4705": { + "47157": 1.0 + }, + "4706": { + "47157": 1.0 + }, + "4707": { + "47157": 1.0 + }, + "4708": { + "47157": 1.0 + }, + "4709": { + "47157": 1.0 + }, + "4801": { + "48001": 0.15, + "48213": 0.25, + "48423": 0.35, + "48183": 0.25 + }, + "4802": { + "48201": 1.0 + }, + "4803": { + "48201": 1.0 + }, + "4804": { + "48201": 1.0 + }, + "4805": { + "48201": 1.0 + }, + "4806": { + "48201": 1.0 + }, + "4807": { + "48201": 1.0 + }, + "4808": { + "48201": 1.0 + }, + "4809": { + "48201": 1.0 + }, + "4810": { + "48201": 1.0 + }, + "4811": { + "48201": 1.0 + }, + "4812": { + "48201": 1.0 + }, + "4813": { + "48201": 1.0 + }, + "4814": { + "48201": 1.0 + }, + "4815": { + "48201": 1.0 + }, + "4816": { + "48201": 1.0 + }, + "4817": { + "48201": 1.0 + }, + "4818": { + "48201": 1.0 + }, + "4819": { + "48201": 1.0 + }, + "4820": { + "48201": 1.0 + }, + "4821": { + "48201": 1.0 + }, + "4822": { + "48201": 1.0 + }, + "4823": { + "48201": 1.0 + }, + "4824": { + "48201": 1.0 + }, + "4825": { + "48201": 1.0 + }, + "4826": { + "48201": 1.0 + }, + "4827": { + "48201": 1.0 + }, + "4828": { + "48201": 1.0 + }, + "4829": { + "48201": 1.0 + }, + "4830": { + "48201": 1.0 + }, + "4831": { + "48201": 1.0 + }, + "4832": { + "48201": 1.0 + }, + "4833": { + "48201": 1.0 + }, + "4834": { + "48201": 1.0 + }, + "4835": { + "48201": 1.0 + }, + "4836": { + "48201": 1.0 + }, + "4837": { + "48201": 1.0 + }, + "4838": { + "48201": 1.0 + }, + "4901": { + "49035": 1.0 + }, + "4902": { + "49035": 1.0 + }, + "4903": { + "49035": 1.0 + }, + "4904": { + "49035": 1.0 + }, + "5001": { + "50007": 1.0 + }, + "501": { + "05119": 1.0 + }, + "502": { + "05119": 1.0 + }, + "503": { + "05119": 1.0 + }, + "504": { + "05119": 1.0 + }, + "5101": { + "51059": 1.0 + }, + "5102": { + "51059": 1.0 + }, + "5103": { + "51059": 1.0 + }, + "5104": { + "51059": 1.0 + }, + "5105": { + "51059": 1.0 + }, + "5106": { + "51059": 1.0 + }, + "5107": { + "51059": 1.0 + }, + "5108": { + "51059": 1.0 + }, + "5109": { + "51059": 1.0 + }, + "5110": { + "51059": 1.0 + }, + "5111": { + "51059": 1.0 + }, + "5301": { + "53033": 1.0 + }, + "5302": { + "53033": 1.0 + }, + "5303": { + "53033": 1.0 + }, + "5304": { + "53033": 1.0 + }, + "5305": { + "53033": 1.0 + }, + "5306": { + "53033": 1.0 + }, + "5307": { + "53033": 1.0 + }, + "5308": { + "53033": 1.0 + }, + "5309": { + "53033": 1.0 + }, + "5310": { + "53033": 1.0 + }, + "5401": { + "54039": 1.0 + }, + "5402": { + "54039": 1.0 + }, + "5501": { + "55079": 1.0 + }, + "5502": { + "55079": 1.0 + }, + "5503": { + "55079": 1.0 + }, + "5504": { + "55079": 1.0 + }, + "5505": { + "55079": 1.0 + }, + "5506": { + "55079": 1.0 + }, + "5507": { + "55079": 1.0 + }, + "5508": { + "55079": 1.0 + }, + "5601": { + "56021": 1.0 + }, + "601": { + "06089": 0.35, + "06103": 0.25, + "06115": 0.2, + "06007": 0.2 + }, + "602": { + "06037": 1.0 + }, + "603": { + "06037": 1.0 + }, + "604": { + "06037": 1.0 + }, + "605": { + "06037": 1.0 + }, + "606": { + "06037": 1.0 + }, + "607": { + "06037": 1.0 + }, + "608": { + "06037": 1.0 + }, + "609": { + "06037": 1.0 + }, + "610": { + "06037": 1.0 + }, + "611": { + "06037": 1.0 + }, + "612": { + "06075": 0.6, + "06081": 0.4 + }, + "613": { + "06037": 1.0 + }, + "614": { + "06037": 1.0 + }, + "615": { + "06037": 1.0 + }, + "616": { + "06037": 1.0 + }, + "617": { + "06037": 1.0 + }, + "618": { + "06037": 1.0 + }, + "619": { + "06037": 1.0 + }, + "620": { + "06037": 1.0 + }, + "621": { + "06037": 1.0 + }, + "622": { + "06037": 1.0 + }, + "623": { + "06037": 1.0 + }, + "624": { + "06037": 1.0 + }, + "625": { + "06037": 1.0 + }, + "626": { + "06037": 1.0 + }, + "627": { + "06037": 1.0 + }, + "628": { + "06037": 1.0 + }, + "629": { + "06037": 1.0 + }, + "630": { + "06037": 1.0 + }, + "631": { + "06037": 1.0 + }, + "632": { + "06037": 1.0 + }, + "633": { + "06037": 1.0 + }, + "634": { + "06037": 1.0 + }, + "635": { + "06037": 1.0 + }, + "636": { + "06037": 1.0 + }, + "637": { + "06037": 1.0 + }, + "638": { + "06037": 1.0 + }, + "639": { + "06037": 1.0 + }, + "640": { + "06037": 1.0 + }, + "641": { + "06037": 1.0 + }, + "642": { + "06037": 1.0 + }, + "643": { + "06037": 1.0 + }, + "644": { + "06037": 1.0 + }, + "645": { + "06037": 1.0 + }, + "646": { + "06037": 1.0 + }, + "647": { + "06037": 1.0 + }, + "648": { + "06037": 1.0 + }, + "649": { + "06037": 1.0 + }, + "650": { + "06037": 1.0 + }, + "651": { + "06037": 1.0 + }, + "652": { + "06073": 1.0 + }, + "801": { + "08031": 1.0 + }, + "802": { + "08031": 1.0 + }, + "803": { + "08031": 1.0 + }, + "804": { + "08031": 1.0 + }, + "805": { + "08031": 1.0 + }, + "806": { + "08031": 1.0 + }, + "807": { + "08031": 1.0 + }, + "808": { + "08031": 1.0 + }, + "901": { + "09003": 1.0 + }, + "902": { + "09003": 1.0 + }, + "903": { + "09003": 1.0 + }, + "904": { + "09003": 1.0 + }, + "905": { + "09003": 1.0 + } +} \ No newline at end of file diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/create_calibration_package.py b/policyengine_us_data/datasets/cps/local_area_calibration/create_calibration_package.py new file mode 100644 index 00000000..3b03f247 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/create_calibration_package.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +import os +import argparse +from pathlib import Path +from datetime import datetime +import pickle +import json +from sqlalchemy import create_engine, text +import logging + +import numpy as np +from scipy import sparse as sp + +from policyengine_us import Microsimulation +from policyengine_us_data.datasets.cps.local_area_calibration.metrics_matrix_geo_stacking_sparse import ( + SparseGeoStackingMatrixBuilder, +) +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, + filter_target_groups, +) +from policyengine_us_data.utils.data_upload import upload_files_to_gcs + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + + +def create_calibration_package( + db_path: str, + dataset_uri: str, + mode: str = "Stratified", + groups_to_exclude: list = None, + local_output_dir: str = None, + gcs_bucket: str = None, + gcs_date_prefix: str = None, +): + """ + Create a calibration package from database and dataset. + + Args: + db_path: Path to policy_data.db + dataset_uri: URI for the CPS dataset (local path or hf://) + mode: "Test", "Stratified", or "Full" + groups_to_exclude: List of target group IDs to exclude + local_output_dir: Local directory to save package (optional) + gcs_bucket: GCS bucket name (optional) + gcs_date_prefix: Date prefix for GCS (e.g., "2025-10-15-1430", auto-generated if None) + + Returns: + dict with 'local_path' and/or 'gcs_path' keys + """ + + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + if groups_to_exclude is None: + groups_to_exclude = [] + + # Step 1: Load data and get CD list + db_uri = f"sqlite:///{db_path}" + builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + + engine = create_engine(db_uri) + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' + ORDER BY sc.value + """ + + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + all_cd_geoids = [row[0] for row in result] + + logging.info( + f"Found {len(all_cd_geoids)} congressional districts in database" + ) + + # Select CDs based on mode + if mode == "Test": + cds_to_calibrate = [ + "601", + "652", + "3601", + "3626", + "4801", + "4838", + "1201", + "1228", + "1701", + "1101", + ] + logging.info(f"TEST MODE: Using {len(cds_to_calibrate)} CDs") + else: + cds_to_calibrate = all_cd_geoids + logging.info(f"Using all {len(cds_to_calibrate)} CDs") + + sim = Microsimulation(dataset=dataset_uri) + + # Step 2: Build sparse matrix + logging.info("Building sparse matrix...") + targets_df, X_sparse, household_id_mapping = ( + builder.build_stacked_matrix_sparse( + "congressional_district", cds_to_calibrate, sim + ) + ) + logging.info(f"Matrix shape: {X_sparse.shape}") + logging.info(f"Total targets: {len(targets_df)}") + + # Step 3: Create and filter target groups + target_groups, group_info = create_target_groups(targets_df) + + logging.info(f"Total groups: {len(np.unique(target_groups))}") + for info in group_info[:5]: + logging.info(f" {info}") + + if groups_to_exclude: + logging.info(f"Excluding {len(groups_to_exclude)} target groups") + targets_df, X_sparse, target_groups = filter_target_groups( + targets_df, X_sparse, target_groups, groups_to_exclude + ) + + targets = targets_df.value.values + + # Step 4: Calculate initial weights + cd_populations = {} + for cd_geoid in cds_to_calibrate: + cd_age_targets = targets_df[ + (targets_df["geographic_id"] == cd_geoid) + & (targets_df["variable"] == "person_count") + & (targets_df["variable_desc"].str.contains("age", na=False)) + ] + if not cd_age_targets.empty: + unique_ages = cd_age_targets.drop_duplicates( + subset=["variable_desc"] + ) + cd_populations[cd_geoid] = unique_ages["value"].sum() + + if cd_populations: + min_pop = min(cd_populations.values()) + max_pop = max(cd_populations.values()) + logging.info(f"CD population range: {min_pop:,.0f} to {max_pop:,.0f}") + else: + logging.warning("Could not calculate CD populations, using default") + min_pop = 700000 + + keep_probs = np.zeros(X_sparse.shape[1]) + init_weights = np.zeros(X_sparse.shape[1]) + cumulative_idx = 0 + cd_household_indices = {} + + for cd_key, household_list in household_id_mapping.items(): + cd_geoid = cd_key.replace("cd", "") + n_households = len(household_list) + + if cd_geoid in cd_populations: + cd_pop = cd_populations[cd_geoid] + else: + cd_pop = min_pop + + pop_ratio = cd_pop / min_pop + adjusted_keep_prob = min(0.15, 0.02 * np.sqrt(pop_ratio)) + keep_probs[cumulative_idx : cumulative_idx + n_households] = ( + adjusted_keep_prob + ) + + base_weight = cd_pop / n_households + sparsity_adjustment = 1.0 / np.sqrt(adjusted_keep_prob) + initial_weight = base_weight * sparsity_adjustment + + init_weights[cumulative_idx : cumulative_idx + n_households] = ( + initial_weight + ) + cd_household_indices[cd_geoid] = ( + cumulative_idx, + cumulative_idx + n_households, + ) + cumulative_idx += n_households + + logging.info( + f"Initial weight range: {init_weights.min():.0f} to {init_weights.max():.0f}" + ) + logging.info(f"Mean initial weight: {init_weights.mean():.0f}") + + # Step 5: Create calibration package + calibration_package = { + "X_sparse": X_sparse, + "targets_df": targets_df, + "household_id_mapping": household_id_mapping, + "cd_household_indices": cd_household_indices, + "dataset_uri": dataset_uri, + "cds_to_calibrate": cds_to_calibrate, + "initial_weights": init_weights, + "keep_probs": keep_probs, + "target_groups": target_groups, + } + + # Create metadata + metadata = { + "created_at": datetime.now().isoformat(), + "mode": mode, + "dataset_uri": dataset_uri, + "n_cds": len(cds_to_calibrate), + "n_targets": len(targets_df), + "n_households": X_sparse.shape[1], + "matrix_shape": X_sparse.shape, + "groups_excluded": groups_to_exclude, + } + + results = {} + + # Save locally if requested + if local_output_dir: + local_dir = Path(local_output_dir) + local_dir.mkdir(parents=True, exist_ok=True) + + pkg_path = local_dir / "calibration_package.pkl" + with open(pkg_path, "wb") as f: + pickle.dump(calibration_package, f) + + meta_path = local_dir / "metadata.json" + with open(meta_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"✅ Saved locally to {pkg_path}") + logging.info( + f" Size: {pkg_path.stat().st_size / 1024 / 1024:.1f} MB" + ) + results["local_path"] = str(pkg_path) + + # Upload to GCS if requested + if gcs_bucket: + if not gcs_date_prefix: + gcs_date_prefix = datetime.now().strftime("%Y-%m-%d-%H%M") + + gcs_prefix = f"{gcs_date_prefix}/inputs" + + # Save to temp location for upload + import tempfile + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_pkg = Path(tmpdir) / "calibration_package.pkl" + tmp_meta = Path(tmpdir) / "metadata.json" + + with open(tmp_pkg, "wb") as f: + pickle.dump(calibration_package, f) + with open(tmp_meta, "w") as f: + json.dump(metadata, f, indent=2) + + # Upload to GCS with prefix + from google.cloud import storage + import google.auth + + credentials, project_id = google.auth.default() + storage_client = storage.Client( + credentials=credentials, project=project_id + ) + bucket = storage_client.bucket(gcs_bucket) + + for local_file, blob_name in [ + (tmp_pkg, "calibration_package.pkl"), + (tmp_meta, "metadata.json"), + ]: + blob_path = f"{gcs_prefix}/{blob_name}" + blob = bucket.blob(blob_path) + blob.upload_from_filename(local_file) + logging.info(f"✅ Uploaded to gs://{gcs_bucket}/{blob_path}") + + gcs_path = f"gs://{gcs_bucket}/{gcs_prefix}" + results["gcs_path"] = gcs_path + results["gcs_prefix"] = gcs_prefix + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Create calibration package") + parser.add_argument( + "--db-path", required=True, help="Path to policy_data.db" + ) + parser.add_argument( + "--dataset-uri", + required=True, + help="Dataset URI (local path or hf://)", + ) + parser.add_argument( + "--mode", default="Stratified", choices=["Test", "Stratified", "Full"] + ) + parser.add_argument("--local-output", help="Local output directory") + parser.add_argument( + "--gcs-bucket", help="GCS bucket name (e.g., policyengine-calibration)" + ) + parser.add_argument( + "--gcs-date", help="GCS date prefix (default: YYYY-MM-DD-HHMM)" + ) + + args = parser.parse_args() + + # Default groups to exclude (from original script) + groups_to_exclude = [ + 0, + 1, + 2, + 3, + 4, + 5, + 8, + 12, + 10, + 15, + 17, + 18, + 21, + 34, + 35, + 36, + 37, + 31, + 56, + 42, + 64, + 46, + 68, + 47, + 69, + ] + + results = create_calibration_package( + db_path=args.db_path, + dataset_uri=args.dataset_uri, + mode=args.mode, + groups_to_exclude=groups_to_exclude, + local_output_dir=args.local_output, + gcs_bucket=args.gcs_bucket, + gcs_date_prefix=args.gcs_date, + ) + + print("\n" + "=" * 70) + print("CALIBRATION PACKAGE CREATED") + print("=" * 70) + if "local_path" in results: + print(f"Local: {results['local_path']}") + if "gcs_path" in results: + print(f"GCS: {results['gcs_path']}") + print(f"\nTo use with optimize_weights.py:") + print(f" --gcs-input gs://{args.gcs_bucket}/{results['gcs_prefix']}") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py b/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py new file mode 100644 index 00000000..6f6b6fa4 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/create_stratified_cps.py @@ -0,0 +1,306 @@ +""" +Create a stratified sample of extended_cps_2023.h5 that preserves high-income households. +This is needed for congressional district geo-stacking where the full dataset is too large. + +Strategy: +- Keep ALL households above a high income threshold (e.g., top 1%) +- Sample progressively less from lower income strata +- Ensure representation across all income levels +""" + +import numpy as np +import pandas as pd +import h5py +from policyengine_us import Microsimulation +from policyengine_core.data.dataset import Dataset +from policyengine_core.enums import Enum + + +def create_stratified_cps_dataset( + target_households=30_000, + high_income_percentile=99, # Keep ALL households above this percentile + base_dataset="hf://policyengine/test/extended_cps_2023.h5", + output_path=None, +): + """ + Create a stratified sample of CPS data preserving high-income households. + + Args: + target_households: Target number of households in output (approximate) + high_income_percentile: Keep ALL households above this AGI percentile + output_path: Where to save the stratified h5 file + """ + print("\n" + "=" * 70) + print("CREATING STRATIFIED CPS DATASET") + print("=" * 70) + + # Load the original simulation + print("Loading original dataset...") + sim = Microsimulation(dataset=base_dataset) + + # Calculate AGI for all households + print("Calculating household AGI...") + agi = sim.calculate("adjusted_gross_income", map_to="household").values + household_ids = sim.calculate("household_id", map_to="household").values + n_households_orig = len(household_ids) + + print(f"Original dataset: {n_households_orig:,} households") + print(f"Target dataset: {target_households:,} households") + print(f"Reduction ratio: {target_households/n_households_orig:.1%}") + + # Calculate AGI percentiles + print("\nAnalyzing income distribution...") + percentiles = [0, 25, 50, 75, 90, 95, 99, 99.5, 99.9, 100] + agi_percentiles = np.percentile(agi, percentiles) + + print("AGI Percentiles:") + for p, val in zip(percentiles, agi_percentiles): + print(f" {p:5.1f}%: ${val:,.0f}") + + # Define sampling strategy + # Keep ALL high earners, sample progressively less from lower strata + high_income_threshold = np.percentile(agi, high_income_percentile) + print( + f"\nHigh-income threshold (top {100-high_income_percentile}%): ${high_income_threshold:,.0f}" + ) + + # Create strata with sampling rates + strata = [ + (99.9, 100, 1.00), # Top 0.1% - keep ALL + (99.5, 99.9, 1.00), # 99.5-99.9% - keep ALL + (99, 99.5, 1.00), # 99-99.5% - keep ALL + (95, 99, 0.80), # 95-99% - keep 80% + (90, 95, 0.60), # 90-95% - keep 60% + (75, 90, 0.40), # 75-90% - keep 40% + (50, 75, 0.25), # 50-75% - keep 25% + (25, 50, 0.15), # 25-50% - keep 15% + (0, 25, 0.10), # Bottom 25% - keep 10% + ] + + # Adjust sampling rates to hit target + print("\nInitial sampling strategy:") + expected_count = 0 + for low_p, high_p, rate in strata: + low_val = np.percentile(agi, low_p) if low_p > 0 else -np.inf + high_val = np.percentile(agi, high_p) if high_p < 100 else np.inf + in_stratum = np.sum((agi > low_val) & (agi <= high_val)) + expected = int(in_stratum * rate) + expected_count += expected + print( + f" {low_p:5.1f}-{high_p:5.1f}%: {in_stratum:6,} households × {rate:.0%} = {expected:6,}" + ) + + print(f"Expected total: {expected_count:,} households") + + # Adjust rates if needed + if expected_count > target_households * 1.1: # Allow 10% overage + adjustment = target_households / expected_count + print( + f"\nAdjusting rates by factor of {adjustment:.2f} to meet target..." + ) + + # Never reduce the top percentiles + strata_adjusted = [] + for low_p, high_p, rate in strata: + if high_p >= 99: # Never reduce top 1% + strata_adjusted.append((low_p, high_p, rate)) + else: + strata_adjusted.append( + (low_p, high_p, min(1.0, rate * adjustment)) + ) + strata = strata_adjusted + + # Select households based on strata + print("\nSelecting households...") + selected_mask = np.zeros(n_households_orig, dtype=bool) + + for low_p, high_p, rate in strata: + low_val = np.percentile(agi, low_p) if low_p > 0 else -np.inf + high_val = np.percentile(agi, high_p) if high_p < 100 else np.inf + + in_stratum = (agi > low_val) & (agi <= high_val) + stratum_indices = np.where(in_stratum)[0] + n_in_stratum = len(stratum_indices) + + if rate >= 1.0: + # Keep all + selected_mask[stratum_indices] = True + n_selected = n_in_stratum + else: + # Random sample within stratum + n_to_select = int(n_in_stratum * rate) + if n_to_select > 0: + np.random.seed(42) # For reproducibility + selected_indices = np.random.choice( + stratum_indices, n_to_select, replace=False + ) + selected_mask[selected_indices] = True + n_selected = n_to_select + else: + n_selected = 0 + + print( + f" {low_p:5.1f}-{high_p:5.1f}%: Selected {n_selected:6,} / {n_in_stratum:6,} ({n_selected/max(1,n_in_stratum):.0%})" + ) + + n_selected = np.sum(selected_mask) + print( + f"\nTotal selected: {n_selected:,} households ({n_selected/n_households_orig:.1%} of original)" + ) + + # Verify high earners are preserved + high_earners_mask = agi >= high_income_threshold + n_high_earners = np.sum(high_earners_mask) + n_high_earners_selected = np.sum(selected_mask & high_earners_mask) + print(f"\nHigh earners (>=${high_income_threshold:,.0f}):") + print(f" Original: {n_high_earners:,}") + print( + f" Selected: {n_high_earners_selected:,} ({n_high_earners_selected/n_high_earners:.0%})" + ) + + # Get the selected household IDs + selected_household_ids = set(household_ids[selected_mask]) + + # Now filter the dataset using DataFrame approach (similar to create_sparse_state_stacked.py) + print("\nCreating filtered dataset...") + time_period = int(sim.default_calculation_period) + + # Convert full simulation to DataFrame + df = sim.to_input_dataframe() + + # Filter to selected households + hh_id_col = f"household_id__{time_period}" + df_filtered = df[df[hh_id_col].isin(selected_household_ids)].copy() + + print(f"Filtered DataFrame: {len(df_filtered):,} persons") + + # Create Dataset from filtered DataFrame + print("Creating Dataset from filtered DataFrame...") + stratified_dataset = Dataset.from_dataframe(df_filtered, time_period) + + # Build a simulation to convert to h5 + print("Building simulation from Dataset...") + stratified_sim = Microsimulation() + stratified_sim.dataset = stratified_dataset + stratified_sim.build_from_dataset() + + # Generate output path if not provided + if output_path is None: + output_path = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/stratified_extended_cps_2023.h5" + + # Save to h5 file + print(f"\nSaving to {output_path}...") + data = {} + + # Only save input variables (not calculated/derived variables) + input_vars = set(stratified_sim.input_variables) + print(f"Found {len(input_vars)} input variables (excluding calculated variables)") + + for variable in stratified_sim.tax_benefit_system.variables: + if variable not in input_vars: + continue + + data[variable] = {} + for period in stratified_sim.get_holder(variable).get_known_periods(): + values = stratified_sim.get_holder(variable).get_array(period) + + # Handle different value types + if variable == "county_fips": + values = values.astype("int32") + elif stratified_sim.tax_benefit_system.variables.get( + variable + ).value_type in (Enum, str): + # Check if it's an EnumArray with decode_to_str method + if hasattr(values, "decode_to_str"): + values = values.decode_to_str().astype("S") + else: + # Already a numpy array, just ensure it's string type + values = values.astype("S") + else: + values = np.array(values) + + if values is not None: + data[variable][period] = values + + if len(data[variable]) == 0: + del data[variable] + + # Write to h5 + with h5py.File(output_path, "w") as f: + for variable, periods in data.items(): + grp = f.create_group(variable) + for period, values in periods.items(): + grp.create_dataset(str(period), data=values) + + print(f"Stratified CPS dataset saved successfully!") + + # Verify the saved file + print("\nVerifying saved file...") + with h5py.File(output_path, "r") as f: + if "household_id" in f and str(time_period) in f["household_id"]: + hh_ids = f["household_id"][str(time_period)][:] + print(f" Final households: {len(hh_ids):,}") + if "person_id" in f and str(time_period) in f["person_id"]: + person_ids = f["person_id"][str(time_period)][:] + print(f" Final persons: {len(person_ids):,}") + if ( + "household_weight" in f + and str(time_period) in f["household_weight"] + ): + weights = f["household_weight"][str(time_period)][:] + print(f" Final household weights sum: {np.sum(weights):,.0f}") + + # Final income distribution check + print("\nVerifying income distribution in stratified dataset...") + stratified_sim_verify = Microsimulation(dataset=output_path) + agi_stratified = stratified_sim_verify.calculate( + "adjusted_gross_income", map_to="household" + ).values + + print("AGI Percentiles in stratified dataset:") + for p in [0, 25, 50, 75, 90, 95, 99, 99.5, 99.9, 100]: + val = np.percentile(agi_stratified, p) + print(f" {p:5.1f}%: ${val:,.0f}") + + max_agi_original = np.max(agi) + max_agi_stratified = np.max(agi_stratified) + print(f"\nMaximum AGI:") + print(f" Original: ${max_agi_original:,.0f}") + print(f" Stratified: ${max_agi_stratified:,.0f}") + + if max_agi_stratified < max_agi_original * 0.9: + print("WARNING: May have lost some ultra-high earners!") + else: + print("Ultra-high earners preserved!") + + return output_path + + +if __name__ == "__main__": + import sys + + # Parse command line arguments + if len(sys.argv) > 1: + try: + target = int(sys.argv[1]) + print( + f"Creating stratified dataset with target of {target:,} households..." + ) + output_file = create_stratified_cps_dataset( + target_households=target + ) + except ValueError: + print(f"Invalid target households: {sys.argv[1]}") + print("Usage: python create_stratified_cps.py [target_households]") + sys.exit(1) + else: + # Default target + print( + "Creating stratified dataset with default target of 30,000 households..." + ) + output_file = create_stratified_cps_dataset(target_households=30_000) + + print(f"\nDone! Created: {output_file}") + print("\nTo test loading:") + print(" from policyengine_us import Microsimulation") + print(f" sim = Microsimulation(dataset='{output_file}')") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/geo_stacking_walkthrough.ipynb b/policyengine_us_data/datasets/cps/local_area_calibration/geo_stacking_walkthrough.ipynb new file mode 100644 index 00000000..0504d8f5 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/geo_stacking_walkthrough.ipynb @@ -0,0 +1,2383 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Geo-Stacking Calibration Walkthrough\n", + "\n", + "This notebook validates the sparse matrix construction and dataset creation pipeline for CD-level calibration. It traces a single household through the system to verify correctness." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 1: Setup & Matrix Construction\n", + "\n", + "Build the sparse calibration matrix `X_sparse` where rows are targets and columns are (household × CD) pairs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/baogorek/envs/pe/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TEST_LITE == False\n" + ] + } + ], + "source": [ + "from sqlalchemy import create_engine, text\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from policyengine_us import Microsimulation\n", + "from policyengine_us_data.storage import STORAGE_FOLDER\n", + "from policyengine_us_data.datasets.cps.geo_stacking_calibration.metrics_matrix_geo_stacking_sparse import (\n", + " SparseGeoStackingMatrixBuilder,\n", + ")\n", + "from policyengine_us_data.datasets.cps.geo_stacking_calibration.calibration_utils import (\n", + " create_target_groups,\n", + ")\n", + "from policyengine_us_data.datasets.cps.geo_stacking_calibration.household_tracer import HouseholdTracer\n", + "from policyengine_us_data.datasets.cps.geo_stacking_calibration.create_sparse_cd_stacked import create_sparse_cd_stacked_dataset\n", + "\n", + "rng_ben = np.random.default_rng(seed=42)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "db_path = STORAGE_FOLDER / \"policy_data.db\"\n", + "db_uri = f\"sqlite:///{db_path}\"\n", + "builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023)\n", + "\n", + "engine = create_engine(db_uri)\n", + "\n", + "query = \"\"\"\n", + "SELECT DISTINCT sc.value as cd_geoid\n", + "FROM strata s\n", + "JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id\n", + "WHERE s.stratum_group_id = 1\n", + " AND sc.constraint_variable = 'congressional_district_geoid'\n", + "ORDER BY sc.value\n", + "\"\"\"\n", + "\n", + "with engine.connect() as conn:\n", + " result = conn.execute(text(query)).fetchall()\n", + " all_cd_geoids = [row[0] for row in result]\n", + "\n", + "cds_to_calibrate = all_cd_geoids\n", + "dataset_uri = STORAGE_FOLDER / \"stratified_extended_cps_2023.h5\"\n", + "sim = Microsimulation(dataset=str(dataset_uri))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "=== Creating Target Groups ===\n", + "\n", + "National targets (each is a singleton group):\n", + " Group 0: alimony_expense = 12,554,181,166\n", + " Group 1: alimony_income = 12,554,181,166\n", + " Group 2: charitable_deduction = 63,061,583,407\n", + " Group 3: child_support_expense = 31,868,306,036\n", + " Group 4: child_support_received = 31,868,306,036\n", + " Group 5: eitc = 64,440,000,000\n", + " Group 6: health_insurance_premiums_without_medicare_part_b = 371,796,903,749\n", + " Group 7: income_tax = 2,176,481,000,000\n", + " Group 8: interest_deduction = 23,949,514,839\n", + " Group 9: medicaid = 841,806,132,462\n", + " Group 10: medical_expense_deduction = 11,009,051,176\n", + " Group 11: medicare_part_b_premiums = 108,159,099,272\n", + " Group 12: net_worth = 154,512,998,960,600\n", + " Group 13: other_medical_expenses = 268,466,335,694\n", + " Group 14: over_the_counter_health_expenses = 71,220,353,850\n", + " Group 15: person_count_aca_ptc>0 = 19,529,896\n", + " Group 16: person_count_medicaid>0 = 71,644,763\n", + " Group 17: person_count_ssn_card_type=NONE = 12,200,000\n", + " Group 18: qualified_business_income_deduction = 60,936,063,965\n", + " Group 19: real_estate_taxes = 482,853,121,752\n", + " Group 20: rent = 709,794,088,975\n", + " Group 21: salt_deduction = 20,518,360,556\n", + " Group 22: snap = 107,062,860,000\n", + " Group 23: social_security = 1,379,268,000,000\n", + " Group 24: spm_unit_capped_housing_subsidy = 33,799,718,523\n", + " Group 25: spm_unit_capped_work_childcare_expenses = 336,065,772,739\n", + " Group 26: ssi = 60,090,000,000\n", + " Group 27: tanf = 8,691,356,192\n", + " Group 28: tip_income = 51,375,572,154\n", + " Group 29: unemployment_compensation = 35,000,000,000\n", + "\n", + "Geographic targets (grouped by variable type):\n", + " Group 30: All CD Age Distribution (7848 targets)\n", + " Group 31: All CD Person Income Distribution (3924 targets)\n", + " Group 32: All CD Medicaid Enrollment (436 targets)\n", + " Group 33: All CD Tax Units dividend_income>0 (436 targets)\n", + " Group 34: All CD Tax Units eitc_child_count==0 (436 targets)\n", + " Group 35: All CD Tax Units eitc_child_count==1 (436 targets)\n", + " Group 36: All CD Tax Units eitc_child_count==2 (436 targets)\n", + " Group 37: All CD Tax Units eitc_child_count>2 (436 targets)\n", + " Group 38: All CD Tax Units income_tax>0 (436 targets)\n", + " Group 39: All CD Tax Units income_tax_before_credits>0 (436 targets)\n", + " Group 40: All CD Tax Units medical_expense_deduction>0 (436 targets)\n", + " Group 41: All CD Tax Units net_capital_gains>0 (436 targets)\n", + " Group 42: All CD Tax Units qualified_business_income_deduction>0 (436 targets)\n", + " Group 43: All CD Tax Units qualified_dividend_income>0 (436 targets)\n", + " Group 44: All CD Tax Units real_estate_taxes>0 (436 targets)\n", + " Group 45: All CD Tax Units refundable_ctc>0 (436 targets)\n", + " Group 46: All CD Tax Units rental_income>0 (436 targets)\n", + " Group 47: All CD Tax Units salt>0 (436 targets)\n", + " Group 48: All CD Tax Units self_employment_income>0 (436 targets)\n", + " Group 49: All CD Tax Units tax_exempt_interest_income>0 (436 targets)\n", + " Group 50: All CD Tax Units tax_unit_partnership_s_corp_income>0 (436 targets)\n", + " Group 51: All CD Tax Units taxable_interest_income>0 (436 targets)\n", + " Group 52: All CD Tax Units taxable_ira_distributions>0 (436 targets)\n", + " Group 53: All CD Tax Units taxable_pension_income>0 (436 targets)\n", + " Group 54: All CD Tax Units taxable_social_security>0 (436 targets)\n", + " Group 55: All CD Tax Units unemployment_compensation>0 (436 targets)\n", + " Group 56: All CD AGI Total Amount (436 targets)\n", + " Group 57: All CD Dividend Income (436 targets)\n", + " Group 58: All CD Eitc (1744 targets)\n", + " Group 59: All CD SNAP Household Count (436 targets)\n", + " Group 60: All CD Income Tax (436 targets)\n", + " Group 61: All CD Income Tax Before Credits (436 targets)\n", + " Group 62: All CD Medical Expense Deduction (436 targets)\n", + " Group 63: All CD Net Capital Gains (436 targets)\n", + " Group 64: All CD Qualified Business Income Deduction (436 targets)\n", + " Group 65: All CD Qualified Dividend Income (436 targets)\n", + " Group 66: All CD Real Estate Taxes (436 targets)\n", + " Group 67: All CD Refundable Ctc (436 targets)\n", + " Group 68: All CD Rental Income (436 targets)\n", + " Group 69: All CD Salt (436 targets)\n", + " Group 70: All CD Self Employment Income (436 targets)\n", + " Group 71: State-level SNAP Cost (State) (51 targets)\n", + " Group 72: All CD Tax Exempt Interest Income (436 targets)\n", + " Group 73: All CD Tax Unit Partnership S Corp Income (436 targets)\n", + " Group 74: All CD Taxable Interest Income (436 targets)\n", + " Group 75: All CD Taxable Ira Distributions (436 targets)\n", + " Group 76: All CD Taxable Pension Income (436 targets)\n", + " Group 77: All CD Taxable Social Security (436 targets)\n", + " Group 78: All CD Unemployment Compensation (436 targets)\n", + "\n", + "Total groups created: 79\n", + "========================================\n", + "X_sparse shape: (33217, 5889488)\n", + "Number of target groups: 79\n" + ] + } + ], + "source": [ + "targets_df, X_sparse, household_id_mapping = (\n", + " builder.build_stacked_matrix_sparse(\n", + " \"congressional_district\", cds_to_calibrate, sim\n", + " )\n", + ")\n", + "\n", + "target_groups, group_info = create_target_groups(targets_df)\n", + "tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, cds_to_calibrate, sim)\n", + "\n", + "print(f\"X_sparse shape: {X_sparse.shape}\")\n", + "print(f\"Number of target groups: {len(set(target_groups))}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 2: Understanding the Row Catalog\n", + "\n", + "The tracer provides a catalog of what each row (target) represents. We'll examine Group 71: SNAP Cost (State) - 51 targets across 51 states." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "================================================================================\n", + "MATRIX STRUCTURE BREAKDOWN\n", + "================================================================================\n", + "\n", + "Matrix dimensions: 33217 rows × 5889488 columns\n", + " Rows = 33217 targets\n", + " Columns = 13508 households × 436 CDs\n", + " = 13,508 × 436 = 5,889,488\n", + "\n", + "--------------------------------------------------------------------------------\n", + "COLUMN STRUCTURE (Households stacked by CD)\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Showing first and last 10 CDs of 436 total:\n", + "\n", + "First 10 CDs:\n", + "cd_geoid start_col end_col n_households example_household_id\n", + " 1001 0 13507 13508 25\n", + " 101 13508 27015 13508 25\n", + " 102 27016 40523 13508 25\n", + " 103 40524 54031 13508 25\n", + " 104 54032 67539 13508 25\n", + " 105 67540 81047 13508 25\n", + " 106 81048 94555 13508 25\n", + " 107 94556 108063 13508 25\n", + " 1101 108064 121571 13508 25\n", + " 1201 121572 135079 13508 25\n", + "\n", + "Last 10 CDs:\n", + "cd_geoid start_col end_col n_households example_household_id\n", + " 804 5754408 5767915 13508 25\n", + " 805 5767916 5781423 13508 25\n", + " 806 5781424 5794931 13508 25\n", + " 807 5794932 5808439 13508 25\n", + " 808 5808440 5821947 13508 25\n", + " 901 5821948 5835455 13508 25\n", + " 902 5835456 5848963 13508 25\n", + " 903 5848964 5862471 13508 25\n", + " 904 5862472 5875979 13508 25\n", + " 905 5875980 5889487 13508 25\n", + "\n", + "--------------------------------------------------------------------------------\n", + "ROW STRUCTURE (Targets by geography and variable)\n", + "--------------------------------------------------------------------------------\n", + "\n", + "Targets by geographic level:\n", + "geographic_level n_targets\n", + " unknown 33217\n", + "\n", + "Targets by stratum group:\n", + " n_targets n_unique_vars\n", + "stratum_group_id \n", + "2 8284 2\n", + "3 3924 1\n", + "4 436 1\n", + "5 436 1\n", + "6 3488 2\n", + "100 872 2\n", + "101 872 2\n", + "102 872 2\n", + "103 872 2\n", + "104 872 2\n", + "105 872 2\n", + "106 872 2\n", + "107 872 2\n", + "108 872 2\n", + "109 872 2\n", + "110 872 2\n", + "111 872 2\n", + "112 872 2\n", + "113 872 2\n", + "114 872 2\n", + "115 872 2\n", + "116 872 2\n", + "117 872 2\n", + "118 872 2\n", + "national 30 28\n", + "state_snap_cost 51 1\n", + "\n", + "--------------------------------------------------------------------------------\n", + "TARGET GROUPS (for loss calculation)\n", + "--------------------------------------------------------------------------------\n", + "\n", + "=== Creating Target Groups ===\n", + "\n", + "National targets (each is a singleton group):\n", + " Group 0: alimony_expense = 12,554,181,166\n", + " Group 1: alimony_income = 12,554,181,166\n", + " Group 2: charitable_deduction = 63,061,583,407\n", + " Group 3: child_support_expense = 31,868,306,036\n", + " Group 4: child_support_received = 31,868,306,036\n", + " Group 5: eitc = 64,440,000,000\n", + " Group 6: health_insurance_premiums_without_medicare_part_b = 371,796,903,749\n", + " Group 7: income_tax = 2,176,481,000,000\n", + " Group 8: interest_deduction = 23,949,514,839\n", + " Group 9: medicaid = 841,806,132,462\n", + " Group 10: medical_expense_deduction = 11,009,051,176\n", + " Group 11: medicare_part_b_premiums = 108,159,099,272\n", + " Group 12: net_worth = 154,512,998,960,600\n", + " Group 13: other_medical_expenses = 268,466,335,694\n", + " Group 14: over_the_counter_health_expenses = 71,220,353,850\n", + " Group 15: person_count_aca_ptc>0 = 19,529,896\n", + " Group 16: person_count_medicaid>0 = 71,644,763\n", + " Group 17: person_count_ssn_card_type=NONE = 12,200,000\n", + " Group 18: qualified_business_income_deduction = 60,936,063,965\n", + " Group 19: real_estate_taxes = 482,853,121,752\n", + " Group 20: rent = 709,794,088,975\n", + " Group 21: salt_deduction = 20,518,360,556\n", + " Group 22: snap = 107,062,860,000\n", + " Group 23: social_security = 1,379,268,000,000\n", + " Group 24: spm_unit_capped_housing_subsidy = 33,799,718,523\n", + " Group 25: spm_unit_capped_work_childcare_expenses = 336,065,772,739\n", + " Group 26: ssi = 60,090,000,000\n", + " Group 27: tanf = 8,691,356,192\n", + " Group 28: tip_income = 51,375,572,154\n", + " Group 29: unemployment_compensation = 35,000,000,000\n", + "\n", + "Geographic targets (grouped by variable type):\n", + " Group 30: All CD Age Distribution (7848 targets)\n", + " Group 31: All CD Person Income Distribution (3924 targets)\n", + " Group 32: All CD Medicaid Enrollment (436 targets)\n", + " Group 33: All CD Tax Units dividend_income>0 (436 targets)\n", + " Group 34: All CD Tax Units eitc_child_count==0 (436 targets)\n", + " Group 35: All CD Tax Units eitc_child_count==1 (436 targets)\n", + " Group 36: All CD Tax Units eitc_child_count==2 (436 targets)\n", + " Group 37: All CD Tax Units eitc_child_count>2 (436 targets)\n", + " Group 38: All CD Tax Units income_tax>0 (436 targets)\n", + " Group 39: All CD Tax Units income_tax_before_credits>0 (436 targets)\n", + " Group 40: All CD Tax Units medical_expense_deduction>0 (436 targets)\n", + " Group 41: All CD Tax Units net_capital_gains>0 (436 targets)\n", + " Group 42: All CD Tax Units qualified_business_income_deduction>0 (436 targets)\n", + " Group 43: All CD Tax Units qualified_dividend_income>0 (436 targets)\n", + " Group 44: All CD Tax Units real_estate_taxes>0 (436 targets)\n", + " Group 45: All CD Tax Units refundable_ctc>0 (436 targets)\n", + " Group 46: All CD Tax Units rental_income>0 (436 targets)\n", + " Group 47: All CD Tax Units salt>0 (436 targets)\n", + " Group 48: All CD Tax Units self_employment_income>0 (436 targets)\n", + " Group 49: All CD Tax Units tax_exempt_interest_income>0 (436 targets)\n", + " Group 50: All CD Tax Units tax_unit_partnership_s_corp_income>0 (436 targets)\n", + " Group 51: All CD Tax Units taxable_interest_income>0 (436 targets)\n", + " Group 52: All CD Tax Units taxable_ira_distributions>0 (436 targets)\n", + " Group 53: All CD Tax Units taxable_pension_income>0 (436 targets)\n", + " Group 54: All CD Tax Units taxable_social_security>0 (436 targets)\n", + " Group 55: All CD Tax Units unemployment_compensation>0 (436 targets)\n", + " Group 56: All CD AGI Total Amount (436 targets)\n", + " Group 57: All CD Dividend Income (436 targets)\n", + " Group 58: All CD Eitc (1744 targets)\n", + " Group 59: All CD SNAP Household Count (436 targets)\n", + " Group 60: All CD Income Tax (436 targets)\n", + " Group 61: All CD Income Tax Before Credits (436 targets)\n", + " Group 62: All CD Medical Expense Deduction (436 targets)\n", + " Group 63: All CD Net Capital Gains (436 targets)\n", + " Group 64: All CD Qualified Business Income Deduction (436 targets)\n", + " Group 65: All CD Qualified Dividend Income (436 targets)\n", + " Group 66: All CD Real Estate Taxes (436 targets)\n", + " Group 67: All CD Refundable Ctc (436 targets)\n", + " Group 68: All CD Rental Income (436 targets)\n", + " Group 69: All CD Salt (436 targets)\n", + " Group 70: All CD Self Employment Income (436 targets)\n", + " Group 71: State-level SNAP Cost (State) (51 targets)\n", + " Group 72: All CD Tax Exempt Interest Income (436 targets)\n", + " Group 73: All CD Tax Unit Partnership S Corp Income (436 targets)\n", + " Group 74: All CD Taxable Interest Income (436 targets)\n", + " Group 75: All CD Taxable Ira Distributions (436 targets)\n", + " Group 76: All CD Taxable Pension Income (436 targets)\n", + " Group 77: All CD Taxable Social Security (436 targets)\n", + " Group 78: All CD Unemployment Compensation (436 targets)\n", + "\n", + "Total groups created: 79\n", + "========================================\n", + " Group 0: National alimony_expense (1 target, value=12,554,181,166) - rows [0]\n", + " Group 1: National alimony_income (1 target, value=12,554,181,166) - rows [1]\n", + " Group 2: National charitable_deduction (1 target, value=63,061,583,407) - rows [2]\n", + " Group 3: National child_support_expense (1 target, value=31,868,306,036) - rows [3]\n", + " Group 4: National child_support_received (1 target, value=31,868,306,036) - rows [4]\n", + " Group 5: National eitc (1 target, value=64,440,000,000) - rows [5]\n", + " Group 6: National health_insurance_premiums_without_medicare_part_b (1 target, value=371,796,903,749) - rows [6]\n", + " Group 7: National income_tax (1 target, value=2,176,481,000,000) - rows [7]\n", + " Group 8: National interest_deduction (1 target, value=23,949,514,839) - rows [8]\n", + " Group 9: National medicaid (1 target, value=841,806,132,462) - rows [9]\n", + " Group 10: National medical_expense_deduction (1 target, value=11,009,051,176) - rows [10]\n", + " Group 11: National medicare_part_b_premiums (1 target, value=108,159,099,272) - rows [11]\n", + " Group 12: National net_worth (1 target, value=154,512,998,960,600) - rows [12]\n", + " Group 13: National other_medical_expenses (1 target, value=268,466,335,694) - rows [13]\n", + " Group 14: National over_the_counter_health_expenses (1 target, value=71,220,353,850) - rows [14]\n", + " Group 15: National person_count_aca_ptc>0 (1 target, value=19,529,896) - rows [15]\n", + " Group 16: National person_count_medicaid>0 (1 target, value=71,644,763) - rows [16]\n", + " Group 17: National person_count_ssn_card_type=NONE (1 target, value=12,200,000) - rows [17]\n", + " Group 18: National qualified_business_income_deduction (1 target, value=60,936,063,965) - rows [18]\n", + " Group 19: National real_estate_taxes (1 target, value=482,853,121,752) - rows [19]\n", + " Group 20: National rent (1 target, value=709,794,088,975) - rows [20]\n", + " Group 21: National salt_deduction (1 target, value=20,518,360,556) - rows [21]\n", + " Group 22: National snap (1 target, value=107,062,860,000) - rows [22]\n", + " Group 23: National social_security (1 target, value=1,379,268,000,000) - rows [23]\n", + " Group 24: National spm_unit_capped_housing_subsidy (1 target, value=33,799,718,523) - rows [24]\n", + " Group 25: National spm_unit_capped_work_childcare_expenses (1 target, value=336,065,772,739) - rows [25]\n", + " Group 26: National ssi (1 target, value=60,090,000,000) - rows [26]\n", + " Group 27: National tanf (1 target, value=8,691,356,192) - rows [27]\n", + " Group 28: National tip_income (1 target, value=51,375,572,154) - rows [28]\n", + " Group 29: National unemployment_compensation (1 target, value=35,000,000,000) - rows [29]\n", + " Group 30: Age Distribution (7848 targets across 436 geographies) - rows [50, 51, 52, '...', 33126, 33127]\n", + " Group 31: Person Income Distribution (3924 targets across 436 geographies) - rows [41, 42, 43, '...', 33108, 33109]\n", + " Group 32: Medicaid Enrollment (436 targets across 436 geographies) - rows [68, 144, 220, '...', 33052, 33128]\n", + " Group 33: Tax Units dividend_income>0 (436 targets across 436 geographies) - rows [77, 153, 229, '...', 33061, 33137]\n", + " Group 34: Tax Units eitc_child_count==0 (436 targets across 436 geographies) - rows [78, 154, 230, '...', 33062, 33138]\n", + " Group 35: Tax Units eitc_child_count==1 (436 targets across 436 geographies) - rows [79, 155, 231, '...', 33063, 33139]\n", + " Group 36: Tax Units eitc_child_count==2 (436 targets across 436 geographies) - rows [80, 156, 232, '...', 33064, 33140]\n", + " Group 37: Tax Units eitc_child_count>2 (436 targets across 436 geographies) - rows [81, 157, 233, '...', 33065, 33141]\n", + " Group 38: Tax Units income_tax>0 (436 targets across 436 geographies) - rows [83, 159, 235, '...', 33067, 33143]\n", + " Group 39: Tax Units income_tax_before_credits>0 (436 targets across 436 geographies) - rows [82, 158, 234, '...', 33066, 33142]\n", + " Group 40: Tax Units medical_expense_deduction>0 (436 targets across 436 geographies) - rows [84, 160, 236, '...', 33068, 33144]\n", + " Group 41: Tax Units net_capital_gains>0 (436 targets across 436 geographies) - rows [85, 161, 237, '...', 33069, 33145]\n", + " Group 42: Tax Units qualified_business_income_deduction>0 (436 targets across 436 geographies) - rows [86, 162, 238, '...', 33070, 33146]\n", + " Group 43: Tax Units qualified_dividend_income>0 (436 targets across 436 geographies) - rows [87, 163, 239, '...', 33071, 33147]\n", + " Group 44: Tax Units real_estate_taxes>0 (436 targets across 436 geographies) - rows [88, 164, 240, '...', 33072, 33148]\n", + " Group 45: Tax Units refundable_ctc>0 (436 targets across 436 geographies) - rows [89, 165, 241, '...', 33073, 33149]\n", + " Group 46: Tax Units rental_income>0 (436 targets across 436 geographies) - rows [90, 166, 242, '...', 33074, 33150]\n", + " Group 47: Tax Units salt>0 (436 targets across 436 geographies) - rows [91, 167, 243, '...', 33075, 33151]\n", + " Group 48: Tax Units self_employment_income>0 (436 targets across 436 geographies) - rows [92, 168, 244, '...', 33076, 33152]\n", + " Group 49: Tax Units tax_exempt_interest_income>0 (436 targets across 436 geographies) - rows [93, 169, 245, '...', 33077, 33153]\n", + " Group 50: Tax Units tax_unit_partnership_s_corp_income>0 (436 targets across 436 geographies) - rows [94, 170, 246, '...', 33078, 33154]\n", + " Group 51: Tax Units taxable_interest_income>0 (436 targets across 436 geographies) - rows [95, 171, 247, '...', 33079, 33155]\n", + " Group 52: Tax Units taxable_ira_distributions>0 (436 targets across 436 geographies) - rows [96, 172, 248, '...', 33080, 33156]\n", + " Group 53: Tax Units taxable_pension_income>0 (436 targets across 436 geographies) - rows [97, 173, 249, '...', 33081, 33157]\n", + " Group 54: Tax Units taxable_social_security>0 (436 targets across 436 geographies) - rows [98, 174, 250, '...', 33082, 33158]\n", + " Group 55: Tax Units unemployment_compensation>0 (436 targets across 436 geographies) - rows [99, 175, 251, '...', 33083, 33159]\n", + " Group 56: AGI Total Amount (436 targets across 436 geographies) - rows [30, 106, 182, '...', 33014, 33090]\n", + " Group 57: Dividend Income (436 targets across 436 geographies) - rows [31, 107, 183, '...', 33015, 33091]\n", + " Group 58: Eitc (1744 targets across 436 geographies) - rows [32, 33, 34, '...', 33094, 33095]\n", + " Group 59: SNAP Household Count (436 targets across 436 geographies) - rows [36, 112, 188, '...', 33020, 33096]\n", + " Group 60: Income Tax (436 targets across 436 geographies) - rows [38, 114, 190, '...', 33022, 33098]\n", + " Group 61: Income Tax Before Credits (436 targets across 436 geographies) - rows [37, 113, 189, '...', 33021, 33097]\n", + " Group 62: Medical Expense Deduction (436 targets across 436 geographies) - rows [39, 115, 191, '...', 33023, 33099]\n", + " Group 63: Net Capital Gains (436 targets across 436 geographies) - rows [40, 116, 192, '...', 33024, 33100]\n", + " Group 64: Qualified Business Income Deduction (436 targets across 436 geographies) - rows [69, 145, 221, '...', 33053, 33129]\n", + " Group 65: Qualified Dividend Income (436 targets across 436 geographies) - rows [70, 146, 222, '...', 33054, 33130]\n", + " Group 66: Real Estate Taxes (436 targets across 436 geographies) - rows [71, 147, 223, '...', 33055, 33131]\n", + " Group 67: Refundable Ctc (436 targets across 436 geographies) - rows [72, 148, 224, '...', 33056, 33132]\n", + " Group 68: Rental Income (436 targets across 436 geographies) - rows [73, 149, 225, '...', 33057, 33133]\n", + " Group 69: Salt (436 targets across 436 geographies) - rows [74, 150, 226, '...', 33058, 33134]\n", + " Group 70: Self Employment Income (436 targets across 436 geographies) - rows [75, 151, 227, '...', 33059, 33135]\n", + " Group 71: SNAP Cost (State) (51 targets across 51 geographies) - rows [33166, 33167, 33168, '...', 33215, 33216]\n", + " Group 72: Tax Exempt Interest Income (436 targets across 436 geographies) - rows [76, 152, 228, '...', 33060, 33136]\n", + " Group 73: Tax Unit Partnership S Corp Income (436 targets across 436 geographies) - rows [100, 176, 252, '...', 33084, 33160]\n", + " Group 74: Taxable Interest Income (436 targets across 436 geographies) - rows [101, 177, 253, '...', 33085, 33161]\n", + " Group 75: Taxable Ira Distributions (436 targets across 436 geographies) - rows [102, 178, 254, '...', 33086, 33162]\n", + " Group 76: Taxable Pension Income (436 targets across 436 geographies) - rows [103, 179, 255, '...', 33087, 33163]\n", + " Group 77: Taxable Social Security (436 targets across 436 geographies) - rows [104, 180, 256, '...', 33088, 33164]\n", + " Group 78: Unemployment Compensation (436 targets across 436 geographies) - rows [105, 181, 257, '...', 33089, 33165]\n", + "\n", + "================================================================================\n" + ] + } + ], + "source": [ + "tracer.print_matrix_structure()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Row info for first SNAP state target:\n" + ] + }, + { + "data": { + "text/plain": [ + "{'row_index': 33194,\n", + " 'variable': 'snap',\n", + " 'variable_desc': 'snap_cost_state',\n", + " 'geographic_id': '37',\n", + " 'geographic_level': 'unknown',\n", + " 'target_value': 4041086120.0,\n", + " 'stratum_id': 9799,\n", + " 'stratum_group_id': 'state_snap_cost'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "group_71 = tracer.get_group_rows(71)\n", + "row_loc = group_71.iloc[28]['row_index']\n", + "row_info = tracer.get_row_info(row_loc)\n", + "var = row_info['variable']\n", + "var_desc = row_info['variable_desc']\n", + "target_geo_id = int(row_info['geographic_id'])\n", + "\n", + "print(\"Row info for first SNAP state target:\")\n", + "row_info" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
row_indexvariablevariable_descgeographic_idgeographic_leveltarget_valuestratum_idstratum_group_id
3316633166snapsnap_cost_state1unknown2.048985e+099766state_snap_cost
3316733167snapsnap_cost_state10unknown2.962075e+089773state_snap_cost
3316833168snapsnap_cost_state11unknown3.793723e+089774state_snap_cost
3316933169snapsnap_cost_state12unknown6.756577e+099775state_snap_cost
3317033170snapsnap_cost_state13unknown3.232508e+099776state_snap_cost
3317133171snapsnap_cost_state15unknown8.424059e+089777state_snap_cost
3317233172snapsnap_cost_state16unknown2.494227e+089778state_snap_cost
3317333173snapsnap_cost_state17unknown5.440580e+099779state_snap_cost
3317433174snapsnap_cost_state18unknown1.302143e+099780state_snap_cost
3317533175snapsnap_cost_state19unknown5.091406e+089781state_snap_cost
\n", + "
" + ], + "text/plain": [ + " row_index variable variable_desc geographic_id geographic_level \\\n", + "33166 33166 snap snap_cost_state 1 unknown \n", + "33167 33167 snap snap_cost_state 10 unknown \n", + "33168 33168 snap snap_cost_state 11 unknown \n", + "33169 33169 snap snap_cost_state 12 unknown \n", + "33170 33170 snap snap_cost_state 13 unknown \n", + "33171 33171 snap snap_cost_state 15 unknown \n", + "33172 33172 snap snap_cost_state 16 unknown \n", + "33173 33173 snap snap_cost_state 17 unknown \n", + "33174 33174 snap snap_cost_state 18 unknown \n", + "33175 33175 snap snap_cost_state 19 unknown \n", + "\n", + " target_value stratum_id stratum_group_id \n", + "33166 2.048985e+09 9766 state_snap_cost \n", + "33167 2.962075e+08 9773 state_snap_cost \n", + "33168 3.793723e+08 9774 state_snap_cost \n", + "33169 6.756577e+09 9775 state_snap_cost \n", + "33170 3.232508e+09 9776 state_snap_cost \n", + "33171 8.424059e+08 9777 state_snap_cost \n", + "33172 2.494227e+08 9778 state_snap_cost \n", + "33173 5.440580e+09 9779 state_snap_cost \n", + "33174 1.302143e+09 9780 state_snap_cost \n", + "33175 5.091406e+08 9781 state_snap_cost " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "state_snap = tracer.row_catalog[\n", + " (tracer.row_catalog['variable'] == row_info['variable']) &\n", + " (tracer.row_catalog['variable_desc'] == row_info['variable_desc'])\n", + "].sort_values('geographic_id')\n", + "\n", + "assert state_snap.shape[0] == 51, f\"Expected 51 state SNAP targets, got {state_snap.shape[0]}\"\n", + "state_snap.head(10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 3: Finding an Interesting Household\n", + "\n", + "We need a household with:\n", + "- More than one person\n", + "- More than one SPM unit\n", + "- Each SPM unit has positive SNAP\n", + "\n", + "This tests that we correctly aggregate SNAP at the household level (sum across SPM units, not persons)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
person_idhousehold_idtax_unit_idspm_unit_idfamily_idmarital_unit_id
0250125250125001251.020
110301103103011030011031.080
212501125125011250011251.099
312502125125011250011251.0101
412503125125021250011252.0100
\n", + "
" + ], + "text/plain": [ + " person_id household_id tax_unit_id spm_unit_id family_id \\\n", + "0 2501 25 2501 25001 251.0 \n", + "1 10301 103 10301 103001 1031.0 \n", + "2 12501 125 12501 125001 1251.0 \n", + "3 12502 125 12501 125001 1251.0 \n", + "4 12503 125 12502 125001 1252.0 \n", + "\n", + " marital_unit_id \n", + "0 20 \n", + "1 80 \n", + "2 99 \n", + "3 101 \n", + "4 100 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "entity_rel = pd.DataFrame(\n", + " {\n", + " \"person_id\": sim.calculate(\"person_id\", map_to=\"person\").values,\n", + " \"household_id\": sim.calculate(\"household_id\", map_to=\"person\").values,\n", + " \"tax_unit_id\": sim.calculate(\"tax_unit_id\", map_to=\"person\").values,\n", + " \"spm_unit_id\": sim.calculate(\"spm_unit_id\", map_to=\"person\").values,\n", + " \"family_id\": sim.calculate(\"family_id\", map_to=\"person\").values,\n", + " \"marital_unit_id\": sim.calculate(\"marital_unit_id\", map_to=\"person\").values,\n", + " }\n", + ")\n", + "entity_rel.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: SNAP values differ by entity level due to broadcasting:\n", + "- `sim.calculate_dataframe(['spm_unit_id', 'snap'])` - rows are SPM units\n", + "- `sim.calculate_dataframe(['household_id', 'snap'])` - rows are households\n", + "- Person-level broadcasts the SPM unit's SNAP to each person" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
person_household_idperson_countsnap_minsnap_unique
45926623122293.1999512
5672806622937.4997562
5804821683789.1999513
66839199733592.0000002
7143979722789.1999512
834011252823236.5000002
94911288393789.1999512
\n", + "
" + ], + "text/plain": [ + " person_household_id person_count snap_min snap_unique\n", + "4592 66231 2 2293.199951 2\n", + "5672 80662 2 937.499756 2\n", + "5804 82168 3 789.199951 3\n", + "6683 91997 3 3592.000000 2\n", + "7143 97972 2 789.199951 2\n", + "8340 112528 2 3236.500000 2\n", + "9491 128839 3 789.199951 2" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p_df = sim.calculate_dataframe(['person_household_id', 'person_id', 'snap'], map_to=\"person\")\n", + "\n", + "hh_stats = p_df.groupby('person_household_id').agg(\n", + " person_count=('person_id', 'nunique'),\n", + " snap_min=('snap', 'min'),\n", + " snap_unique=('snap', 'nunique')\n", + ").reset_index()\n", + "\n", + "candidates = hh_stats[(hh_stats.person_count > 1) & (hh_stats.snap_min > 0) & (hh_stats.snap_unique > 1)]\n", + "candidates.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
person_household_idperson_idsnap__tmp_weights
197399199791997063592.00.0
197409199791997074333.50.0
197419199791997084333.50.0
\n", + "
" + ], + "text/plain": [ + " weight person_household_id person_id snap __tmp_weights\n", + "19739 0.0 91997 9199706 3592.0 0.0\n", + "19740 0.0 91997 9199707 4333.5 0.0\n", + "19741 0.0 91997 9199708 4333.5 0.0" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hh_id = candidates.iloc[3]['person_household_id']\n", + "p_df.loc[p_df.person_household_id == hh_id]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This household has 3 persons across 2 SPM units:\n", + "- Person 1, 2: SNAP = 3592.0\n", + "- Persons 3: SNAP = 789.2 " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
spm_unit_idsnap
6989919970023592.0
6990919970044333.5
\n", + "
" + ], + "text/plain": [ + " weight spm_unit_id snap\n", + "6989 0.0 91997002 3592.0\n", + "6990 0.0 91997004 4333.5" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hh_snap_goal = 3592.0 + 4333.5\n", + "\n", + "snap_df = sim.calculate_dataframe(['spm_unit_id', 'snap'])\n", + "snap_subset = entity_rel.loc[entity_rel.household_id == hh_id]\n", + "snap_df.loc[snap_df.spm_unit_id.isin(list(snap_subset.spm_unit_id))]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Household 91997.0 is from state FIPS 50\n" + ] + }, + { + "data": { + "text/plain": [ + "household_id 91997\n", + "state_fips 50\n", + "Name: 6683, dtype: int32" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hh_df = sim.calculate_dataframe(['household_id', 'state_fips'])\n", + "hh_loc = np.where(hh_df.household_id == hh_id)[0][0]\n", + "hh_one = hh_df.iloc[hh_loc]\n", + "hh_home_state = hh_one.state_fips\n", + "\n", + "print(f\"Household {hh_id} is from state FIPS {hh_home_state}\")\n", + "hh_one" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 4: Validate Matrix Values\n", + "\n", + "Each household appears as a column in X_sparse for every CD (436 times). For state-level SNAP targets, the matrix value should be:\n", + "- `hh_snap_goal` if the CD is in the household's home state\n", + "- `0` if the CD is in a different state" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "All 436 CD column values validated for household 91997.0\n" + ] + } + ], + "source": [ + "hh_col_lku = tracer.get_household_column_positions(hh_id)\n", + "\n", + "for cd in hh_col_lku.keys():\n", + " hh_away_state = int(cd) // 100\n", + " col_loc = hh_col_lku[cd]\n", + " col_info = tracer.get_column_info(col_loc)\n", + " \n", + " assert col_info['household_id'] == hh_id\n", + " \n", + " value_lku = tracer.lookup_matrix_cell(row_idx=row_loc, col_idx=col_loc)\n", + " assert value_lku['household']['household_id'] == hh_id\n", + " \n", + " metric = value_lku['matrix_value']\n", + " assert X_sparse[row_loc, col_loc] == metric\n", + "\n", + " if hh_away_state != target_geo_id:\n", + " assert metric == 0, f\"Expected 0 for CD {cd} (state {hh_away_state}), got {metric}\"\n", + " else:\n", + " assert metric == hh_snap_goal, f\"Expected {hh_snap_goal} for CD {cd}, got {metric}\"\n", + "\n", + "print(f\"All {len(hh_col_lku)} CD column values validated for household {hh_id}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 5: Create Sparse Dataset from Weights\n", + "\n", + "Test `create_sparse_cd_stacked_dataset` which reconstructs an h5 file from weight vectors. We verify:\n", + "1. Household appears in mapping file for CDs with non-zero weight\n", + "2. New household IDs correctly map back to originals\n", + "3. SNAP values are preserved" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "n_nonzero = 500000\n", + "total_size = X_sparse.shape[1]\n", + "\n", + "w = np.zeros(total_size)\n", + "nonzero_indices = rng_ben.choice(total_size, n_nonzero, replace=False)\n", + "w[nonzero_indices] = 2\n", + "\n", + "cd1 = '103'\n", + "cd2 = '3703'\n", + "output_dir = './temp'\n", + "w[hh_col_lku[cd1]] = 1.5\n", + "w[hh_col_lku[cd2]] = 1.7" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing subset of 2 CDs: 103, 3703...\n", + "Output path: ./temp/mapping1.h5\n", + "\n", + "Original dataset has 13,508 households\n", + "Extracted weights for 2 CDs from full weight matrix\n", + "Total active household-CD pairs: 2,292\n", + "Total weight in W matrix: 4,583\n", + "Processing CD 3703 (2/2)...\n", + "\n", + "Combining 2 CD DataFrames...\n", + "Total households across all CDs: 2,292\n", + "Combined DataFrame shape: (7054, 184)\n", + "\n", + "Weights in combined_df BEFORE reindexing:\n", + " HH weight sum: 0.01M\n", + " Person weight sum: 0.01M\n", + " Ratio: 1.00\n", + "\n", + "Reindexing all entity IDs using 25k ranges per CD...\n", + " Created 2,292 unique households across 2 CDs\n", + " Reindexing persons using 25k ranges...\n", + " Reindexing tax units...\n", + " Reindexing SPM units...\n", + " Reindexing marital units...\n", + " Final persons: 7,054\n", + " Final households: 2,292\n", + " Final tax units: 3,252\n", + " Final SPM units: 2,412\n", + " Final marital units: 5,445\n", + "\n", + "Weights in combined_df AFTER reindexing:\n", + " HH weight sum: 0.01M\n", + " Person weight sum: 0.01M\n", + " Ratio: 1.00\n", + "\n", + "Overflow check:\n", + " Max person ID after reindexing: 10,203,635\n", + " Max person ID × 100: 1,020,363,500\n", + " int32 max: 2,147,483,647\n", + " ✓ No overflow risk!\n", + "\n", + "Creating Dataset from combined DataFrame...\n", + "Building simulation from Dataset...\n", + "\n", + "Saving to ./temp/mapping1.h5...\n", + "Found 168 input variables (excluding calculated variables)\n", + "Variables saved: 180\n", + "Variables skipped: 3213\n", + "Sparse CD-stacked dataset saved successfully!\n", + "Household mapping saved to ./temp/mappings/mapping1_household_mapping.csv\n", + "\n", + "Verifying saved file...\n", + " Final households: 2,292\n", + " Final persons: 7,054\n", + " Total population (from household weights): 4,583\n", + " Total population (from person weights): 14,106\n", + " Average persons per household: 3.08\n", + "Output dataset shape: (2292, 4)\n" + ] + } + ], + "source": [ + "output_path = f\"{output_dir}/mapping1.h5\"\n", + "output_file = create_sparse_cd_stacked_dataset(\n", + " w,\n", + " cds_to_calibrate,\n", + " cd_subset=[cd1, cd2],\n", + " dataset_path=str(dataset_uri),\n", + " output_path=output_path,\n", + ")\n", + "\n", + "sim_test = Microsimulation(dataset=output_path)\n", + "df_test = sim_test.calculate_dataframe([\n", + " 'congressional_district_geoid',\n", + " 'household_id', 'household_weight', 'snap'])\n", + "\n", + "print(f\"Output dataset shape: {df_test.shape}\")\n", + "assert np.isclose(df_test.shape[0] / 2 * 436, n_nonzero, rtol=0.10)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
new_household_idoriginal_household_idcongressional_districtstate_fips
115175572919971031
1152520057991997370337
\n", + "
" + ], + "text/plain": [ + " new_household_id original_household_id congressional_district \\\n", + "1151 75572 91997 103 \n", + "1152 5200579 91997 3703 \n", + "\n", + " state_fips \n", + "1151 1 \n", + "1152 37 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mapping = pd.read_csv(f\"{output_dir}/mappings/mapping1_household_mapping.csv\")\n", + "match = mapping.loc[mapping.original_household_id == hh_id].shape[0]\n", + "assert match == 2, f\"Household should appear twice (once per CD), got {match}\"\n", + "\n", + "hh_mapping = mapping.loc[mapping.original_household_id == hh_id]\n", + "hh_mapping" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
congressional_district_geoidhousehold_idhousehold_weightsnap
0103750002.00.0
1103750012.00.0
2103750022.00.0
3103750032.00.0
4103750042.00.0
...............
2287370352011632.00.0
2288370352011642.00.0
2289370352011652.00.0
2290370352011662.00.0
2291370352011672.00.0
\n", + "

2292 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " weight congressional_district_geoid household_id household_weight \\\n", + "0 2.0 103 75000 2.0 \n", + "1 2.0 103 75001 2.0 \n", + "2 2.0 103 75002 2.0 \n", + "3 2.0 103 75003 2.0 \n", + "4 2.0 103 75004 2.0 \n", + "... ... ... ... ... \n", + "2287 2.0 3703 5201163 2.0 \n", + "2288 2.0 3703 5201164 2.0 \n", + "2289 2.0 3703 5201165 2.0 \n", + "2290 2.0 3703 5201166 2.0 \n", + "2291 2.0 3703 5201167 2.0 \n", + "\n", + " snap \n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 \n", + "... ... \n", + "2287 0.0 \n", + "2288 0.0 \n", + "2289 0.0 \n", + "2290 0.0 \n", + "2291 0.0 \n", + "\n", + "[2292 rows x 5 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_test" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CD 103: weight=1.5, snap=7925.5\n" + ] + } + ], + "source": [ + "df_test_cd1 = df_test.loc[df_test.congressional_district_geoid == int(cd1)]\n", + "df_test_cd2 = df_test.loc[df_test.congressional_district_geoid == int(cd2)]\n", + "\n", + "hh_mapping_cd1 = hh_mapping.loc[hh_mapping.congressional_district == int(cd1)]\n", + "new_hh_id_cd1 = hh_mapping_cd1['new_household_id'].values[0]\n", + "\n", + "assert hh_mapping_cd1.shape[0] == 1\n", + "assert hh_mapping_cd1.original_household_id.values[0] == hh_id\n", + "\n", + "w_hh_cd1 = w[hh_col_lku[cd1]]\n", + "assert_cd1_df = df_test_cd1.loc[df_test_cd1.household_id == new_hh_id_cd1]\n", + "\n", + "assert np.isclose(assert_cd1_df.household_weight.values[0], w_hh_cd1, atol=0.001)\n", + "assert np.isclose(assert_cd1_df.snap.values[0], hh_snap_goal, atol=0.001)\n", + "\n", + "print(f\"CD {cd1}: weight={w_hh_cd1}, snap={assert_cd1_df.snap.values[0]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CD 3703: weight=1.7, snap=7925.5\n" + ] + } + ], + "source": [ + "hh_mapping_cd2 = hh_mapping.loc[hh_mapping.congressional_district == int(cd2)]\n", + "new_hh_id_cd2 = hh_mapping_cd2['new_household_id'].values[0]\n", + "\n", + "assert hh_mapping_cd2.shape[0] == 1\n", + "assert hh_mapping_cd2.original_household_id.values[0] == hh_id\n", + "\n", + "w_hh_cd2 = w[hh_col_lku[cd2]]\n", + "assert_cd2_df = df_test_cd2.loc[df_test_cd2.household_id == new_hh_id_cd2]\n", + "\n", + "assert np.isclose(assert_cd2_df.household_weight.values[0], w_hh_cd2, atol=0.001)\n", + "assert np.isclose(assert_cd2_df.snap.values[0], hh_snap_goal, atol=0.001)\n", + "\n", + "print(f\"CD {cd2}: weight={w_hh_cd2}, snap={assert_cd2_df.snap.values[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test: Zero weight excludes household from mapping" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing subset of 1 CDs: 3703...\n", + "Output path: ./temp/3703.h5\n", + "\n", + "Original dataset has 13,508 households\n", + "Extracted weights for 1 CDs from full weight matrix\n", + "Total active household-CD pairs: 1,167\n", + "Total weight in W matrix: 2,334\n", + "Processing CD 3703 (1/1)...\n", + "\n", + "Combining 1 CD DataFrames...\n", + "Total households across all CDs: 1,167\n", + "Combined DataFrame shape: (3633, 184)\n", + "\n", + "Weights in combined_df BEFORE reindexing:\n", + " HH weight sum: 0.01M\n", + " Person weight sum: 0.01M\n", + " Ratio: 1.00\n", + "\n", + "Reindexing all entity IDs using 25k ranges per CD...\n", + " Created 1,167 unique households across 1 CDs\n", + " Reindexing persons using 25k ranges...\n", + " Reindexing tax units...\n", + " Reindexing SPM units...\n", + " Reindexing marital units...\n", + " Final persons: 3,633\n", + " Final households: 1,167\n", + " Final tax units: 1,683\n", + " Final SPM units: 1,227\n", + " Final marital units: 2,818\n", + "\n", + "Weights in combined_df AFTER reindexing:\n", + " HH weight sum: 0.01M\n", + " Person weight sum: 0.01M\n", + " Ratio: 1.00\n", + "\n", + "Overflow check:\n", + " Max person ID after reindexing: 10,203,632\n", + " Max person ID × 100: 1,020,363,200\n", + " int32 max: 2,147,483,647\n", + " ✓ No overflow risk!\n", + "\n", + "Creating Dataset from combined DataFrame...\n", + "Building simulation from Dataset...\n", + "\n", + "Saving to ./temp/3703.h5...\n", + "Found 168 input variables (excluding calculated variables)\n", + "Variables saved: 180\n", + "Variables skipped: 3213\n", + "Sparse CD-stacked dataset saved successfully!\n", + "Household mapping saved to ./temp/mappings/3703_household_mapping.csv\n", + "\n", + "Verifying saved file...\n", + " Final households: 1,167\n", + " Final persons: 3,633\n", + " Total population (from household weights): 2,334\n", + " Total population (from person weights): 7,266\n", + " Average persons per household: 3.11\n", + "Confirmed: household 91997.0 excluded from CD 3703 mapping when weight=0\n" + ] + } + ], + "source": [ + "w[hh_col_lku[cd2]] = 0\n", + "\n", + "output_path = f\"{output_dir}/{cd2}.h5\"\n", + "output_file = create_sparse_cd_stacked_dataset(\n", + " w,\n", + " cds_to_calibrate,\n", + " cd_subset=[cd2],\n", + " dataset_path=str(dataset_uri),\n", + " output_path=output_path,\n", + ")\n", + "\n", + "sim_test = Microsimulation(dataset=output_path)\n", + "df_test = sim_test.calculate_dataframe(['household_id', 'household_weight', 'snap'])\n", + "\n", + "cd2_mapping = pd.read_csv(f\"{output_dir}/mappings/{cd2}_household_mapping.csv\")\n", + "match = cd2_mapping.loc[cd2_mapping.original_household_id == hh_id].shape[0]\n", + "assert match == 0, f\"Household with zero weight should not appear in mapping, got {match}\"\n", + "\n", + "print(f\"Confirmed: household {hh_id} excluded from CD {cd2} mapping when weight=0\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Section 6: End-to-End Validation (X @ w == sim.calculate)\n", + "\n", + "The ultimate test: verify that matrix multiplication `X_sparse @ w` matches what we get from running the simulation on the reconstructed h5 file.\n", + "\n", + "With `freeze_calculated_vars=True`, state-dependent variables like SNAP are saved to the h5 file to prevent recalculation." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "total_size = X_sparse.shape[1]\n", + "w = np.zeros(total_size)\n", + "n_nonzero = 50000\n", + "nonzero_indices = rng_ben.choice(total_size, n_nonzero, replace=False)\n", + "w[nonzero_indices] = 7\n", + "w[hh_col_lku[cd1]] = 11\n", + "w[hh_col_lku[cd2]] = 12\n", + "\n", + "assert np.sum(w > 0) <= n_nonzero + 2" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing all 436 congressional districts\n", + "Output path: ./temp/national.h5\n", + "\n", + "Original dataset has 13,508 households\n", + "Total active household-CD pairs: 50,002\n", + "Total weight in W matrix: 350,023\n", + "Processing CD 1201 (10/436)...\n", + "Processing CD 1211 (20/436)...\n", + "Processing CD 1221 (30/436)...\n", + "Processing CD 1303 (40/436)...\n", + "Processing CD 1313 (50/436)...\n", + "Processing CD 1705 (60/436)...\n", + "Processing CD 1715 (70/436)...\n", + "Processing CD 1808 (80/436)...\n", + "Processing CD 201 (90/436)...\n", + "Processing CD 2204 (100/436)...\n", + "Processing CD 2406 (110/436)...\n", + "Processing CD 2508 (120/436)...\n", + "Processing CD 2609 (130/436)...\n", + "Processing CD 2706 (140/436)...\n", + "Processing CD 2904 (150/436)...\n", + "Processing CD 3201 (160/436)...\n", + "Processing CD 3405 (170/436)...\n", + "Processing CD 3503 (180/436)...\n", + "Processing CD 3610 (190/436)...\n", + "Processing CD 3620 (200/436)...\n", + "Processing CD 3704 (210/436)...\n", + "Processing CD 3714 (220/436)...\n", + "Processing CD 3909 (230/436)...\n", + "Processing CD 4004 (240/436)...\n", + "Processing CD 409 (250/436)...\n", + "Processing CD 4204 (260/436)...\n", + "Processing CD 4214 (270/436)...\n", + "Processing CD 4505 (280/436)...\n", + "Processing CD 4707 (290/436)...\n", + "Processing CD 4808 (300/436)...\n", + "Processing CD 4818 (310/436)...\n", + "Processing CD 4828 (320/436)...\n", + "Processing CD 4838 (330/436)...\n", + "Processing CD 5101 (340/436)...\n", + "Processing CD 5111 (350/436)...\n", + "Processing CD 5310 (360/436)...\n", + "Processing CD 5508 (370/436)...\n", + "Processing CD 609 (380/436)...\n", + "Processing CD 619 (390/436)...\n", + "Processing CD 629 (400/436)...\n", + "Processing CD 639 (410/436)...\n", + "Processing CD 649 (420/436)...\n", + "Processing CD 807 (430/436)...\n", + "Processing CD 905 (436/436)...\n", + "\n", + "Combining 436 CD DataFrames...\n", + "Total households across all CDs: 50,002\n", + "Combined DataFrame shape: (152001, 185)\n", + "\n", + "Weights in combined_df BEFORE reindexing:\n", + " HH weight sum: 1.06M\n", + " Person weight sum: 1.06M\n", + " Ratio: 1.00\n", + "\n", + "Reindexing all entity IDs using 25k ranges per CD...\n", + " Created 50,002 unique households across 436 CDs\n", + " Reindexing persons using 25k ranges...\n", + " Reindexing tax units...\n", + " Reindexing SPM units...\n", + " Reindexing marital units...\n", + " Final persons: 152,001\n", + " Final households: 50,002\n", + " Final tax units: 70,803\n", + " Final SPM units: 52,275\n", + " Final marital units: 116,736\n", + "\n", + "Weights in combined_df AFTER reindexing:\n", + " HH weight sum: 1.06M\n", + " Person weight sum: 1.06M\n", + " Ratio: 1.00\n", + "\n", + "Overflow check:\n", + " Max person ID after reindexing: 15,875,309\n", + " Max person ID × 100: 1,587,530,900\n", + " int32 max: 2,147,483,647\n", + " ✓ No overflow risk!\n", + "\n", + "Creating Dataset from combined DataFrame...\n", + "Building simulation from Dataset...\n", + "\n", + "Saving to ./temp/national.h5...\n", + "Found 168 input variables (excluding calculated variables)\n", + "Also freezing 1 state-dependent calculated variables\n", + "Variables saved: 192\n", + "Variables skipped: 3212\n", + "Sparse CD-stacked dataset saved successfully!\n", + "Household mapping saved to ./temp/mappings/national_household_mapping.csv\n", + "\n", + "Verifying saved file...\n", + " Final households: 50,002\n", + " Final persons: 152,001\n", + " Total population (from household weights): 350,023\n", + " Total population (from person weights): 1,064,034\n", + " Average persons per household: 3.04\n" + ] + } + ], + "source": [ + "output_path = f\"{output_dir}/national.h5\"\n", + "output_file = create_sparse_cd_stacked_dataset(\n", + " w,\n", + " cds_to_calibrate,\n", + " dataset_path=str(dataset_uri),\n", + " output_path=output_path,\n", + " freeze_calculated_vars=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(50002, 5)\n", + "50002\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
household_idhousehold_weightcongressional_district_geoidstate_fipssnap
007.01001100.0
117.01001100.0
227.01001100.0
337.01001100.0
447.01001100.0
\n", + "
" + ], + "text/plain": [ + " household_id household_weight congressional_district_geoid state_fips \\\n", + "0 0 7.0 1001 10 \n", + "1 1 7.0 1001 10 \n", + "2 2 7.0 1001 10 \n", + "3 3 7.0 1001 10 \n", + "4 4 7.0 1001 10 \n", + "\n", + " snap \n", + "0 0.0 \n", + "1 0.0 \n", + "2 0.0 \n", + "3 0.0 \n", + "4 0.0 " + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sim_test = Microsimulation(dataset=output_path)\n", + "hh_snap_df = pd.DataFrame(sim_test.calculate_dataframe([\n", + " \"household_id\", \"household_weight\", \"congressional_district_geoid\", \"state_fips\", \"snap\"])\n", + ")\n", + "\n", + "assert np.sum(w > 0) == hh_snap_df.shape[0], f\"Expected {np.sum(w > 0)} rows, got {hh_snap_df.shape[0]}\"\n", + "print(hh_snap_df.shape)\n", + "print(np.sum(w > 0))\n", + "hh_snap_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1598, 5)\n", + "(1, 5889488)\n", + " household_id household_weight congressional_district_geoid \\\n", + "23789 5150000 7.0 3701 \n", + "23790 5150001 7.0 3701 \n", + "23791 5150002 7.0 3701 \n", + "23792 5150003 7.0 3701 \n", + "23793 5150004 7.0 3701 \n", + "\n", + " state_fips snap \n", + "23789 37 1243.5 \n", + "23790 37 0.0 \n", + "23791 37 0.0 \n", + "23792 37 0.0 \n", + "23793 37 0.0 \n", + " household_id household_weight congressional_district_geoid \\\n", + "25382 5475112 7.0 3714 \n", + "25383 5475113 7.0 3714 \n", + "25384 5475114 7.0 3714 \n", + "25385 5475115 7.0 3714 \n", + "25386 5475116 7.0 3714 \n", + "\n", + " state_fips snap \n", + "25382 37 0.0 \n", + "25383 37 0.0 \n", + "25384 37 0.0 \n", + "25385 37 0.0 \n", + "25386 37 0.0 \n" + ] + } + ], + "source": [ + "print(geo_1_df.shape)\n", + "print(X_sparse[row_loc, :].shape)\n", + "print(geo_1_df.head())\n", + "print(geo_1_df.tail())" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "geo_1_df['col_position'] = np.nan\n", + "geo_1_df['X_sparse_value'] = np.nan\n", + "geo_1_df['w_value'] = np.nan\n", + "\n", + "for i in range(geo_1_df.shape[0]):\n", + " df_hh_id_new = geo_1_df.iloc[i]['household_id']\n", + " # get the old household id\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target row info: {'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "{'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "37\n", + "(1598, 5)\n", + "Matrix multiplication (X @ w)[33194] = 2,895,502.61\n", + "Simulation sum(snap * weight) for state 1 = 2,920,930.08\n", + "Matrix nonzero: 14574, Sim nonzero: 129\n", + "[np.float64(0.0), np.float64(0.0), np.float64(12.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0)]\n", + "Weight from matrix columns: 12.0\n", + "Weight from sim: 11191.0\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Mismatch: 2920930.082221985 vs 2895502.609931946", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[40], line 30\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from matrix columns: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39msum(w_in_state)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from sim: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgeo_1_df\u001b[38;5;241m.\u001b[39mhousehold_weight\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m np\u001b[38;5;241m.\u001b[39misclose(y_hat_sim, snap_hat_geo1, atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMismatch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00my_hat_sim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m vs \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msnap_hat_geo1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEnd-to-end validation PASSED\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: Mismatch: 2920930.082221985 vs 2895502.609931946" + ] + } + ], + "source": [ + "print(f\"Target row info: {row_info}\")\n", + "\n", + "y_hat = X_sparse @ w\n", + "\n", + "# Ok, but hang on, you have two districts from two different states, but you \n", + "# didn't use them here. The geo should be NC\n", + "print(row_info)\n", + "print(target_geo_id)\n", + "\n", + "snap_hat_geo1 = y_hat[row_loc]\n", + "\n", + "geo_1_df = hh_snap_df.loc[hh_snap_df.state_fips == target_geo_id]\n", + "y_hat_sim = np.sum(geo_1_df.snap.values * geo_1_df.household_weight.values)\n", + "print(geo_1_df.shape)\n", + "\n", + "print(f\"Matrix multiplication (X @ w)[{row_loc}] = {snap_hat_geo1:,.2f}\")\n", + "print(f\"Simulation sum(snap * weight) for state 1 = {y_hat_sim:,.2f}\")\n", + "\n", + "# Check if household counts match\n", + "n_matrix = np.sum(X_sparse[row_loc, :].toarray() > 0)\n", + "n_sim = (geo_1_df.snap > 0).sum()\n", + "print(f\"Matrix nonzero: {n_matrix}, Sim nonzero: {n_sim}\")\n", + "\n", + "assert np.isclose(y_hat_sim, snap_hat_geo1, atol=10), f\"Mismatch: {y_hat_sim} vs {snap_hat_geo1}\"\n", + "print(\"\\nEnd-to-end validation PASSED\")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "436" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target row info: {'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "{'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "37\n", + "(1598, 5)\n", + "Matrix multiplication (X @ w)[33194] = 2,895,502.61\n", + "Simulation sum(snap * weight) for state 1 = 2,920,930.08\n", + "Matrix nonzero: 14574, Sim nonzero: 129\n", + "[np.float64(0.0), np.float64(0.0), np.float64(12.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0)]\n", + "Weight from matrix columns: 12.0\n", + "Weight from sim: 11191.0\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Mismatch: 2920930.082221985 vs 2895502.609931946", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[40], line 30\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from matrix columns: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39msum(w_in_state)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from sim: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgeo_1_df\u001b[38;5;241m.\u001b[39mhousehold_weight\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m np\u001b[38;5;241m.\u001b[39misclose(y_hat_sim, snap_hat_geo1, atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMismatch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00my_hat_sim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m vs \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msnap_hat_geo1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEnd-to-end validation PASSED\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: Mismatch: 2920930.082221985 vs 2895502.609931946" + ] + } + ], + "source": [ + "print(f\"Target row info: {row_info}\")\n", + "\n", + "y_hat = X_sparse @ w\n", + "\n", + "# Ok, but hang on, you have two districts from two different states, but you \n", + "# didn't use them here. The geo should be NC\n", + "print(row_info)\n", + "print(target_geo_id)\n", + "\n", + "snap_hat_geo1 = y_hat[row_loc]\n", + "\n", + "geo_1_df = hh_snap_df.loc[hh_snap_df.state_fips == target_geo_id]\n", + "y_hat_sim = np.sum(geo_1_df.snap.values * geo_1_df.household_weight.values)\n", + "print(geo_1_df.shape)\n", + "\n", + "print(f\"Matrix multiplication (X @ w)[{row_loc}] = {snap_hat_geo1:,.2f}\")\n", + "print(f\"Simulation sum(snap * weight) for state 1 = {y_hat_sim:,.2f}\")\n", + "\n", + "# Check if household counts match\n", + "n_matrix = np.sum(X_sparse[row_loc, :].toarray() > 0)\n", + "n_sim = (geo_1_df.snap > 0).sum()\n", + "print(f\"Matrix nonzero: {n_matrix}, Sim nonzero: {n_sim}\")\n", + "\n", + "# Check total weights\n", + "w_in_state = [w[hh_col_lku[cd]] for cd in hh_col_lku if int(cd)//100 == target_geo_id]\n", + "print(w_in_state)\n", + "print(f\"Weight from matrix columns: {np.sum(w_in_state)}\")\n", + "print(f\"Weight from sim: {geo_1_df.household_weight.sum()}\")\n", + "\n", + "assert np.isclose(y_hat_sim, snap_hat_geo1, atol=10), f\"Mismatch: {y_hat_sim} vs {snap_hat_geo1}\"\n", + "print(\"\\nEnd-to-end validation PASSED\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Target row info: {'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "{'row_index': 33194, 'variable': 'snap', 'variable_desc': 'snap_cost_state', 'geographic_id': '37', 'geographic_level': 'unknown', 'target_value': 4041086120.0, 'stratum_id': 9799, 'stratum_group_id': 'state_snap_cost'}\n", + "37\n", + "(1598, 5)\n", + "Matrix multiplication (X @ w)[33194] = 2,895,502.61\n", + "Simulation sum(snap * weight) for state 1 = 2,920,930.08\n", + "Matrix nonzero: 14574, Sim nonzero: 129\n", + "[np.float64(0.0), np.float64(0.0), np.float64(12.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0), np.float64(0.0)]\n", + "Weight from matrix columns: 12.0\n", + "Weight from sim: 11191.0\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Mismatch: 2920930.082221985 vs 2895502.609931946", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[40], line 30\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from matrix columns: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39msum(w_in_state)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWeight from sim: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mgeo_1_df\u001b[38;5;241m.\u001b[39mhousehold_weight\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m np\u001b[38;5;241m.\u001b[39misclose(y_hat_sim, snap_hat_geo1, atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m), \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMismatch: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00my_hat_sim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m vs \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msnap_hat_geo1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mEnd-to-end validation PASSED\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: Mismatch: 2920930.082221985 vs 2895502.609931946" + ] + } + ], + "source": [ + "print(f\"Target row info: {row_info}\")\n", + "\n", + "y_hat = X_sparse @ w\n", + "\n", + "# Ok, but hang on, you have two districts from two different states, but you \n", + "# didn't use them here. The geo should be NC\n", + "print(row_info)\n", + "print(target_geo_id)\n", + "\n", + "snap_hat_geo1 = y_hat[row_loc]\n", + "\n", + "geo_1_df = hh_snap_df.loc[hh_snap_df.state_fips == target_geo_id]\n", + "y_hat_sim = np.sum(geo_1_df.snap.values * geo_1_df.household_weight.values)\n", + "print(geo_1_df.shape)\n", + "\n", + "print(f\"Matrix multiplication (X @ w)[{row_loc}] = {snap_hat_geo1:,.2f}\")\n", + "print(f\"Simulation sum(snap * weight) for state 1 = {y_hat_sim:,.2f}\")\n", + "\n", + "# Check if household counts match\n", + "n_matrix = np.sum(X_sparse[row_loc, :].toarray() > 0)\n", + "n_sim = (geo_1_df.snap > 0).sum()\n", + "print(f\"Matrix nonzero: {n_matrix}, Sim nonzero: {n_sim}\")\n", + "\n", + "# Check total weights\n", + "w_in_state = [w[hh_col_lku[cd]] for cd in hh_col_lku if int(cd)//100 == target_geo_id]\n", + "print(w_in_state)\n", + "print(f\"Weight from matrix columns: {np.sum(w_in_state)}\")\n", + "print(f\"Weight from sim: {geo_1_df.household_weight.sum()}\")\n", + "\n", + "assert np.isclose(y_hat_sim, snap_hat_geo1, atol=10), f\"Mismatch: {y_hat_sim} vs {snap_hat_geo1}\"\n", + "print(\"\\nEnd-to-end validation PASSED\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = \"./temp\"\n", + "w = np.load('w_cd_20251126_131911.npy')\n", + "print(len(w))\n", + "print(len(cds_to_calibrate))\n", + "\n", + "print(w)\n", + "print(dataset_uri)\n", + "output_path = f\"{output_dir}/RI.h5\"\n", + "output_file = create_sparse_cd_stacked_dataset(\n", + " w,\n", + " cds_to_calibrate,\n", + " ['4401', '4402'],\n", + " dataset_path=str(dataset_uri),\n", + " output_path=output_path,\n", + " freeze_calculated_vars=False,\n", + ")\n", + "\n", + "for i in range(51):\n", + " row_loc = group_71.iloc[i]['row_index']\n", + " row_info = tracer.get_row_info(row_loc)\n", + " var = row_info['variable']\n", + " var_desc = row_info['variable_desc']\n", + " target_geo_id = int(row_info['geographic_id'])\n", + " if target_geo_id == 44:\n", + " break\n", + "\n", + "print(\"Row info for first SNAP state target:\")\n", + "row_info\n", + "print(f\"Target row info: {row_info}\")\n", + "\n", + "y_hat = X_sparse @ w\n", + "snap_hat_geo44 = y_hat[row_loc]\n", + "\n", + "sim_test = Microsimulation(dataset=output_path)\n", + "hh_snap_df = pd.DataFrame(sim_test.calculate_dataframe([\n", + " \"household_id\", \"household_weight\", \"congressional_district_geoid\", \"state_fips\", \"snap\"])\n", + ")\n", + "\n", + "geo_44_df = hh_snap_df.loc[hh_snap_df.state_fips == 44]\n", + "y_hat_sim = np.sum(geo_44_df.snap.values * geo_44_df.household_weight.values)\n", + "\n", + "print(\"\\nThe calibration dashboard shows and estimate of 393.86M\")\n", + "print(f\"Matrix multiplication (X @ w)[{row_loc}] = {snap_hat_geo44:,.2f}\")\n", + "print(f\"Simulation sum(snap * weight) for state 44 = {y_hat_sim:,.2f}\")\n", + "\n", + "assert np.isclose(y_hat_sim, snap_hat_geo44, atol=10), f\"Mismatch: {y_hat_sim} vs {snap_hat_geo44}\"\n", + "print(\"\\nFull Weight from Model fitting - End-to-end validation PASSED\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "import os\n", + "\n", + "if os.path.exists('./temp'):\n", + " shutil.rmtree('./temp')\n", + " print(\"Cleaned up ./temp directory\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_path = f\"{output_dir}/3714.h5\"\n", + "output_file = create_sparse_cd_stacked_dataset(\n", + " w,\n", + " ['3714'],\n", + " dataset_path=str(dataset_uri),\n", + " output_path=output_path,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/holdout_validation.py b/policyengine_us_data/datasets/cps/local_area_calibration/holdout_validation.py new file mode 100644 index 00000000..6a65cae4 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/holdout_validation.py @@ -0,0 +1,480 @@ +import numpy as np +import pandas as pd +import torch +from scipy import sparse as sp +from typing import Tuple, List, Dict, Optional + + +def create_holdout_split( + X_sparse: sp.csr_matrix, + targets: np.ndarray, + target_groups: np.ndarray, + holdout_group_indices: List[int], +) -> Tuple[Dict, Dict]: + """ + Split data into training and holdout sets based on target group indices. + + Args: + X_sparse: Sparse calibration matrix (n_targets x n_features) + targets: Target values array + target_groups: Group assignment for each target + holdout_group_indices: List of group indices to put in holdout set + + Returns: + train_data: Dict with X, targets, target_groups for training + holdout_data: Dict with X, targets, target_groups for holdout + """ + holdout_group_set = set(holdout_group_indices) + + # Create masks + holdout_mask = np.isin(target_groups, list(holdout_group_set)) + train_mask = ~holdout_mask + + # Split data + train_data = { + "X": X_sparse[train_mask, :], + "targets": targets[train_mask], + "target_groups": target_groups[train_mask], + "original_groups": target_groups[train_mask], # Keep original IDs + } + + holdout_data = { + "X": X_sparse[holdout_mask, :], + "targets": targets[holdout_mask], + "target_groups": target_groups[holdout_mask], + "original_groups": target_groups[holdout_mask], # Keep original IDs + } + + # Renumber groups to be consecutive for model training + train_data["target_groups"] = renumber_groups(train_data["target_groups"]) + # For holdout, also renumber for consistency in model evaluation + # But keep original_groups for reporting + holdout_data["target_groups"] = renumber_groups( + holdout_data["target_groups"] + ) + + return train_data, holdout_data + + +def renumber_groups(groups: np.ndarray) -> np.ndarray: + """Renumber groups to be consecutive starting from 0.""" + unique_groups = np.unique(groups) + mapping = {old: new for new, old in enumerate(unique_groups)} + return np.array([mapping[g] for g in groups]) + + +def calculate_group_losses( + model, + X_sparse: sp.csr_matrix, + targets: np.ndarray, + target_groups: np.ndarray, + loss_type: str = "relative", + original_groups: np.ndarray = None, +) -> Dict[str, float]: + """ + Calculate mean loss per group and overall mean group loss. + + Args: + model: Trained SparseCalibrationWeights model + X_sparse: Sparse calibration matrix + targets: Target values + target_groups: Group assignments (possibly renumbered) + loss_type: Type of loss ("relative" or "absolute") + original_groups: Original group IDs (optional, for reporting) + + Returns: + Dict with per-group losses and mean group loss + """ + with torch.no_grad(): + predictions = model.predict(X_sparse).cpu().numpy() + + # Calculate per-target losses + if loss_type == "relative": + # For reporting, use absolute relative error to match L0's verbose output + # L0 reports |relative_error|, not squared + losses = np.abs((predictions - targets) / (targets + 1)) + else: + # For absolute, also use non-squared for consistency + losses = np.abs(predictions - targets) + + # Use original groups if provided, otherwise use renumbered groups + groups_for_reporting = ( + original_groups if original_groups is not None else target_groups + ) + + # Calculate mean loss per group + unique_groups = np.unique(groups_for_reporting) + group_losses = {} + + for group_id in unique_groups: + group_mask = groups_for_reporting == group_id + group_losses[int(group_id)] = np.mean(losses[group_mask]) + + # Mean across groups (not weighted by group size) + mean_group_mare = np.mean(list(group_losses.values())) + + return { + "per_group": group_losses, + "mean_group_mare": mean_group_mare, + "n_groups": len(unique_groups), + } + + +def run_holdout_experiment( + X_sparse: sp.csr_matrix, + targets: np.ndarray, + target_groups: np.ndarray, + holdout_group_indices: List[int], + model_params: Dict, + training_params: Dict, +) -> Dict: + """ + Run a single holdout experiment with specified groups. + + Args: + X_sparse: Full sparse calibration matrix + targets: Full target values + target_groups: Full group assignments + holdout_group_indices: Groups to hold out + model_params: Parameters for SparseCalibrationWeights + training_params: Parameters for model.fit() + + Returns: + Dict with training and holdout results + """ + from l0.calibration import SparseCalibrationWeights + + # Split data + train_data, holdout_data = create_holdout_split( + X_sparse, targets, target_groups, holdout_group_indices + ) + + print( + f"Training samples: {len(train_data['targets'])}, " + f"Holdout samples: {len(holdout_data['targets'])}" + ) + print( + f"Training groups: {len(np.unique(train_data['target_groups']))}, " + f"Holdout groups: {len(np.unique(holdout_data['target_groups']))}" + ) + + # Create and train model + model = SparseCalibrationWeights( + n_features=X_sparse.shape[1], **model_params + ) + + model.fit( + M=train_data["X"], + y=train_data["targets"], + target_groups=train_data["target_groups"], + **training_params, + ) + + # Calculate losses with original group IDs + train_losses = calculate_group_losses( + model, + train_data["X"], + train_data["targets"], + train_data["target_groups"], + training_params.get("loss_type", "relative"), + original_groups=train_data["original_groups"], + ) + + holdout_losses = calculate_group_losses( + model, + holdout_data["X"], + holdout_data["targets"], + holdout_data["target_groups"], + training_params.get("loss_type", "relative"), + original_groups=holdout_data["original_groups"], + ) + + # Get sparsity info + active_info = model.get_active_weights() + + # Get the actual weight values + with torch.no_grad(): + weights = model.get_weights(deterministic=True).cpu().numpy() + + results = { + "train_mean_group_mare": train_losses["mean_group_mare"], + "holdout_mean_group_mare": holdout_losses["mean_group_mare"], + "train_group_losses": train_losses["per_group"], + "holdout_group_losses": holdout_losses["per_group"], + "n_train_groups": train_losses["n_groups"], + "n_holdout_groups": holdout_losses["n_groups"], + "active_weights": active_info["count"], + "total_weights": X_sparse.shape[1], + "sparsity_pct": 100 * (1 - active_info["count"] / X_sparse.shape[1]), + "weights": weights, # Store the weight vector + "model": model, # Optionally store the entire model object + } + + return results + + +def compute_aggregate_losses( + X_sparse: sp.csr_matrix, + weights: np.ndarray, + targets_df: pd.DataFrame, + target_groups: np.ndarray, + training_group_ids: List[int], + holdout_group_ids: List[int], +) -> Dict: + """ + Compute aggregate losses showing how well CD/state predictions aggregate to higher levels. + Returns losses organized by group_id with 'state' and 'national' sub-keys. + + Args: + X_sparse: Calibration matrix + weights: Calibrated weights + targets_df: DataFrame with geographic info and group assignments + target_groups: Group assignments array + training_group_ids: Groups used in training + holdout_group_ids: Groups held out + + Returns: + Dict with train_aggregate_losses and holdout_aggregate_losses + """ + + # Calculate predictions + predictions = X_sparse @ weights + targets_df = targets_df.copy() + targets_df["prediction"] = predictions + targets_df["group_id"] = target_groups + + # Identify which groups are training vs holdout + train_aggregate_losses = {} + holdout_aggregate_losses = {} + + # Process each unique group + for group_id in np.unique(target_groups): + group_mask = target_groups == group_id + group_targets = targets_df[group_mask].copy() + + if len(group_targets) == 0: + continue + + # Determine if this is a training or holdout group + is_training = group_id in training_group_ids + is_holdout = group_id in holdout_group_ids + + if not (is_training or is_holdout): + continue # Skip unknown groups + + # Get the primary geographic level of this group + geo_ids = group_targets["geographic_id"].unique() + + # Determine the geographic level + if "US" in geo_ids and len(geo_ids) == 1: + # National-only group - no aggregation possible, skip + continue + elif all(len(str(g)) > 2 for g in geo_ids if g != "US"): + # CD-level group - can aggregate to state and national + primary_level = "cd" + elif all(len(str(g)) <= 2 for g in geo_ids if g != "US"): + # State-level group - can aggregate to national only + primary_level = "state" + else: + # Mixed or unclear - skip + continue + + aggregate_losses = {} + + # For CD-level groups, compute state and national aggregation + if primary_level == "cd": + # Extract state from CD codes + group_targets["state"] = group_targets["geographic_id"].apply( + lambda x: ( + x[:2] + if len(str(x)) == 4 + else str(x)[:-2] if len(str(x)) == 3 else str(x)[:2] + ) + ) + + # Get the variable(s) for this group + variables = group_targets["variable"].unique() + + state_losses = [] + for variable in variables: + var_targets = group_targets[ + group_targets["variable"] == variable + ] + + # Aggregate by state + state_aggs = var_targets.groupby("state").agg( + {"value": "sum", "prediction": "sum"} + ) + + # Compute relative error for each state + for state_id, row in state_aggs.iterrows(): + if row["value"] != 0: + rel_error = abs( + (row["prediction"] - row["value"]) / row["value"] + ) + state_losses.append(rel_error) + + # Mean across all states + if state_losses: + aggregate_losses["state"] = np.mean(state_losses) + + # National aggregation + total_actual = group_targets["value"].sum() + total_pred = group_targets["prediction"].sum() + if total_actual != 0: + aggregate_losses["national"] = abs( + (total_pred - total_actual) / total_actual + ) + + # For state-level groups, compute national aggregation only + elif primary_level == "state": + total_actual = group_targets["value"].sum() + total_pred = group_targets["prediction"].sum() + if total_actual != 0: + aggregate_losses["national"] = abs( + (total_pred - total_actual) / total_actual + ) + + # Store in appropriate dict + if aggregate_losses: + if is_training: + train_aggregate_losses[group_id] = aggregate_losses + else: + holdout_aggregate_losses[group_id] = aggregate_losses + + return { + "train_aggregate_losses": train_aggregate_losses, + "holdout_aggregate_losses": holdout_aggregate_losses, + } + + +def simple_holdout( + X_sparse, + targets, + target_groups, + init_weights, + holdout_group_ids, + targets_df=None, # Optional: needed for hierarchical checks + check_hierarchical=False, # Optional: enable hierarchical analysis + epochs=10, + lambda_l0=8e-7, + lr=0.2, + verbose_spacing=5, + device="cuda", # Add device parameter +): + """ + Simple holdout validation for notebooks - no DataFrame dependencies. + + Args: + X_sparse: Sparse matrix from cd_matrix_sparse.npz + targets: Target values from cd_targets_array.npy + target_groups: Group assignments from cd_target_groups.npy + init_weights: Initial weights from cd_init_weights.npy + holdout_group_ids: List of group IDs to hold out (e.g. [10, 25, 47]) + targets_df: Optional DataFrame with geographic info for hierarchical checks + check_hierarchical: If True and targets_df provided, analyze hierarchical consistency + epochs: Training epochs + lambda_l0: L0 regularization parameter + lr: Learning rate + verbose_spacing: How often to print progress + device: 'cuda' for GPU, 'cpu' for CPU + + Returns: + Dictionary with train/holdout losses, summary stats, and optionally hierarchical analysis + """ + + # Model parameters (matching calibrate_cds_sparse.py) + model_params = { + "beta": 2 / 3, + "gamma": -0.1, + "zeta": 1.1, + "init_keep_prob": 0.999, + "init_weights": init_weights, + "log_weight_jitter_sd": 0.05, + "log_alpha_jitter_sd": 0.01, + "device": device, # Pass device to model + } + + training_params = { + "lambda_l0": lambda_l0, + "lambda_l2": 0, + "lr": lr, + "epochs": epochs, + "loss_type": "relative", + "verbose": True, + "verbose_freq": verbose_spacing, + } + + # Use the existing run_holdout_experiment function + results = run_holdout_experiment( + X_sparse=X_sparse, + targets=targets, + target_groups=target_groups, + holdout_group_indices=holdout_group_ids, + model_params=model_params, + training_params=training_params, + ) + + # Add hierarchical consistency check if requested + if check_hierarchical and targets_df is not None: + # Get training group IDs (all groups not in holdout) + all_group_ids = set(np.unique(target_groups)) + training_group_ids = list(all_group_ids - set(holdout_group_ids)) + + # Compute aggregate losses + aggregate_results = compute_aggregate_losses( + X_sparse=X_sparse, + weights=results["weights"], + targets_df=targets_df, + target_groups=target_groups, + training_group_ids=training_group_ids, + holdout_group_ids=holdout_group_ids, + ) + + # Add to results + results["train_aggregate_losses"] = aggregate_results[ + "train_aggregate_losses" + ] + results["holdout_aggregate_losses"] = aggregate_results[ + "holdout_aggregate_losses" + ] + + # Print summary if available + if ( + aggregate_results["train_aggregate_losses"] + or aggregate_results["holdout_aggregate_losses"] + ): + print("\n" + "=" * 60) + print("HIERARCHICAL AGGREGATION PERFORMANCE") + print("=" * 60) + + # Show training group aggregates + if aggregate_results["train_aggregate_losses"]: + print("\nTraining groups (CD→State/National aggregation):") + for group_id, losses in list( + aggregate_results["train_aggregate_losses"].items() + )[:5]: + print(f" Group {group_id}:", end="") + if "state" in losses: + print(f" State={losses['state']:.2%}", end="") + if "national" in losses: + print(f" National={losses['national']:.2%}", end="") + print() + + # Show holdout group aggregates + if aggregate_results["holdout_aggregate_losses"]: + print("\nHoldout groups (CD→State/National aggregation):") + for group_id, losses in list( + aggregate_results["holdout_aggregate_losses"].items() + )[:5]: + print(f" Group {group_id}:", end="") + if "state" in losses: + print(f" State={losses['state']:.2%}", end="") + if "national" in losses: + print(f" National={losses['national']:.2%}", end="") + print() + print( + " → Good performance here shows hierarchical generalization!" + ) + + return results diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/household_tracer.py b/policyengine_us_data/datasets/cps/local_area_calibration/household_tracer.py new file mode 100644 index 00000000..b2525e02 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/household_tracer.py @@ -0,0 +1,1087 @@ +""" +Household tracer utility for debugging geo-stacking sparse matrices. + +This utility allows tracing a single household through the complex stacked matrix +structure to verify values match sim.calculate results. + +USAGE +===== + +Basic Setup (from calibration package): + + import pickle + from household_tracer import HouseholdTracer + + # Load calibration package + with open('calibration_package.pkl', 'rb') as f: + data = pickle.load(f) + + # Extract components + X_sparse = data['X_sparse'] + targets_df = data['targets_df'] + household_id_mapping = data['household_id_mapping'] + cds_to_calibrate = data['cds_to_calibrate'] + # Note: you also need 'sim' (Microsimulation instance) + + # Create tracer + tracer = HouseholdTracer( + targets_df, X_sparse, household_id_mapping, + cds_to_calibrate, sim + ) + +Common Operations: + + # 1. Understand what a column represents + col_info = tracer.get_column_info(100) + # Returns: {'column_index': 100, 'cd_geoid': '101', + # 'household_id': 100, 'household_index': 99} + + # 2. Access full column catalog (all column mappings) + tracer.column_catalog # DataFrame with all 4.6M column mappings + + # 3. Find where a household appears across all CDs + positions = tracer.get_household_column_positions(565) + # Returns: {'101': 564, '102': 11144, '201': 21724, ...} + + # 4. Look up a specific matrix cell with full context + cell = tracer.lookup_matrix_cell(row_idx=50, col_idx=100) + # Returns complete info about target, household, and value + + # 5. Get info about a row (target) + row_info = tracer.get_row_info(50) + + # 6. View matrix structure + tracer.print_matrix_structure() + + # 7. View column/row catalogs + tracer.print_column_catalog(max_rows=50) + tracer.print_row_catalog(max_rows=50) + + # 8. Trace all target values for a specific household + household_targets = tracer.trace_household_targets(565) + + # 9. Get targets by group + from calibration_utils import create_target_groups + tracer.target_groups, _ = create_target_groups(targets_df) + group_31 = tracer.get_group_rows(31) # Person count targets + +Matrix Structure: + + Columns are organized as: [CD1_households | CD2_households | ... | CD436_households] + Each CD block has n_households columns (e.g., 10,580 households) + + Formula to find column index: + column_idx = cd_block_number × n_households + household_index + + Example: Household at index 12 in CD block 371: + column_idx = 371 × 10580 + 12 = 3,925,192 +""" + +import logging +import pandas as pd +import numpy as np +from typing import Dict, List, Tuple, Optional +from scipy import sparse + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import create_target_groups +from policyengine_us_data.datasets.cps.local_area_calibration.metrics_matrix_geo_stacking_sparse import SparseGeoStackingMatrixBuilder +from policyengine_us import Microsimulation +from sqlalchemy import create_engine, text + + +logger = logging.getLogger(__name__) + + +class HouseholdTracer: + """Trace households through geo-stacked sparse matrices for debugging.""" + + def __init__( + self, + targets_df: pd.DataFrame, + matrix: sparse.csr_matrix, + household_id_mapping: Dict[str, List[str]], + geographic_ids: List[str], + sim, + ): + """ + Initialize tracer with matrix components. + + Args: + targets_df: DataFrame of all targets + matrix: The final stacked sparse matrix + household_id_mapping: Mapping from geo keys to household ID lists + geographic_ids: List of geographic IDs in order + sim: Microsimulation instance + """ + self.targets_df = targets_df + self.matrix = matrix + self.household_id_mapping = household_id_mapping + self.geographic_ids = geographic_ids + self.sim = sim + + # Get original household info + self.original_household_ids = sim.calculate("household_id").values + self.n_households = len(self.original_household_ids) + self.n_geographies = len(geographic_ids) + + # Build reverse lookup: original_hh_id -> index in original data + self.hh_id_to_index = { + hh_id: idx for idx, hh_id in enumerate(self.original_household_ids) + } + + # Build column catalog: maps column index -> (cd_geoid, household_id, household_index) + self.column_catalog = self._build_column_catalog() + + # Build row catalog: maps row index -> target info + self.row_catalog = self._build_row_catalog() + + logger.info( + f"Tracer initialized: {self.n_households} households x {self.n_geographies} geographies" + ) + logger.info(f"Matrix shape: {matrix.shape}") + + def _build_column_catalog(self) -> pd.DataFrame: + """Build a complete catalog of all matrix columns.""" + catalog = [] + col_idx = 0 + + for geo_id in self.geographic_ids: + for hh_idx, hh_id in enumerate(self.original_household_ids): + catalog.append( + { + "column_index": col_idx, + "cd_geoid": geo_id, + "household_id": hh_id, + "household_index": hh_idx, + } + ) + col_idx += 1 + + return pd.DataFrame(catalog) + + def _build_row_catalog(self) -> pd.DataFrame: + """Build a complete catalog of all matrix rows (targets).""" + catalog = [] + + for row_idx, (_, target) in enumerate(self.targets_df.iterrows()): + catalog.append( + { + "row_index": row_idx, + "variable": target["variable"], + "variable_desc": target.get( + "variable_desc", target["variable"] + ), + "geographic_id": target.get("geographic_id", "unknown"), + "geographic_level": target.get( + "geographic_level", "unknown" + ), + "target_value": target["value"], + "stratum_id": target.get("stratum_id"), + "stratum_group_id": target.get( + "stratum_group_id", "unknown" + ), + } + ) + + return pd.DataFrame(catalog) + + def get_column_info(self, col_idx: int) -> Dict: + """Get information about a specific column.""" + if col_idx >= len(self.column_catalog): + raise ValueError( + f"Column index {col_idx} out of range (max: {len(self.column_catalog)-1})" + ) + return self.column_catalog.iloc[col_idx].to_dict() + + def get_row_info(self, row_idx: int) -> Dict: + """Get information about a specific row (target).""" + if row_idx >= len(self.row_catalog): + raise ValueError( + f"Row index {row_idx} out of range (max: {len(self.row_catalog)-1})" + ) + return self.row_catalog.iloc[row_idx].to_dict() + + def lookup_matrix_cell(self, row_idx: int, col_idx: int) -> Dict: + """ + Look up a specific matrix cell and return complete context. + + Args: + row_idx: Row index in matrix + col_idx: Column index in matrix + + Returns: + Dict with row info, column info, and matrix value + """ + row_info = self.get_row_info(row_idx) + col_info = self.get_column_info(col_idx) + matrix_value = self.matrix[row_idx, col_idx] + + return { + "row_index": row_idx, + "column_index": col_idx, + "matrix_value": float(matrix_value), + "target": row_info, + "household": col_info, + } + + def print_column_catalog(self, max_rows: int = 50): + """Print a sample of the column catalog.""" + print( + f"\nColumn Catalog (showing first {max_rows} of {len(self.column_catalog)}):" + ) + print(self.column_catalog.head(max_rows).to_string(index=False)) + + def print_row_catalog(self, max_rows: int = 50): + """Print a sample of the row catalog.""" + print( + f"\nRow Catalog (showing first {max_rows} of {len(self.row_catalog)}):" + ) + print(self.row_catalog.head(max_rows).to_string(index=False)) + + def print_matrix_structure(self, create_groups=True): + """Print a comprehensive breakdown of the matrix structure.""" + print("\n" + "=" * 80) + print("MATRIX STRUCTURE BREAKDOWN") + print("=" * 80) + + print( + f"\nMatrix dimensions: {self.matrix.shape[0]} rows × {self.matrix.shape[1]} columns" + ) + print(f" Rows = {len(self.row_catalog)} targets") + print( + f" Columns = {self.n_households} households × {self.n_geographies} CDs" + ) + print( + f" = {self.n_households:,} × {self.n_geographies} = {self.matrix.shape[1]:,}" + ) + + print("\n" + "-" * 80) + print("COLUMN STRUCTURE (Households stacked by CD)") + print("-" * 80) + + # Build column ranges by CD + col_ranges = [] + cumulative = 0 + for geo_id in self.geographic_ids: + start_col = cumulative + end_col = cumulative + self.n_households - 1 + col_ranges.append( + { + "cd_geoid": geo_id, + "start_col": start_col, + "end_col": end_col, + "n_households": self.n_households, + "example_household_id": self.original_household_ids[0], + } + ) + cumulative += self.n_households + + ranges_df = pd.DataFrame(col_ranges) + print(f"\nShowing first and last 10 CDs of {len(ranges_df)} total:") + print("\nFirst 10 CDs:") + print(ranges_df.head(10).to_string(index=False)) + print("\nLast 10 CDs:") + print(ranges_df.tail(10).to_string(index=False)) + + print("\n" + "-" * 80) + print("ROW STRUCTURE (Targets by geography and variable)") + print("-" * 80) + + # Summarize rows by geographic level + row_summary = ( + self.row_catalog.groupby(["geographic_level", "geographic_id"]) + .size() + .reset_index(name="n_targets") + ) + + print(f"\nTargets by geographic level:") + geo_level_summary = ( + self.row_catalog.groupby("geographic_level") + .size() + .reset_index(name="n_targets") + ) + print(geo_level_summary.to_string(index=False)) + + print(f"\nTargets by stratum group:") + stratum_summary = ( + self.row_catalog.groupby("stratum_group_id") + .agg({"row_index": "count", "variable": lambda x: len(x.unique())}) + .rename( + columns={"row_index": "n_targets", "variable": "n_unique_vars"} + ) + ) + print(stratum_summary.to_string()) + + # Create and display target groups like calibrate_cds_sparse.py + if create_groups: + print("\n" + "-" * 80) + print("TARGET GROUPS (for loss calculation)") + print("-" * 80) + + target_groups, group_info = create_target_groups(self.targets_df) + + # Store target groups for later use + self.target_groups = target_groups + + # Use the improved labels from create_target_groups + for group_id, info in enumerate(group_info): + # Get row indices for this group + group_mask = target_groups == group_id + row_indices = np.where(group_mask)[0] + + # Format row indices for display + if len(row_indices) > 6: + row_display = f"[{row_indices[0]}, {row_indices[1]}, {row_indices[2]}, '...', {row_indices[-2]}, {row_indices[-1]}]" + else: + row_display = str(row_indices.tolist()) + + print(f" {info} - rows {row_display}") + + print("\n" + "=" * 80) + + def get_group_rows(self, group_id: int) -> pd.DataFrame: + """ + Get all rows (targets) for a specific target group. + + Args: + group_id: The target group ID + + Returns: + DataFrame with all targets in that group + """ + if not hasattr(self, "target_groups"): + self.target_groups, _ = create_target_groups(self.targets_df) + + group_mask = self.target_groups == group_id + group_targets = self.targets_df[group_mask].copy() + + # Add row indices + row_indices = np.where(group_mask)[0] + group_targets["row_index"] = row_indices + + # Reorder columns for clarity + cols = [ + "row_index", + "variable", + "geographic_id", + "value", + "description", + ] + cols = [c for c in cols if c in group_targets.columns] + group_targets = group_targets[cols] + + return group_targets + + def get_household_column_positions( + self, original_hh_id: int + ) -> Dict[str, int]: + """ + Get all column positions for a household across all geographies. + + Args: + original_hh_id: Original household ID from simulation + + Returns: + Dict mapping geo_id to column position in stacked matrix + """ + if original_hh_id not in self.hh_id_to_index: + raise ValueError( + f"Household {original_hh_id} not found in original data" + ) + + # Get the household's index in the original data + hh_index = self.hh_id_to_index[original_hh_id] + + # Calculate column positions for each geography + positions = {} + for geo_idx, geo_id in enumerate(self.geographic_ids): + # Each geography gets a block of n_households columns + col_position = geo_idx * self.n_households + hh_index + positions[geo_id] = col_position + + return positions + + def trace_household_targets(self, original_hh_id: int) -> pd.DataFrame: + """ + Extract all target values for a household across all geographies. + + Args: + original_hh_id: Original household ID to trace + + Returns: + DataFrame with target details and values for this household + """ + positions = self.get_household_column_positions(original_hh_id) + + results = [] + + for target_idx, (_, target) in enumerate(self.targets_df.iterrows()): + target_result = { + "target_idx": target_idx, + "variable": target["variable"], + "target_value": target["value"], + "geographic_id": target.get("geographic_id", "unknown"), + "stratum_group_id": target.get("stratum_group_id", "unknown"), + "description": target.get("description", ""), + } + + # Extract values for this target across all geographies + for geo_id, col_pos in positions.items(): + if col_pos < self.matrix.shape[1]: + matrix_value = self.matrix[target_idx, col_pos] + target_result[f"matrix_value_{geo_id}"] = matrix_value + else: + target_result[f"matrix_value_{geo_id}"] = np.nan + + results.append(target_result) + + return pd.DataFrame(results) + + def verify_household_target( + self, original_hh_id: int, target_idx: int, geo_id: str + ) -> Dict: + """ + Verify a specific target value for a household by comparing with sim.calculate. + + Args: + original_hh_id: Original household ID + target_idx: Target row index in matrix + geo_id: Geographic ID to check + + Returns: + Dict with verification results + """ + # Get target info + target = self.targets_df.iloc[target_idx] + variable = target["variable"] + stratum_id = target["stratum_id"] + + # Get matrix value + positions = self.get_household_column_positions(original_hh_id) + col_pos = positions[geo_id] + matrix_value = self.matrix[target_idx, col_pos] + + # Calculate expected value using sim + # Import the matrix builder to access constraint methods + + # We need a builder instance to get constraints + # This is a bit hacky but necessary for verification + db_uri = "sqlite:////home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db" + builder = SparseGeoStackingMatrixBuilder(db_uri) + + # Get constraints for this stratum + constraints_df = builder.get_constraints_for_stratum(stratum_id) + + # Calculate what the value should be for this household + expected_value = self._calculate_expected_value( + original_hh_id, variable, constraints_df + ) + + return { + "household_id": original_hh_id, + "target_idx": target_idx, + "geo_id": geo_id, + "variable": variable, + "stratum_id": stratum_id, + "matrix_value": float(matrix_value), + "expected_value": float(expected_value), + "matches": abs(matrix_value - expected_value) < 1e-6, + "difference": float(matrix_value - expected_value), + "constraints": ( + constraints_df.to_dict("records") + if not constraints_df.empty + else [] + ), + } + + def _calculate_expected_value( + self, original_hh_id: int, variable: str, constraints_df: pd.DataFrame + ) -> float: + """ + Calculate expected value for a household given variable and constraints. + """ + # Get household index + hh_index = self.hh_id_to_index[original_hh_id] + + # Get target entity + target_entity = self.sim.tax_benefit_system.variables[ + variable + ].entity.key + + # Check if household satisfies all constraints + satisfies_constraints = True + + for _, constraint in constraints_df.iterrows(): + var = constraint["constraint_variable"] + op = constraint["operation"] + val = constraint["value"] + + # Skip geographic constraints (they're handled by matrix structure) + if var in ["state_fips", "congressional_district_geoid"]: + continue + + # Get constraint value for this household + constraint_entity = self.sim.tax_benefit_system.variables[ + var + ].entity.key + if constraint_entity == "person": + # For person variables, check if any person in household satisfies + person_values = self.sim.calculate(var, map_to="person").values + household_ids_person_level = self.sim.calculate( + "household_id", map_to="person" + ).values + + # Get person values for this household + household_mask = household_ids_person_level == original_hh_id + household_person_values = person_values[household_mask] + + # Parse constraint value + try: + parsed_val = float(val) + if parsed_val.is_integer(): + parsed_val = int(parsed_val) + except ValueError: + if val == "True": + parsed_val = True + elif val == "False": + parsed_val = False + else: + parsed_val = val + + # Check if any person in household satisfies constraint + if op == "==" or op == "=": + person_satisfies = household_person_values == parsed_val + elif op == ">": + person_satisfies = household_person_values > parsed_val + elif op == ">=": + person_satisfies = household_person_values >= parsed_val + elif op == "<": + person_satisfies = household_person_values < parsed_val + elif op == "<=": + person_satisfies = household_person_values <= parsed_val + elif op == "!=": + person_satisfies = household_person_values != parsed_val + else: + continue + + if not person_satisfies.any(): + satisfies_constraints = False + break + + else: + # For household/tax_unit variables, get value directly + if constraint_entity == "household": + constraint_value = self.sim.calculate(var).values[hh_index] + else: + # For tax_unit, map to household level + constraint_value = self.sim.calculate( + var, map_to="household" + ).values[hh_index] + + # Parse constraint value + try: + parsed_val = float(val) + if parsed_val.is_integer(): + parsed_val = int(parsed_val) + except ValueError: + if val == "True": + parsed_val = True + elif val == "False": + parsed_val = False + else: + parsed_val = val + + # Check constraint + if op == "==" or op == "=": + if not (constraint_value == parsed_val): + satisfies_constraints = False + break + elif op == ">": + if not (constraint_value > parsed_val): + satisfies_constraints = False + break + elif op == ">=": + if not (constraint_value >= parsed_val): + satisfies_constraints = False + break + elif op == "<": + if not (constraint_value < parsed_val): + satisfies_constraints = False + break + elif op == "<=": + if not (constraint_value <= parsed_val): + satisfies_constraints = False + break + elif op == "!=": + if not (constraint_value != parsed_val): + satisfies_constraints = False + break + + if not satisfies_constraints: + return 0.0 + + # If constraints satisfied, get the target value + if target_entity == "household": + target_value = self.sim.calculate(variable).values[hh_index] + elif target_entity == "person": + # For person variables, sum over household members + person_values = self.sim.calculate( + variable, map_to="person" + ).values + household_ids_person_level = self.sim.calculate( + "household_id", map_to="person" + ).values + household_mask = household_ids_person_level == original_hh_id + target_value = person_values[household_mask].sum() + else: + # For tax_unit variables, map to household + target_value = self.sim.calculate( + variable, map_to="household" + ).values[hh_index] + + return float(target_value) + + def audit_household( + self, original_hh_id: int, max_targets: int = 10 + ) -> Dict: + """ + Comprehensive audit of a household across all targets and geographies. + + Args: + original_hh_id: Household ID to audit + max_targets: Maximum number of targets to verify in detail + + Returns: + Dict with audit results + """ + logger.info(f"Auditing household {original_hh_id}") + + # Get basic info + positions = self.get_household_column_positions(original_hh_id) + all_values = self.trace_household_targets(original_hh_id) + + # Verify a sample of targets + verifications = [] + target_sample = min(max_targets, len(self.targets_df)) + + for target_idx in range( + 0, + len(self.targets_df), + max(1, len(self.targets_df) // target_sample), + ): + for geo_id in self.geographic_ids[ + :2 + ]: # Limit to first 2 geographies + try: + verification = self.verify_household_target( + original_hh_id, target_idx, geo_id + ) + verifications.append(verification) + except Exception as e: + logger.warning( + f"Could not verify target {target_idx} for geo {geo_id}: {e}" + ) + + # Summary statistics + if verifications: + matches = [v["matches"] for v in verifications] + match_rate = sum(matches) / len(matches) + max_diff = max([abs(v["difference"]) for v in verifications]) + else: + match_rate = 0.0 + max_diff = 0.0 + + return { + "household_id": original_hh_id, + "column_positions": positions, + "all_target_values": all_values, + "verifications": verifications, + "summary": { + "total_verifications": len(verifications), + "match_rate": match_rate, + "max_difference": max_diff, + "passes_audit": match_rate > 0.95 and max_diff < 1e-3, + }, + } + + +def matrix_tracer(): + """Demo the household tracer.""" + + # Setup - match calibrate_cds_sparse.py configuration exactly + db_uri = "sqlite:////home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db" + builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + sim = Microsimulation(dataset="/home/baogorek/devl/stratified_10k.h5") + + hh_person_rel = pd.DataFrame( + { + "household_id": sim.calculate("household_id", map_to="person"), + "person_id": sim.calculate("person_id", map_to="person"), + } + ) + + # Get all congressional districts from database (like calibrate_cds_sparse.py does) + engine = create_engine(db_uri) + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' + ORDER BY sc.value + """ + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + all_cd_geoids = [row[0] for row in result] + + targets_df, matrix, household_mapping = ( + builder.build_stacked_matrix_sparse( + "congressional_district", all_cd_geoids, sim + ) + ) + target_groups, y = create_target_groups(targets_df) + + tracer = HouseholdTracer( + targets_df, matrix, household_mapping, all_cd_geoids, sim + ) + tracer.print_matrix_structure() + + # Testing national targets with a test household ----------------- + test_household = sim.calculate("household_id").values[100] + positions = tracer.get_household_column_positions(test_household) + + # Row 0: Alimony - Row 0 + matrix_hh_position = positions["3910"] + matrix[0, matrix_hh_position] + + # Row 0: Alimony - Row 0 + matrix_hh_position = positions["3910"] + matrix[0, matrix_hh_position] + + # Group 32: Medicaid Enrollment (436 targets across 436 geographies) - rows [69, 147, 225, '...', 33921, 33999] + group_32_mask = target_groups == 32 + group_32_targets = targets_df[group_32_mask].copy() + group_32_targets["row_index"] = np.where(group_32_mask)[0] + group_32_targets[ + [ + "target_id", + "stratum_id", + "value", + "original_value", + "geographic_id", + "variable_desc", + "uprating_factor", + "reconciliation_factor", + ] + ] + + # Note that Medicaid reporting in the surveys can sometimes be higher than the administrative totals + # Alabama is one of the states that has not expanded Medicaid under the Affordable Care Act (ACA). + # People in the gap might confuse + group_32_targets.reconciliation_factor.describe() + + cd_101_medicaid = group_32_targets[ + group_32_targets["geographic_id"] == "101" + ] + row_idx = cd_101_medicaid["row_index"].values[0] + target_value = cd_101_medicaid["value"].values[0] + + medicaid_df = sim.calculate_dataframe( + ["household_id", "medicaid"], map_to="household" + ) + medicaid_households = medicaid_df[medicaid_df["medicaid"] > 0] + + test_hh = int(medicaid_households.iloc[0]["household_id"]) + medicaid_df.loc[medicaid_df.household_id == test_hh] + positions = tracer.get_household_column_positions(test_hh) + col_idx = positions["101"] + matrix[row_idx, positions["101"]] # Should be > 0 + matrix[row_idx, positions["102"]] # Should be zero + + # But Medicaid is a person count concept. In this case, the number is 2.0 + hh_person_rel.loc[hh_person_rel.household_id == test_hh] + + person_medicaid_df = sim.calculate_dataframe( + ["person_id", "medicaid", "medicaid_enrolled"], map_to="person" + ) + person_medicaid_df.loc[person_medicaid_df.person_id.isin([56001, 56002])] + # Note that it's medicaid_enrolled that we're counting for the metrics matrix. + + # Group 43: Tax Units qualified_business_income_deduction>0 (436 targets across 436 geographies) - rows [88, 166, 244, '...', 33940, 34018] + # Note that this is the COUNT of > 0 + group_43_mask = target_groups == 43 + group_43_targets = targets_df[group_43_mask].copy() + group_43_targets["row_index"] = np.where(group_43_mask)[0] + group_43_targets[ + [ + "target_id", + "stratum_id", + "value", + "original_value", + "geographic_id", + "variable_desc", + "uprating_factor", + "reconciliation_factor", + ] + ] + + cd_101_qbid = group_43_targets[group_43_targets["geographic_id"] == "101"] + row_idx = cd_101_qbid["row_index"].values[0] + target_value = cd_101_qbid["value"].values[0] + + qbid_df = sim.calculate_dataframe( + ["household_id", "qualified_business_income_deduction"], + map_to="household", + ) + qbid_households = qbid_df[ + qbid_df["qualified_business_income_deduction"] > 0 + ] + + # Check matrix for a specific QBID household + test_hh = int(qbid_households.iloc[0]["household_id"]) + positions = tracer.get_household_column_positions(test_hh) + col_idx = positions["101"] + matrix[row_idx, positions["101"]] # Should be 1.0 + matrix[row_idx, positions["102"]] # Should be zero + + qbid_df.loc[qbid_df.household_id == test_hh] + hh_person_rel.loc[hh_person_rel.household_id == test_hh] + + # Group 66: Qualified Business Income Deduction (436 targets across 436 geographies) - rows [70, 148, 226, '...', 33922, 34000] + # This is the amount! + group_66_mask = target_groups == 66 + group_66_targets = targets_df[group_66_mask].copy() + group_66_targets["row_index"] = np.where(group_66_mask)[0] + group_66_targets[ + [ + "target_id", + "stratum_id", + "value", + "original_value", + "geographic_id", + "variable_desc", + "uprating_factor", + "reconciliation_factor", + ] + ] + + cd_101_qbid_amount = group_66_targets[ + group_66_targets["geographic_id"] == "101" + ] + row_idx = cd_101_qbid_amount["row_index"].values[0] + target_value = cd_101_qbid_amount["value"].values[0] + + matrix[row_idx, positions["101"]] # Should > 1.0 + matrix[row_idx, positions["102"]] # Should be zero + + # Group 60: Household Count (436 targets across 436 geographies) - rows [36, 114, 192, '...', 33888, 33966] + group_60_mask = target_groups == 60 + group_60_targets = targets_df[group_60_mask].copy() + group_60_targets["row_index"] = np.where(group_60_mask)[0] + group_60_targets[ + [ + "target_id", + "stratum_id", + "value", + "original_value", + "geographic_id", + "variable_desc", + "uprating_factor", + "reconciliation_factor", + ] + ] + + cd_101_snap = group_60_targets[group_60_targets["geographic_id"] == "101"] + row_idx = cd_101_snap["row_index"].values[0] + target_value = cd_101_snap["value"].values[0] + + # Find households with SNAP > 0 + snap_df = sim.calculate_dataframe( + ["household_id", "snap"], map_to="household" + ) + snap_households = snap_df[snap_df["snap"] > 0] + + # Check matrix for a specific SNAP household + test_hh = int(snap_households.iloc[0]["household_id"]) + positions = tracer.get_household_column_positions(test_hh) + col_idx = positions["101"] + matrix[row_idx, positions["101"]] # Should be > 0 + matrix[row_idx, positions["102"]] # Should be zero + + # Check non-SNAP household + non_snap_hh = snap_df[snap_df["snap"] == 0].iloc[0]["household_id"] + non_snap_positions = tracer.get_household_column_positions(non_snap_hh) + matrix[row_idx, non_snap_positions["101"]] # should be 0 + + # Group 73: Snap Cost at State Level (51 targets across 51 geographies) - rows 34038-34088 ----------- + group_73_mask = target_groups == 73 + group_73_targets = targets_df[group_73_mask].copy() + group_73_targets["row_index"] = np.where(group_73_mask)[0] + + state_snap = group_73_targets[ + group_73_targets["geographic_id"] == "1" + ] # Delaware + row_idx = state_snap["row_index"].values[0] + target_value = state_snap["value"].values[0] + + snap_value = matrix[row_idx, col_idx] + snap_value + + # AGI target exploration -------- + test_household = 565 + positions = tracer.get_household_column_positions(test_household) + row_idx = 27268 + one_target = targets_df.iloc[row_idx] + test_variable = one_target.variable + print(one_target.variable_desc) + print(one_target.value) + + # Get value for test household in CD 101 + matrix_hh_position = positions["101"] + value_correct = matrix[row_idx, matrix_hh_position] + print(f"Household {test_household} in CD 3910: {value_correct}") + + # Get value for same household but wrong CD (e.g., '1001') + matrix_hh_position_1001 = positions["1001"] + value_incorrect = matrix[row_idx_3910, matrix_hh_position_1001] + print(f"Household {test_household} in CD 1001 (wrong!): {value_incorrect}") + + df = sim.calculate_dataframe( + ["household_id", test_variable, "adjusted_gross_income"], + map_to="household", + ) + df.loc[df.household_id == test_household] + + # Row 78: Taxable Pension Income --------------------------------------------------------- + group_78 = tracer.get_group_rows(78) + cd_3910_target = group_78[group_78["geographic_id"] == "3910"] + + row_idx_3910 = cd_3910_target["row_index"].values[0] + print(f"Taxable Pension Income for CD 3910 is at row {row_idx_3910}") + + # Check here ------ + targets_df.iloc[row_idx_3910] + cd_3910_target + + test_variable = targets_df.iloc[row_idx_3910].variable + + # Get value for household in CD 3910 + matrix_hh_position_3910 = positions["3910"] + value_correct = matrix[row_idx_3910, matrix_hh_position_3910] + print(f"Household {test_household} in CD 3910: {value_correct}") + + # Get value for same household but wrong CD (e.g., '1001') + matrix_hh_position_1001 = positions["1001"] + value_incorrect = matrix[row_idx_3910, matrix_hh_position_1001] + print(f"Household {test_household} in CD 1001 (wrong!): {value_incorrect}") + + df = sim.calculate_dataframe( + ["household_id", test_variable], map_to="household" + ) + df.loc[df.household_id == test_household][[test_variable]] + + df.loc[df[test_variable] > 0] + + # Get all target values + all_values = tracer.trace_household_targets(test_household) + print(f"\nFound values for {len(all_values)} targets") + print(all_values.head()) + + # Verify a specific target + verification = tracer.verify_household_target( + test_household, 0, test_cds[0] + ) + print(f"\nVerification result: {verification}") + + # Full audit (TODO: not working, or at least wasn't working, on *_count metrics and targets) + audit = tracer.audit_household(test_household, max_targets=5) + print(f"\nAudit summary: {audit['summary']}") + + +def h5_tracer(): + import pandas as pd + from policyengine_us import Microsimulation + + # --- 1. Setup: Load simulations and mapping file --- + + # Paths to the datasets and mapping file + new_dataset_path = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/datasets/cps/geo_stacking_calibration/temp/RI.h5" + original_dataset_path = "/home/baogorek/devl/stratified_10k.h5" + mapping_file_path = "./temp/RI_household_mapping.csv" + + # Initialize the two microsimulations + sim_new = Microsimulation(dataset=new_dataset_path) + sim_orig = Microsimulation(dataset=original_dataset_path) + + # Load the household ID mapping file + mapping_df = pd.read_csv(mapping_file_path) + + # --- 2. Identify households for comparison --- + + # Specify the household ID from the NEW dataset to test + test_hh_new = 2741169 + + # Find the corresponding ORIGINAL household ID using the mapping file + test_hh_orig = mapping_df.loc[ + mapping_df.new_household_id == test_hh_new + ].original_household_id.values[0] + + print( + f"Comparing new household '{test_hh_new}' with original household '{test_hh_orig}'\n" + ) + + # --- 3. Compare household-level data --- + + # Define the variables to analyze at the household level + household_vars = [ + "household_id", + "state_fips", + "congressional_district_geoid", + "adjusted_gross_income", + ] + + # Calculate dataframes for both simulations + df_new = sim_new.calculate_dataframe(household_vars, map_to="household") + df_orig = sim_orig.calculate_dataframe(household_vars, map_to="household") + + # Filter for the specific households + household_new_data = df_new.loc[df_new.household_id == test_hh_new] + household_orig_data = df_orig.loc[df_orig.household_id == test_hh_orig] + + print("--- Household-Level Comparison ---") + print("\nData from New Simulation (RI.h5):") + print(household_new_data) + print("\nData from Original Simulation (stratified_10k.h5):") + print(household_orig_data) + + # --- 4. Compare person-level data --- + + # A helper function to create a person-level dataframe from a simulation + def get_person_df(simulation): + return pd.DataFrame( + { + "household_id": simulation.calculate( + "household_id", map_to="person" + ), + "person_id": simulation.calculate( + "person_id", map_to="person" + ), + "age": simulation.calculate("age", map_to="person"), + } + ) + + # Get person-level dataframes + df_person_new = get_person_df(sim_new) + df_person_orig = get_person_df(sim_orig) + + # Filter for the members of the specific households + persons_new = df_person_new.loc[df_person_new.household_id == test_hh_new] + persons_orig = df_person_orig.loc[ + df_person_orig.household_id == test_hh_orig + ] + + print("\n\n--- Person-Level Comparison ---") + print("\nData from New Simulation (RI.h5):") + print(persons_new) + print("\nData from Original Simulation (stratified_10k.h5):") + print(persons_orig) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/metrics_matrix_geo_stacking_sparse.py b/policyengine_us_data/datasets/cps/local_area_calibration/metrics_matrix_geo_stacking_sparse.py new file mode 100644 index 00000000..9aa24c35 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/metrics_matrix_geo_stacking_sparse.py @@ -0,0 +1,2468 @@ +""" +Sparse geo-stacking calibration matrix creation for PolicyEngine US. + +This module creates calibration matrices for the geo-stacking approach where +the same household dataset is treated as existing in multiple geographic areas. +Targets are rows, households are columns (small n, large p formulation). + +This version builds sparse matrices directly, avoiding dense intermediate structures. +""" + +import logging +from typing import Dict, List, Optional, Tuple +import numpy as np +import pandas as pd +from scipy import sparse +from sqlalchemy import create_engine, text +from sqlalchemy.orm import Session + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + get_calculated_variables, +) + + +logger = logging.getLogger(__name__) + + +def get_us_state_dependent_variables(): + """ + Return list of variables that should be calculated US-state-specifically. + + These are variables whose values depend on US state policy rules, + so the same household can have different values in different states. + + NOTE: Only include variables that are CALCULATED based on state policy. + Variables based on INPUT data (like salt_deduction, which uses + state_withheld_income_tax as an input) will NOT vary when state changes. + + Returns: + List of variable names that are US-state-dependent + """ + return ['snap', 'medicaid', 'salt_deduction'] + + +class SparseGeoStackingMatrixBuilder: + """Build sparse calibration matrices for geo-stacking approach. + + NOTE: Period handling is complex due to mismatched data years: + - The enhanced CPS 2024 dataset only contains 2024 data + - Targets in the database exist for different years (2022, 2023, 2024) + - For now, we pull targets from whatever year they exist and use 2024 data + - This temporal mismatch will be addressed in future iterations + """ + + def __init__(self, db_uri: str, time_period: int): + self.db_uri = db_uri + self.engine = create_engine(db_uri) + self.time_period = time_period + self._uprating_factors = None + self._params = None + self._state_specific_cache = {} # Cache for state-specific calculated values: {(hh_id, state_fips, var): value} + + @property + def uprating_factors(self): + """Lazy-load uprating factors from PolicyEngine parameters.""" + # NOTE: this is pretty limited. What kind of CPI? + # In [44]: self._uprating_factors + # Out[44]: + # {(2022, 'cpi'): 1.0641014696885627, + # (2022, 'pop'): 1.009365413037974, + # (2023, 'cpi'): 1.0, + # (2023, 'pop'): 1.0, + # (2024, 'cpi'): 0.9657062435037478, + # (2024, 'pop'): 0.989171581243436, + # (2025, 'cpi'): 0.937584224942492, + # (2025, 'pop'): 0.9892021773614242} + + if self._uprating_factors is None: + self._uprating_factors = self._calculate_uprating_factors() + return self._uprating_factors + + def _calculate_uprating_factors(self): + """Calculate all needed uprating factors from PolicyEngine parameters.""" + from policyengine_us import Microsimulation + + # Get a minimal sim just for parameters + if self._params is None: + sim = Microsimulation() + self._params = sim.tax_benefit_system.parameters + + factors = {} + + # Get unique years from database + query = """ + SELECT DISTINCT period + FROM targets + WHERE period IS NOT NULL + ORDER BY period + """ + with self.engine.connect() as conn: + result = conn.execute(text(query)) + years_needed = [row[0] for row in result] + + logger.info( + f"Calculating uprating factors for years {years_needed} to {self.time_period}" + ) + + for from_year in years_needed: + if from_year == self.time_period: + factors[(from_year, "cpi")] = 1.0 + factors[(from_year, "pop")] = 1.0 + continue + + # CPI factor + try: + cpi_from = self._params.gov.bls.cpi.cpi_u(from_year) + cpi_to = self._params.gov.bls.cpi.cpi_u(self.time_period) + factors[(from_year, "cpi")] = float(cpi_to / cpi_from) + except Exception as e: + logger.warning( + f"Could not calculate CPI factor for {from_year}: {e}" + ) + factors[(from_year, "cpi")] = 1.0 + + # Population factor + try: + pop_from = ( + self._params.calibration.gov.census.populations.total( + from_year + ) + ) + pop_to = self._params.calibration.gov.census.populations.total( + self.time_period + ) + factors[(from_year, "pop")] = float(pop_to / pop_from) + except Exception as e: + logger.warning( + f"Could not calculate population factor for {from_year}: {e}" + ) + factors[(from_year, "pop")] = 1.0 + + # Log the factors + for (year, type_), factor in sorted(factors.items()): + if factor != 1.0: + logger.info( + f" {year} -> {self.time_period} ({type_}): {factor:.4f}" + ) + + return factors + + def _get_uprating_info(self, variable: str, period: int): + """ + Get uprating factor and type for a single variable. + Returns (factor, uprating_type) + """ + if period == self.time_period: + return 1.0, "none" + + # Determine uprating type based on variable name + count_indicators = [ + "count", + "person", + "people", + "households", + "tax_units", + ] + is_count = any( + indicator in variable.lower() for indicator in count_indicators + ) + uprating_type = "pop" if is_count else "cpi" + + # Get factor from pre-calculated dict + factor = self.uprating_factors.get((period, uprating_type), 1.0) + + return factor, uprating_type + + def _calculate_state_specific_values(self, dataset_path: str, variables_to_calculate: List[str]): + """ + Pre-calculate state-specific values for variables that depend on state policy. + + Creates a FRESH simulation for each state to avoid PolicyEngine caching issues. + This ensures calculated variables like salt_deduction are properly recomputed + with the new state's policy rules. + + Args: + dataset_path: Path to the dataset file (e.g., stratified_10k.h5) + variables_to_calculate: List of variable names to calculate state-specifically + + Returns: + None (populates self._state_specific_cache) + """ + import gc + from policyengine_us import Microsimulation + + # State FIPS codes (skipping gaps in numbering) + valid_states = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 55, 56] + + # Get household IDs from a temporary sim (they're constant across states) + #temp_sim = Microsimulation(dataset=dataset_path) + sim = Microsimulation(dataset=dataset_path) + household_ids = sim.calculate("household_id", map_to="household").values + n_households = len(household_ids) + + logger.info(f"Calculating state-specific values for {len(variables_to_calculate)} variables " + f"across {n_households} households and {len(valid_states)} states...") + logger.info(f"This will create {n_households * len(valid_states) * len(variables_to_calculate):,} cached values") + + total_states = len(valid_states) + + # For each state, create a FRESH simulation to avoid caching issues + for state_idx, state_fips in enumerate(valid_states): + # Create brand new simulation for this state + #sim = Microsimulation(dataset=dataset_path) + + # Set ALL households to this state + sim.set_input("state_fips", self.time_period, + np.full(n_households, state_fips, dtype=np.int32)) + # Clear cached calculated variables so state changes propagate + for var in get_calculated_variables(sim): + sim.delete_arrays(var) + + # Calculate each variable for all households in this state + for var_name in variables_to_calculate: + values = sim.calculate(var_name, map_to="household").values + + # Cache all values for this state + for hh_idx, hh_id in enumerate(household_ids): + cache_key = (int(hh_id), int(state_fips), var_name) + self._state_specific_cache[cache_key] = float(values[hh_idx]) + + # Log progress + if (state_idx + 1) % 10 == 0 or state_idx == total_states - 1: + logger.info(f" Progress: {state_idx + 1}/{total_states} states complete") + + + logger.info(f"State-specific cache populated with {len(self._state_specific_cache):,} values") + + def get_best_period_for_targets( + self, query_base: str, params: dict + ) -> int: + """ + Find the best period for targets: closest year <= target_year, + or closest future year if no past years exist. + + Args: + query_base: SQL query that should return period column + params: Parameters for the query + + Returns: + Best period to use, or None if no targets found + """ + # Get all available periods for these targets + period_query = f""" + WITH target_periods AS ( + {query_base} + ) + SELECT DISTINCT period + FROM target_periods + WHERE period IS NOT NULL + ORDER BY period + """ + + with self.engine.connect() as conn: + result = conn.execute(text(period_query), params) + available_periods = [row[0] for row in result.fetchall()] + + if not available_periods: + return None + + # Find best period: closest <= target_year, or closest > target_year + past_periods = [p for p in available_periods if p <= self.time_period] + if past_periods: + # Return the most recent past period (closest to target) + return max(past_periods) + else: + # No past periods, return closest future period + return min(available_periods) + + def get_all_descendant_targets( + self, stratum_id: int, sim=None + ) -> pd.DataFrame: + """ + Recursively get all targets from a stratum and all its descendants. + This handles the new filer stratum layer transparently. + Selects the best period for each target (closest to target_year in the past, or closest future). + """ + query = """ + WITH RECURSIVE descendant_strata AS ( + -- Base case: the stratum itself + SELECT stratum_id + FROM strata + WHERE stratum_id = :stratum_id + + UNION ALL + + -- Recursive case: all children + SELECT s.stratum_id + FROM strata s + JOIN descendant_strata d ON s.parent_stratum_id = d.stratum_id + ), + -- Find best period for each stratum/variable combination + best_periods AS ( + SELECT + t.stratum_id, + t.variable, + CASE + -- If there are periods <= target_year, use the maximum (most recent) + WHEN MAX(CASE WHEN t.period <= :target_year THEN t.period END) IS NOT NULL + THEN MAX(CASE WHEN t.period <= :target_year THEN t.period END) + -- Otherwise use the minimum period (closest future) + ELSE MIN(t.period) + END as best_period + FROM targets t + WHERE t.stratum_id IN (SELECT stratum_id FROM descendant_strata) + GROUP BY t.stratum_id, t.variable + ) + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.active, + t.tolerance, + s.notes as stratum_notes, + s.stratum_group_id, + s.parent_stratum_id, + src.name as source_name, + -- Aggregate constraint info to avoid duplicate rows + (SELECT GROUP_CONCAT(sc2.constraint_variable || sc2.operation || sc2.value, '|') + FROM stratum_constraints sc2 + WHERE sc2.stratum_id = s.stratum_id + GROUP BY sc2.stratum_id) as constraint_info + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + JOIN sources src ON t.source_id = src.source_id + JOIN best_periods bp ON t.stratum_id = bp.stratum_id + AND t.variable = bp.variable + AND t.period = bp.best_period + WHERE s.stratum_id IN (SELECT stratum_id FROM descendant_strata) + ORDER BY s.stratum_id, t.variable + """ + + with self.engine.connect() as conn: + df = pd.read_sql( + query, + conn, + params={ + "stratum_id": stratum_id, + "target_year": self.time_period, + }, + ) + + if len(df) > 0: + # Log which periods were selected + periods_used = df["period"].unique() + logger.debug( + f"Selected targets from periods: {sorted(periods_used)}" + ) + + return df + + def get_hierarchical_targets( + self, + cd_stratum_id: int, + state_stratum_id: int, + national_stratum_id: int, + sim=None, + ) -> pd.DataFrame: + """ + Get targets using hierarchical fallback: CD -> State -> National. + For each target concept, use the most geographically specific available. + """ + # Get all targets at each level (including descendants) + cd_targets = self.get_all_descendant_targets(cd_stratum_id, sim) + state_targets = self.get_all_descendant_targets(state_stratum_id, sim) + national_targets = self.get_all_descendant_targets( + national_stratum_id, sim + ) + + # Add geographic level to each + cd_targets["geo_level"] = "congressional_district" + cd_targets["geo_priority"] = 1 # Highest priority + state_targets["geo_level"] = "state" + state_targets["geo_priority"] = 2 + national_targets["geo_level"] = "national" + national_targets["geo_priority"] = 3 # Lowest priority + + # Combine all targets + all_targets = pd.concat( + [cd_targets, state_targets, national_targets], ignore_index=True + ) + + # Create concept identifier from variable + all constraints + def get_concept_id(row): + if not row["variable"]: + return None + + variable = row["variable"] + + # Parse constraint_info if present + if pd.notna(row.get("constraint_info")): + constraints = row["constraint_info"].split("|") + + # Filter out geographic and filer constraints + demographic_constraints = [] + irs_constraint = None + + for c in constraints: + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ): + # Check if this is an IRS variable constraint + if not any( + demo in c + for demo in [ + "age", + "adjusted_gross_income", + "eitc_child_count", + "snap", + "medicaid", + ] + ): + # This is likely an IRS variable constraint like "salt>0" + irs_constraint = c + else: + demographic_constraints.append(c) + + # If we have an IRS constraint, use that as the concept + if irs_constraint: + # Extract just the variable name from something like "salt>0" + import re + + match = re.match(r"([a-zA-Z_]+)", irs_constraint) + if match: + return f"{match.group(1)}_constrained" + + # Otherwise build concept from variable + demographic constraints + if demographic_constraints: + # Sort for consistency + demographic_constraints.sort() + # Normalize operators for valid identifiers + normalized = [] + for c in demographic_constraints: + c_norm = c.replace(">=", "_gte_").replace( + "<=", "_lte_" + ) + c_norm = c_norm.replace(">", "_gt_").replace( + "<", "_lt_" + ) + c_norm = c_norm.replace("==", "_eq_").replace( + "=", "_eq_" + ) + normalized.append(c_norm) + return f"{variable}_{'_'.join(normalized)}" + + # No constraints, just the variable + return variable + + all_targets["concept_id"] = all_targets.apply(get_concept_id, axis=1) + + # Remove targets without a valid concept + all_targets = all_targets[all_targets["concept_id"].notna()] + + # For each concept, keep only the most geographically specific target + # Sort by concept and priority, then keep first of each concept + all_targets = all_targets.sort_values(["concept_id", "geo_priority"]) + selected_targets = ( + all_targets.groupby("concept_id").first().reset_index() + ) + + logger.info( + f"Hierarchical fallback selected {len(selected_targets)} targets from " + f"{len(all_targets)} total across all levels" + ) + + return selected_targets + + def get_national_targets(self, sim=None) -> pd.DataFrame: + """ + Get national-level targets from the database. + Includes both direct national targets and national targets with strata/constraints. + Selects the best period for each target (closest to target_year in the past, or closest future). + """ + query = """ + WITH national_stratum AS ( + -- Get the national (US) stratum ID + SELECT stratum_id + FROM strata + WHERE parent_stratum_id IS NULL + LIMIT 1 + ), + national_targets AS ( + -- Get all national targets + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.active, + t.tolerance, + s.notes as stratum_notes, + (SELECT GROUP_CONCAT(sc2.constraint_variable || sc2.operation || sc2.value, '|') + FROM stratum_constraints sc2 + WHERE sc2.stratum_id = s.stratum_id + GROUP BY sc2.stratum_id) as constraint_info, + src.name as source_name + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + JOIN sources src ON t.source_id = src.source_id + WHERE ( + -- Direct national targets (no parent) + s.parent_stratum_id IS NULL + OR + -- National targets with strata (parent is national stratum) + s.parent_stratum_id = (SELECT stratum_id FROM national_stratum) + ) + AND UPPER(src.type) = 'HARDCODED' -- Hardcoded targets only + ), + -- Find best period for each stratum/variable combination + best_periods AS ( + SELECT + stratum_id, + variable, + CASE + -- If there are periods <= target_year, use the maximum (most recent) + WHEN MAX(CASE WHEN period <= :target_year THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :target_year THEN period END) + -- Otherwise use the minimum period (closest future) + ELSE MIN(period) + END as best_period + FROM national_targets + GROUP BY stratum_id, variable + ) + SELECT nt.* + FROM national_targets nt + JOIN best_periods bp ON nt.stratum_id = bp.stratum_id + AND nt.variable = bp.variable + AND nt.period = bp.best_period + ORDER BY nt.variable, nt.constraint_info + """ + + with self.engine.connect() as conn: + df = pd.read_sql( + query, conn, params={"target_year": self.time_period} + ) + + if len(df) > 0: + periods_used = df["period"].unique() + logger.info( + f"Found {len(df)} national targets from periods: {sorted(periods_used)}" + ) + else: + logger.info("No national targets found") + + return df + + def get_irs_scalar_targets( + self, geographic_stratum_id: int, geographic_level: str, sim=None + ) -> pd.DataFrame: + """ + Get IRS scalar variables from child strata with constraints. + These are now in child strata with constraints like "salt > 0" + """ + query = """ + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.active, + t.tolerance, + s.notes as stratum_notes, + s.stratum_group_id, + src.name as source_name + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + JOIN sources src ON t.source_id = src.source_id + WHERE s.parent_stratum_id = :stratum_id -- Look for children of geographic stratum + AND s.stratum_group_id >= 100 -- IRS strata have group_id >= 100 + AND src.name = 'IRS Statistics of Income' + AND t.variable NOT IN ('adjusted_gross_income') -- AGI handled separately + ORDER BY s.stratum_group_id, t.variable + """ + + with self.engine.connect() as conn: + df = pd.read_sql( + query, conn, params={"stratum_id": geographic_stratum_id} + ) + + # Note: Uprating removed - should be done once after matrix assembly + logger.info( + f"Found {len(df)} IRS scalar targets for {geographic_level}" + ) + return df + + def get_agi_total_target( + self, geographic_stratum_id: int, geographic_level: str, sim=None + ) -> pd.DataFrame: + """ + Get the total AGI amount for a geography. + This is a single scalar value, not a distribution. + """ + query = """ + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + t.active, + t.tolerance, + s.notes as stratum_notes, + src.name as source_name + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + JOIN sources src ON t.source_id = src.source_id + WHERE s.stratum_id = :stratum_id + AND t.variable = 'adjusted_gross_income' + """ + + with self.engine.connect() as conn: + df = pd.read_sql( + query, conn, params={"stratum_id": geographic_stratum_id} + ) + + # Note: Uprating removed - should be done once after matrix assembly + logger.info(f"Found AGI total target for {geographic_level}") + return df + + def get_demographic_targets( + self, + geographic_stratum_id: int, + stratum_group_id: int, + group_name: str, + sim=None, + ) -> pd.DataFrame: + """ + Generic function to get demographic targets for a geographic area. + Selects the best period for each target (closest to target_year in the past, or closest future). + + Args: + geographic_stratum_id: The parent geographic stratum + stratum_group_id: The demographic group (2=Age, 3=Income, 4=SNAP, 5=Medicaid, 6=EITC) + group_name: Descriptive name for logging + """ + query = """ + WITH demographic_targets AS ( + -- Get all targets for this demographic group + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.active, + t.tolerance, + s.notes as stratum_notes, + s.stratum_group_id, + (SELECT GROUP_CONCAT(sc2.constraint_variable || sc2.operation || sc2.value, '|') + FROM stratum_constraints sc2 + WHERE sc2.stratum_id = s.stratum_id + GROUP BY sc2.stratum_id) as constraint_info, + t.period + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE s.stratum_group_id = :stratum_group_id + AND s.parent_stratum_id = :parent_id + ), + -- Find best period for each stratum/variable combination + best_periods AS ( + SELECT + stratum_id, + variable, + CASE + -- If there are periods <= target_year, use the maximum (most recent) + WHEN MAX(CASE WHEN period <= :target_year THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :target_year THEN period END) + -- Otherwise use the minimum period (closest future) + ELSE MIN(period) + END as best_period + FROM demographic_targets + GROUP BY stratum_id, variable + ) + SELECT dt.* + FROM demographic_targets dt + JOIN best_periods bp ON dt.stratum_id = bp.stratum_id + AND dt.variable = bp.variable + AND dt.period = bp.best_period + ORDER BY dt.variable, dt.constraint_info + """ + + with self.engine.connect() as conn: + df = pd.read_sql( + query, + conn, + params={ + "target_year": self.time_period, + "stratum_group_id": stratum_group_id, + "parent_id": geographic_stratum_id, + }, + ) + + if len(df) > 0: + periods_used = df["period"].unique() + logger.debug( + f"Found {len(df)} {group_name} targets for stratum {geographic_stratum_id} from periods: {sorted(periods_used)}" + ) + else: + logger.info( + f"No {group_name} targets found for stratum {geographic_stratum_id}" + ) + + return df + + def get_national_stratum_id(self) -> Optional[int]: + """Get stratum ID for national level.""" + query = """ + SELECT stratum_id + FROM strata + WHERE parent_stratum_id IS NULL + AND stratum_group_id = 1 -- Geographic stratum + LIMIT 1 + """ + with self.engine.connect() as conn: + result = conn.execute(text(query)).fetchone() + return result[0] if result else None + + def get_state_stratum_id(self, state_fips: str) -> Optional[int]: + """Get the stratum_id for a state.""" + query = """ + SELECT s.stratum_id + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 -- Geographic + AND sc.constraint_variable = 'state_fips' + AND sc.value = :state_fips + """ + + with self.engine.connect() as conn: + result = conn.execute( + text(query), {"state_fips": state_fips} + ).fetchone() + return result[0] if result else None + + def get_state_fips_from_cd(self, cd_geoid: str) -> str: + """Extract state FIPS code from congressional district GEOID.""" + # CD GEOIDs are formatted as state_fips (1-2 digits) + district (2 digits) + # Examples: '601' -> '6', '3601' -> '36' + if len(cd_geoid) == 3: + return cd_geoid[0] # Single digit state + elif len(cd_geoid) == 4: + return cd_geoid[:2] # Two digit state + else: + raise ValueError(f"Invalid CD GEOID format: {cd_geoid}") + + def reconcile_targets_to_higher_level( + self, + lower_targets_dict: Dict[str, pd.DataFrame], + higher_level: str, + target_filters: Dict[str, any], + sim=None, + ) -> Dict[str, pd.DataFrame]: + """ + Reconcile lower-level targets to match higher-level aggregates. + Generic method that can handle CD->State or State->National reconciliation. + + Args: + lower_targets_dict: Dict mapping geography_id to its targets DataFrame + higher_level: 'state' or 'national' + target_filters: Dict with filters like {'stratum_group_id': 2} for age + sim: Microsimulation instance (if needed) + + Returns: + Dict with same structure but adjusted targets including diagnostic columns + """ + reconciled_dict = {} + + # Group lower-level geographies by their parent + if higher_level == "state": + # Group CDs by state + grouped = {} + for cd_id, targets_df in lower_targets_dict.items(): + state_fips = self.get_state_fips_from_cd(cd_id) + if state_fips not in grouped: + grouped[state_fips] = {} + grouped[state_fips][cd_id] = targets_df + else: # national + # All states belong to one national group + grouped = {"US": lower_targets_dict} + + # Process each group + for parent_id, children_dict in grouped.items(): + # Get parent-level targets + if higher_level == "state": + parent_stratum_id = self.get_state_stratum_id(parent_id) + else: # national + parent_stratum_id = self.get_national_stratum_id() + + if parent_stratum_id is None: + logger.warning( + f"Could not find {higher_level} stratum for {parent_id}" + ) + # Return unchanged + for child_id, child_df in children_dict.items(): + reconciled_dict[child_id] = child_df.copy() + continue + + # Get parent targets matching the filter + parent_targets = self._get_filtered_targets( + parent_stratum_id, target_filters + ) + + if parent_targets.empty: + # No parent targets to reconcile to + for child_id, child_df in children_dict.items(): + reconciled_dict[child_id] = child_df.copy() + continue + + # First, calculate adjustment factors for all targets + adjustment_factors = {} + for _, parent_target in parent_targets.iterrows(): + # Sum all children for this concept + total_child_sum = 0.0 + for child_id, child_df in children_dict.items(): + child_mask = self._get_matching_targets_mask( + child_df, parent_target, target_filters + ) + if child_mask.any(): + # Use ORIGINAL values, not modified ones + if ( + "original_value_pre_reconciliation" + in child_df.columns + ): + total_child_sum += child_df.loc[ + child_mask, "original_value_pre_reconciliation" + ].sum() + else: + total_child_sum += child_df.loc[ + child_mask, "value" + ].sum() + + if total_child_sum > 0: + parent_value = parent_target["value"] + factor = parent_value / total_child_sum + adjustment_factors[parent_target["variable"]] = factor + logger.info( + f"Calculated factor for {parent_target['variable']}: {factor:.4f} " + f"(parent={parent_value:,.0f}, children_sum={total_child_sum:,.0f})" + ) + + # Now apply the factors to each child + for child_id, child_df in children_dict.items(): + reconciled_df = self._apply_reconciliation_factors( + child_df, + parent_targets, + adjustment_factors, + child_id, + higher_level, + target_filters, + ) + reconciled_dict[child_id] = reconciled_df + + return reconciled_dict + + def _apply_reconciliation_factors( + self, + child_df: pd.DataFrame, + parent_targets: pd.DataFrame, + adjustment_factors: Dict[str, float], + child_id: str, + parent_level: str, + target_filters: Dict, + ) -> pd.DataFrame: + """Apply pre-calculated reconciliation factors to a child geography.""" + result_df = child_df.copy() + + # Add diagnostic columns if not present + if "original_value_pre_reconciliation" not in result_df.columns: + result_df["original_value_pre_reconciliation"] = result_df[ + "value" + ].copy() + if "reconciliation_factor" not in result_df.columns: + result_df["reconciliation_factor"] = 1.0 + if "reconciliation_source" not in result_df.columns: + result_df["reconciliation_source"] = "none" + if "undercount_pct" not in result_df.columns: + result_df["undercount_pct"] = 0.0 + + # Apply factors for matching targets + for _, parent_target in parent_targets.iterrows(): + var_name = parent_target["variable"] + if var_name in adjustment_factors: + matching_mask = self._get_matching_targets_mask( + result_df, parent_target, target_filters + ) + if matching_mask.any(): + factor = adjustment_factors[var_name] + # Apply to ORIGINAL value, not current value + original_vals = result_df.loc[ + matching_mask, "original_value_pre_reconciliation" + ] + result_df.loc[matching_mask, "value"] = ( + original_vals * factor + ) + result_df.loc[matching_mask, "reconciliation_factor"] = ( + factor + ) + result_df.loc[matching_mask, "reconciliation_source"] = ( + f"{parent_level}_{var_name}" + ) + result_df.loc[matching_mask, "undercount_pct"] = ( + (1 - 1 / factor) * 100 if factor != 0 else 0 + ) + + return result_df + + def _get_filtered_targets( + self, stratum_id: int, filters: Dict + ) -> pd.DataFrame: + """Get targets from database matching filters.""" + # Build query conditions + conditions = [ + "s.stratum_id = :stratum_id OR s.parent_stratum_id = :stratum_id" + ] + + for key, value in filters.items(): + if key == "stratum_group_id": + conditions.append(f"s.stratum_group_id = {value}") + elif key == "variable": + conditions.append(f"t.variable = '{value}'") + + query = f""" + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.period, + s.stratum_group_id, + (SELECT GROUP_CONCAT(sc2.constraint_variable || sc2.operation || sc2.value, '|') + FROM stratum_constraints sc2 + WHERE sc2.stratum_id = s.stratum_id + GROUP BY sc2.stratum_id) as constraint_info + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE {' AND '.join(conditions)} + """ + + with self.engine.connect() as conn: + return pd.read_sql(query, conn, params={"stratum_id": stratum_id}) + + def _reconcile_single_geography( + self, + child_df: pd.DataFrame, + parent_targets: pd.DataFrame, + child_id: str, + parent_id: str, + parent_level: str, + filters: Dict, + all_children_dict: Dict[str, pd.DataFrame], + ) -> pd.DataFrame: + """Reconcile a single geography's targets to parent aggregates.""" + result_df = child_df.copy() + + # Add diagnostic columns if not present + if "original_value_pre_reconciliation" not in result_df.columns: + result_df["original_value_pre_reconciliation"] = result_df[ + "value" + ].copy() + if "reconciliation_factor" not in result_df.columns: + result_df["reconciliation_factor"] = 1.0 + if "reconciliation_source" not in result_df.columns: + result_df["reconciliation_source"] = "none" + if "undercount_pct" not in result_df.columns: + result_df["undercount_pct"] = 0.0 + + # Match targets by concept (variable + constraints) + for _, parent_target in parent_targets.iterrows(): + # Find matching child targets + matching_mask = self._get_matching_targets_mask( + result_df, parent_target, filters + ) + + if not matching_mask.any(): + continue + + # Aggregate all siblings for this concept using already-collected data + sibling_sum = 0.0 + for sibling_id, sibling_df in all_children_dict.items(): + sibling_mask = self._get_matching_targets_mask( + sibling_df, parent_target, filters + ) + if sibling_mask.any(): + sibling_sum += sibling_df.loc[sibling_mask, "value"].sum() + + if sibling_sum == 0: + logger.warning( + f"Zero sum for {parent_target['variable']} in {parent_level}" + ) + continue + + # Calculate adjustment factor + parent_value = parent_target["value"] + adjustment_factor = parent_value / sibling_sum + + # Apply adjustment + result_df.loc[matching_mask, "value"] *= adjustment_factor + result_df.loc[matching_mask, "reconciliation_factor"] = ( + adjustment_factor + ) + result_df.loc[matching_mask, "reconciliation_source"] = ( + f"{parent_level}_{parent_target['variable']}" + ) + result_df.loc[matching_mask, "undercount_pct"] = ( + 1 - 1 / adjustment_factor + ) * 100 + + logger.info( + f"Reconciled {parent_target['variable']} for {child_id}: " + f"factor={adjustment_factor:.4f}, undercount={((1-1/adjustment_factor)*100):.1f}%" + ) + + return result_df + + def _get_matching_targets_mask( + self, df: pd.DataFrame, parent_target: pd.Series, filters: Dict + ) -> pd.Series: + """Get mask for targets matching parent target concept.""" + mask = df["variable"] == parent_target["variable"] + + # Match stratum_group_id if in filters + if "stratum_group_id" in filters and "stratum_group_id" in df.columns: + mask &= df["stratum_group_id"] == filters["stratum_group_id"] + + # Match constraints based on constraint_info, ignoring geographic constraints + parent_constraint_info = parent_target.get("constraint_info") + if "constraint_info" in df.columns: + # Extract demographic constraints from parent (exclude geographic) + parent_demo_constraints = set() + if pd.notna(parent_constraint_info): + for c in str(parent_constraint_info).split("|"): + if not any( + geo in c + for geo in [ + "state_fips", + "congressional_district_geoid", + ] + ): + parent_demo_constraints.add(c) + + # Create vectorized comparison for efficiency + def extract_demo_constraints(constraint_str): + """Extract non-geographic constraints from constraint string.""" + if pd.isna(constraint_str): + return frozenset() + demo_constraints = [] + for c in str(constraint_str).split("|"): + if not any( + geo in c + for geo in [ + "state_fips", + "congressional_district_geoid", + ] + ): + demo_constraints.append(c) + return frozenset(demo_constraints) + + # Apply extraction and compare + child_demo_constraints = df["constraint_info"].apply( + extract_demo_constraints + ) + parent_demo_set = frozenset(parent_demo_constraints) + mask &= child_demo_constraints == parent_demo_set + + return mask + + def _aggregate_cd_targets_for_state( + self, state_fips: str, target_concept: pd.Series, filters: Dict + ) -> float: + """Sum CD targets for a state matching the concept.""" + # Get all CDs in state + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM stratum_constraints sc + JOIN strata s ON sc.stratum_id = s.stratum_id + WHERE sc.constraint_variable = 'congressional_district_geoid' + AND sc.value LIKE :state_pattern + """ + + # Determine pattern based on state_fips length + if len(state_fips) == 1: + pattern = f"{state_fips}__" # e.g., "6__" for CA + else: + pattern = f"{state_fips}__" # e.g., "36__" for NY + + with self.engine.connect() as conn: + cd_result = conn.execute(text(query), {"state_pattern": pattern}) + cd_ids = [row[0] for row in cd_result] + + # Sum targets across CDs + total = 0.0 + for cd_id in cd_ids: + cd_stratum_id = self.get_cd_stratum_id(cd_id) + if cd_stratum_id: + cd_targets = self._get_filtered_targets(cd_stratum_id, filters) + # Sum matching targets + for _, cd_target in cd_targets.iterrows(): + if self._targets_match_concept(cd_target, target_concept): + total += cd_target["value"] + + return total + + def _targets_match_concept( + self, target1: pd.Series, target2: pd.Series + ) -> bool: + """Check if two targets represent the same concept.""" + # Must have same variable + if target1["variable"] != target2["variable"]: + return False + + # Must have same constraint pattern based on constraint_info + constraint1 = target1.get("constraint_info") + constraint2 = target2.get("constraint_info") + + # Both must be either null or non-null + if pd.isna(constraint1) != pd.isna(constraint2): + return False + + # If both have constraints, they must match exactly + if pd.notna(constraint1): + return constraint1 == constraint2 + + return True + + def _aggregate_state_targets_for_national( + self, target_concept: pd.Series, filters: Dict + ) -> float: + """Sum state targets for national matching the concept.""" + # Get all states + query = """ + SELECT DISTINCT sc.value as state_fips + FROM stratum_constraints sc + JOIN strata s ON sc.stratum_id = s.stratum_id + WHERE sc.constraint_variable = 'state_fips' + """ + + with self.engine.connect() as conn: + state_result = conn.execute(text(query)) + state_fips_list = [row[0] for row in state_result] + + # Sum targets across states + total = 0.0 + for state_fips in state_fips_list: + state_stratum_id = self.get_state_stratum_id(state_fips) + if state_stratum_id: + state_targets = self._get_filtered_targets( + state_stratum_id, filters + ) + # Sum matching targets + for _, state_target in state_targets.iterrows(): + if self._targets_match_concept( + state_target, target_concept + ): + total += state_target["value"] + + return total + + def get_cd_stratum_id(self, cd_geoid: str) -> Optional[int]: + """Get the stratum_id for a congressional district.""" + query = """ + SELECT s.stratum_id + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 -- Geographic + AND sc.constraint_variable = 'congressional_district_geoid' + AND sc.value = :cd_geoid + """ + + with self.engine.connect() as conn: + result = conn.execute( + text(query), {"cd_geoid": cd_geoid} + ).fetchone() + return result[0] if result else None + + def get_constraints_for_stratum(self, stratum_id: int) -> pd.DataFrame: + """Get all constraints for a specific stratum.""" + query = """ + SELECT + constraint_variable, + operation, + value, + notes + FROM stratum_constraints + WHERE stratum_id = :stratum_id + AND constraint_variable NOT IN ('state_fips', 'congressional_district_geoid') + ORDER BY constraint_variable + """ + + with self.engine.connect() as conn: + return pd.read_sql(query, conn, params={"stratum_id": stratum_id}) + + def apply_constraints_to_sim_sparse( + self, sim, constraints_df: pd.DataFrame, target_variable: str, + target_state_fips: Optional[int] = None + ) -> Tuple[np.ndarray, np.ndarray]: + + # TODO: is it really a good idea to skip geographic filtering? + # I'm seeing all of the US here for SNAP and I'm only in one congressional district + # We're putting a lot of faith on later functions to filter them out + """ + Apply constraints and return sparse representation (indices and values). + + *** Wow this is where the values are actually set at the household level. So + this function is really misnamed because its a crucial part of getting + the value at the household level! *** + + Note: Geographic constraints are ALWAYS skipped as geographic isolation + happens through matrix column structure in geo-stacking, not data filtering. + + Args: + sim: Microsimulation instance + constraints_df: DataFrame with constraints + target_variable: Variable to calculate + target_state_fips: If provided and variable is state-dependent, use cached state-specific values + + Returns: + Tuple of (nonzero_indices, nonzero_values) at household level + """ + + # Check if we should use US-state-specific cached values + us_state_dependent_vars = get_us_state_dependent_variables() + use_cache = (target_state_fips is not None and + target_variable in us_state_dependent_vars and + len(self._state_specific_cache) > 0) + + if use_cache: + # Use cached state-specific values instead of calculating + logger.debug(f"Using cached {target_variable} values for state {target_state_fips}") + household_ids = sim.calculate("household_id", map_to="household").values + + # Get values from cache for this state + household_values = [] + for hh_id in household_ids: + cache_key = (int(hh_id), int(target_state_fips), target_variable) + value = self._state_specific_cache.get(cache_key, 0.0) + household_values.append(value) + + household_values = np.array(household_values) + + # Apply non-geographic constraints to determine which households qualify + # (We still need to filter based on constraints like "snap > 0") + # Build entity relationship to check constraints + entity_rel = pd.DataFrame({ + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + }) + + # Start with all persons + person_constraint_mask = np.ones(len(entity_rel), dtype=bool) + + # Apply each non-geographic constraint + for _, constraint in constraints_df.iterrows(): + var = constraint["constraint_variable"] + op = constraint["operation"] + val = constraint["value"] + + if var in ["state_fips", "congressional_district_geoid"]: + continue + + # Special handling for the target variable itself + if var == target_variable: + # Map household values to person level for constraint checking + hh_value_map = dict(zip(household_ids, household_values)) + person_hh_ids = entity_rel["household_id"].values + person_target_values = np.array([hh_value_map.get(hh_id, 0.0) for hh_id in person_hh_ids]) + + # Parse constraint value + try: + parsed_val = float(val) + if parsed_val.is_integer(): + parsed_val = int(parsed_val) + except ValueError: + parsed_val = val + + # Apply operation + if op == "==" or op == "=": + mask = (person_target_values == parsed_val).astype(bool) + elif op == ">": + mask = (person_target_values > parsed_val).astype(bool) + elif op == ">=": + mask = (person_target_values >= parsed_val).astype(bool) + elif op == "<": + mask = (person_target_values < parsed_val).astype(bool) + elif op == "<=": + mask = (person_target_values <= parsed_val).astype(bool) + elif op == "!=": + mask = (person_target_values != parsed_val).astype(bool) + else: + continue + + person_constraint_mask = person_constraint_mask & mask + + # Aggregate to household level + entity_rel["satisfies_constraints"] = person_constraint_mask + household_mask = entity_rel.groupby("household_id")["satisfies_constraints"].any() + + # Apply mask to values + masked_values = household_values * household_mask.values + + # Return sparse representation + nonzero_indices = np.nonzero(masked_values)[0] + nonzero_values = masked_values[nonzero_indices] + + return nonzero_indices, nonzero_values + + ## Get target entity level + target_entity = sim.tax_benefit_system.variables[ + target_variable + ].entity.key + + # Build entity relationship DataFrame at person level + # This gives us the mapping between all entities + entity_rel = pd.DataFrame( + { + "person_id": sim.calculate( + "person_id", map_to="person" + ).values, + "household_id": sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": sim.calculate( + "spm_unit_id", map_to="person" + ).values, + "family_id": sim.calculate( + "family_id", map_to="person" + ).values, + "marital_unit_id": sim.calculate( + "marital_unit_id", map_to="person" + ).values, + } + ) + + # Start with all persons satisfying constraints (will be ANDed together) + person_constraint_mask = np.ones(len(entity_rel), dtype=bool) + + # Apply each constraint at person level + for _, constraint in constraints_df.iterrows(): + var = constraint["constraint_variable"] + op = constraint["operation"] + val = constraint["value"] + + # ALWAYS skip geographic constraints - geo-stacking handles geography through matrix structure + if var in ["state_fips", "congressional_district_geoid"]: + continue + + try: + # Get constraint values at person level + # We need to explicitly map to person for non-person variables + constraint_entity = sim.tax_benefit_system.variables[ + var + ].entity.key + if constraint_entity == "person": + constraint_values = sim.calculate(var).values + else: + # For tax_unit or household variables, map to person level + # This broadcasts the values so each person gets their tax_unit/household's value + constraint_values = sim.calculate( + var, map_to="person" + ).values + + # Parse value based on type + try: + parsed_val = float(val) + if parsed_val.is_integer(): + parsed_val = int(parsed_val) + except ValueError: + if val == "True": + parsed_val = True + elif val == "False": + parsed_val = False + else: + parsed_val = val + + # Apply operation at person level + if op == "==" or op == "=": + mask = (constraint_values == parsed_val).astype(bool) + elif op == ">": + mask = (constraint_values > parsed_val).astype(bool) + elif op == ">=": + mask = (constraint_values >= parsed_val).astype(bool) + elif op == "<": + mask = (constraint_values < parsed_val).astype(bool) + elif op == "<=": + mask = (constraint_values <= parsed_val).astype(bool) + elif op == "!=": + mask = (constraint_values != parsed_val).astype(bool) + else: + logger.warning(f"Unknown operation {op}") + continue + + # AND this constraint with existing constraints + person_constraint_mask = person_constraint_mask & mask + + except Exception as e: + logger.warning( + f"Could not apply constraint {var} {op} {val}: {e}" + ) + continue + + # Add constraint mask to entity_rel + entity_rel["satisfies_constraints"] = person_constraint_mask + + # Now aggregate constraints to target entity level + if target_entity == "person": + entity_mask = person_constraint_mask + entity_ids = entity_rel["person_id"].values + elif target_entity == "household": + household_mask = entity_rel.groupby("household_id")[ + "satisfies_constraints" + ].any() + entity_mask = household_mask.values + entity_ids = household_mask.index.values + elif target_entity == "tax_unit": + tax_unit_mask = entity_rel.groupby("tax_unit_id")[ + "satisfies_constraints" + ].any() + entity_mask = tax_unit_mask.values + entity_ids = tax_unit_mask.index.values + elif target_entity == "spm_unit": + spm_unit_mask = entity_rel.groupby("spm_unit_id")[ + "satisfies_constraints" + ].any() + entity_mask = spm_unit_mask.values + entity_ids = spm_unit_mask.index.values + else: + raise ValueError(f"Entity type {target_entity} not handled") + + target_values_raw = sim.calculate( + target_variable, map_to=target_entity + ).values + + masked_values = target_values_raw * entity_mask + + entity_df = pd.DataFrame( + { + f"{target_entity}_id": entity_ids, + "entity_masked_metric": masked_values, + } + ) + if target_entity == "household": + hh_df = entity_df + else: + entity_rel_for_agg = entity_rel[["household_id", f"{target_entity}_id"]].drop_duplicates() + hh_df = entity_rel_for_agg.merge(entity_df, on=f"{target_entity}_id") + + # Check if this is a count variable + is_count_target = target_variable.endswith("_count") + + if is_count_target: + # For counts, count unique entities per household that satisfy constraints + masked_df = hh_df.loc[hh_df["entity_masked_metric"] > 0] + household_counts = masked_df.groupby("household_id")[ + f"{target_entity}_id" + ].nunique() + all_households = hh_df["household_id"].unique() + household_values_df = pd.DataFrame( + { + "household_id": all_households, + "household_metric": household_counts.reindex( + all_households, fill_value=0 + ).values, + } + ) + else: + # For non-counts, sum the values + household_values_df = ( + hh_df.groupby("household_id")[["entity_masked_metric"]] + .sum() + .reset_index() + .rename({"entity_masked_metric": "household_metric"}, axis=1) + ) + + # Return sparse representation + household_values_df = household_values_df.sort_values( + ["household_id"] + ).reset_index(drop=True) + nonzero_indices = np.nonzero(household_values_df["household_metric"])[ + 0 + ] + nonzero_values = household_values_df.iloc[nonzero_indices][ + "household_metric" + ].values + + return nonzero_indices, nonzero_values + + def build_matrix_for_geography_sparse( + self, geographic_level: str, geographic_id: str, sim + ) -> Tuple[pd.DataFrame, sparse.csr_matrix, List[str]]: + """ + Build sparse calibration matrix for any geographic level using hierarchical fallback. + + Returns: + Tuple of (targets_df, sparse_matrix, household_ids) + """ + national_stratum_id = ( + self.get_national_stratum_id() + ) # 1 is the id for the US stratum with no other constraints + + if geographic_level == "state": + state_stratum_id = self.get_state_stratum_id(geographic_id) + cd_stratum_id = None # No CD level for state calibration + geo_label = f"state_{geographic_id}" + if state_stratum_id is None: + raise ValueError( + f"Could not find state {geographic_id} in database" + ) + elif geographic_level == "congressional_district": + cd_stratum_id = self.get_cd_stratum_id( + geographic_id + ) # congressional district stratum with no other constraints + state_fips = self.get_state_fips_from_cd(geographic_id) + state_stratum_id = self.get_state_stratum_id(state_fips) + geo_label = f"cd_{geographic_id}" + if cd_stratum_id is None: + raise ValueError( + f"Could not find CD {geographic_id} in database" + ) + else: + raise ValueError(f"Unknown geographic level: {geographic_level}") + + # Use hierarchical fallback to get all targets + if geographic_level == "congressional_district": + # CD calibration: Use CD -> State -> National fallback + # TODO: why does CD level use a function other than get_all_descendant_targets below? + hierarchical_targets = self.get_hierarchical_targets( + cd_stratum_id, state_stratum_id, national_stratum_id, sim + ) + else: # state + # State calibration: Use State -> National fallback (no CD level) + # For state calibration, we pass state_stratum_id twice to avoid null issues + # TODO: why does state and national levels use a function other than get_hierarchical_targets above?_ + state_targets = self.get_all_descendant_targets( + state_stratum_id, sim + ) + national_targets = self.get_all_descendant_targets( + national_stratum_id, sim + ) + + # Add geographic level + state_targets["geo_level"] = "state" + state_targets["geo_priority"] = 1 + national_targets["geo_level"] = "national" + national_targets["geo_priority"] = 2 + + # Combine and deduplicate + all_targets = pd.concat( + [state_targets, national_targets], ignore_index=True + ) + + # Create concept identifier from variable + all constraints + # TODO (baogorek): Is this function defined muliple times? (I think it is) + def get_concept_id(row): + if not row["variable"]: + return None + + variable = row["variable"] + + # Parse constraint_info if present + # TODO (baogorek): hard-coding needs refactoring + if pd.notna(row.get("constraint_info")): + constraints = row["constraint_info"].split("|") + + # Filter out geographic and filer constraints + demographic_constraints = [] + irs_constraint = None + + for c in constraints: + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ): + # Check if this is an IRS variable constraint + if not any( + demo in c + for demo in [ + "age", + "adjusted_gross_income", + "eitc_child_count", + "snap", + "medicaid", + ] + ): + # This is likely an IRS variable constraint like "salt>0" + irs_constraint = c + else: + demographic_constraints.append(c) + + # If we have an IRS constraint, use that as the concept + if irs_constraint: + # Extract just the variable name from something like "salt>0" + import re + + match = re.match(r"([a-zA-Z_]+)", irs_constraint) + if match: + return f"{match.group(1)}_constrained" + + # Otherwise build concept from variable + demographic constraints + if demographic_constraints: + # Sort for consistency + demographic_constraints.sort() + # Normalize operators for valid identifiers + normalized = [] + for c in demographic_constraints: + c_norm = c.replace(">=", "_gte_").replace( + "<=", "_lte_" + ) + c_norm = c_norm.replace(">", "_gt_").replace( + "<", "_lt_" + ) + c_norm = c_norm.replace("==", "_eq_").replace( + "=", "_eq_" + ) + normalized.append(c_norm) + return f"{variable}_{'_'.join(normalized)}" + + # No constraints, just the variable + return variable + + all_targets["concept_id"] = all_targets.apply( + get_concept_id, axis=1 + ) + all_targets = all_targets[all_targets["concept_id"].notna()] + all_targets = all_targets.sort_values( + ["concept_id", "geo_priority"] + ) + hierarchical_targets = ( + all_targets.groupby("concept_id").first().reset_index() + ) + + # Process hierarchical targets into the format expected by the rest of the code + all_targets = [] + + for _, target_row in hierarchical_targets.iterrows(): + # BUILD DESCRIPTION from variable and constraints (but not all constraints) ---- + desc_parts = [target_row["variable"]] + + # Parse constraint_info to add all constraints to description + if pd.notna(target_row.get("constraint_info")): + constraints = target_row["constraint_info"].split("|") + # Filter out geographic and filer constraints FOR DESCRIPTION + for c in constraints: + # TODO (baogorek): I get that the string is getting long, but "(filers)" doesn't add too much and geo_ids are max 4 digits + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ): + desc_parts.append(c) + + # Preserve the original stratum_group_id for proper grouping + # Special handling only for truly national/geographic targets + if pd.isna(target_row["stratum_group_id"]): + # No stratum_group_id means it's a national target + group_id = "national" + elif target_row["stratum_group_id"] == 1: + # Geographic identifier (not a real target) + group_id = "geographic" + else: + # Keep the original numeric stratum_group_id + # This preserves 2=Age, 3=AGI, 4=SNAP, 5=Medicaid, 6=EITC, 100+=IRS + group_id = target_row["stratum_group_id"] + + all_targets.append( + { + "target_id": target_row.get("target_id"), + "variable": target_row["variable"], + "value": target_row["value"], + "active": target_row.get("active", True), + "tolerance": target_row.get("tolerance", 0.05), + "stratum_id": target_row["stratum_id"], + "stratum_group_id": group_id, + "geographic_level": target_row["geo_level"], + "geographic_id": ( + geographic_id + if target_row["geo_level"] == geographic_level + else ( + "US" + if target_row["geo_level"] == "national" + else state_fips + ) + ), + "description": "_".join(desc_parts), + } + ) + + targets_df = pd.DataFrame(all_targets) + + # Build sparse data matrix ("loss matrix" historically) --------------------------------------- + # NOTE: we are unapologetically at the household level at this point + household_ids = sim.calculate( + "household_id" + ).values # Implicit map to "household" entity level + n_households = len(household_ids) + n_targets = len(targets_df) + + # Use LIL matrix for efficient row-by-row construction + matrix = sparse.lil_matrix((n_targets, n_households), dtype=np.float32) + + # TODO: is this were all the values are set? + for i, (_, target) in enumerate(targets_df.iterrows()): + # target = targets_df.iloc[68] + constraints = self.get_constraints_for_stratum( + target["stratum_id"] + ) # NOTE:will not return the geo constraint + # TODO: going in with snap target with index 68, and no constraints came out + nonzero_indices, nonzero_values = ( + self.apply_constraints_to_sim_sparse( + sim, constraints, target["variable"] + ) + ) + if len(nonzero_indices) > 0: + matrix[i, nonzero_indices] = nonzero_values + + matrix = ( + matrix.tocsr() + ) # To compressed sparse row (CSR) for efficient operations + + logger.info( + f"Created sparse matrix for {geographic_level} {geographic_id}: shape {matrix.shape}, nnz={matrix.nnz}" + ) + return targets_df, matrix, household_ids.tolist() + + # TODO (baogorek): instance of hard-coding (figure it out. This is why we have a targets database) + def get_state_snap_cost(self, state_fips: str) -> pd.DataFrame: + """Get state-level SNAP cost target (administrative data).""" + query = """ + WITH snap_targets AS ( + SELECT + t.target_id, + t.stratum_id, + t.variable, + t.value, + t.active, + t.tolerance, + t.period + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 4 -- SNAP + AND t.variable = 'snap' -- Cost variable + AND sc.constraint_variable = 'state_fips' + AND sc.value = :state_fips + ), + best_period AS ( + SELECT + CASE + WHEN MAX(CASE WHEN period <= :target_year THEN period END) IS NOT NULL + THEN MAX(CASE WHEN period <= :target_year THEN period END) + ELSE MIN(period) + END as selected_period + FROM snap_targets + ) + SELECT st.* + FROM snap_targets st + JOIN best_period bp ON st.period = bp.selected_period + """ + + with self.engine.connect() as conn: + return pd.read_sql( + query, + conn, + params={ + "state_fips": state_fips, + "target_year": self.time_period, + }, + ) + + def get_state_fips_for_cd(self, cd_geoid: str) -> str: + """Extract state FIPS from CD GEOID.""" + # CD GEOIDs are formatted as state_fips + district_number + # e.g., "601" = California (06) district 01 + if len(cd_geoid) == 3: + return str( + int(cd_geoid[:1]) + ) # Single digit state, return as string of integer + elif len(cd_geoid) == 4: + return str( + int(cd_geoid[:2]) + ) # Two digit state, return as string of integer + else: + raise ValueError(f"Unexpected CD GEOID format: {cd_geoid}") + + def build_stacked_matrix_sparse( + self, geographic_level: str, geographic_ids: List[str], sim=None + ) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: + """ + Build stacked sparse calibration matrix for multiple geographic areas. + + Returns: + Tuple of (targets_df, sparse_matrix, household_id_mapping) + """ + all_targets = [] + geo_matrices = [] + household_id_mapping = {} + + # Pre-calculate US-state-specific values for state-dependent variables + if sim is not None and len(self._state_specific_cache) == 0: + us_state_dependent_vars = get_us_state_dependent_variables() + if us_state_dependent_vars: + logger.info("Pre-calculating US-state-specific values for state-dependent variables...") + # Get dataset path from sim to create fresh simulations per state + dataset_path = str(sim.dataset.__class__.file_path) + self._calculate_state_specific_values(dataset_path, us_state_dependent_vars) + + # First, get national targets once (they apply to all geographic copies) + national_targets = self.get_national_targets(sim) + national_targets_list = [] + for _, target in national_targets.iterrows(): + # Get uprating info + factor, uprating_type = self._get_uprating_info( + target["variable"], target["period"] + ) + + # Build description with all constraints from constraint_info + var_desc = target["variable"] + if "constraint_info" in target and pd.notna( + target["constraint_info"] + ): + constraints = target["constraint_info"].split("|") + # Filter out geographic and filer constraints + demo_constraints = [ + c + for c in constraints + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ) + ] + if demo_constraints: + # Join all constraints with underscores + var_desc = ( + f"{target['variable']}_{'_'.join(demo_constraints)}" + ) + + national_targets_list.append( + { + "target_id": target["target_id"], + "stratum_id": target["stratum_id"], + "value": target["value"] * factor, + "original_value": target["value"], + "variable": target["variable"], + "variable_desc": var_desc, + "geographic_id": "US", + "stratum_group_id": "national", # Required for create_target_groups + "period": target["period"], + "uprating_factor": factor, + "reconciliation_factor": 1.0, + } + ) + + # Build national targets matrix ONCE before the loop + national_matrix = None + if sim is not None and len(national_targets) > 0: + import time + + start = time.time() + logger.info( + f"Building national targets matrix once... ({len(national_targets)} targets)" + ) + household_ids = sim.calculate("household_id").values + n_households = len(household_ids) + n_national_targets = len(national_targets) + + # Build sparse matrix for national targets + national_matrix = sparse.lil_matrix( + (n_national_targets, n_households), dtype=np.float32 + ) + + for i, (_, target) in enumerate(national_targets.iterrows()): + if i % 10 == 0: + logger.info( + f" Processing national target {i+1}/{n_national_targets}: {target['variable']}" + ) + # Get constraints for this stratum + constraints = self.get_constraints_for_stratum( + target["stratum_id"] + ) + + # Get sparse representation of household values + nonzero_indices, nonzero_values = ( + self.apply_constraints_to_sim_sparse( + sim, constraints, target["variable"] + ) + ) + + # Set the sparse row + if len(nonzero_indices) > 0: + national_matrix[i, nonzero_indices] = nonzero_values + + # Convert to CSR for efficiency + national_matrix = national_matrix.tocsr() + elapsed = time.time() - start + logger.info( + f"National matrix built in {elapsed:.1f}s: shape {national_matrix.shape}, nnz={national_matrix.nnz}" + ) + + # Collect all geography targets first for reconciliation + all_geo_targets_dict = {} + + # Build matrix for each geography (CD-specific targets only) + for i, geo_id in enumerate(geographic_ids): + if i % 50 == 0: # Log every 50th CD instead of every one + logger.info( + f"Processing {geographic_level}s: {i+1}/{len(geographic_ids)} completed..." + ) + + # Get CD-specific targets directly without rebuilding national + if geographic_level == "congressional_district": + cd_stratum_id = self.get_cd_stratum_id( + geo_id + ) # The base geographic stratum + if cd_stratum_id is None: + raise ValueError(f"Could not find CD {geo_id} in database") + + # Get only CD-specific targets with deduplication + cd_targets_raw = self.get_all_descendant_targets( + cd_stratum_id, sim + ) + + # Deduplicate CD targets by concept using ALL constraints + def get_cd_concept_id(row): + """ + Creates unique concept IDs from ALL constraints, not just the first one. + This eliminates the need for hard-coded stratum_group_id logic. + + Examples: + - person_count with age>4|age<10 -> person_count_age_gt_4_age_lt_10 + - person_count with adjusted_gross_income>=25000|adjusted_gross_income<50000 + -> person_count_adjusted_gross_income_gte_25000_adjusted_gross_income_lt_50000 + """ + variable = row["variable"] + + # Parse constraint_info which contains ALL constraints + if "constraint_info" in row and pd.notna( + row["constraint_info"] + ): + constraints = row["constraint_info"].split("|") + + # Filter out geographic constraints (not part of the concept) + demographic_constraints = [] + for c in constraints: + # Skip geographic and filer constraints + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ): + # Normalize the constraint format for consistency + # Replace operators with text equivalents for valid Python identifiers + c_normalized = c.replace( + ">=", "_gte_" + ).replace("<=", "_lte_") + c_normalized = c_normalized.replace( + ">", "_gt_" + ).replace("<", "_lt_") + c_normalized = c_normalized.replace( + "==", "_eq_" + ).replace("=", "_eq_") + c_normalized = c_normalized.replace( + " ", "" + ) # Remove any spaces + demographic_constraints.append(c_normalized) + + # Sort for consistency (ensures same constraints always produce same ID) + demographic_constraints.sort() + + if demographic_constraints: + # Join all constraints to create unique concept + constraint_str = "_".join(demographic_constraints) + return f"{variable}_{constraint_str}" + + # No constraints, just the variable name + return variable + + cd_targets_raw["cd_concept_id"] = cd_targets_raw.apply( + get_cd_concept_id, axis=1 + ) + + if cd_targets_raw["cd_concept_id"].isna().any(): + raise ValueError( + "Error: One or more targets were found without a valid concept ID." + ) + + # For each concept, keep the first occurrence (or most specific based on stratum_group_id) + # Prioritize by stratum_group_id: higher values are more specific + cd_targets_raw = cd_targets_raw.sort_values( + ["cd_concept_id", "stratum_group_id"], + ascending=[True, False], + ) + cd_targets = ( + cd_targets_raw.groupby("cd_concept_id") + .first() + .reset_index(drop=True) + ) + + if len(cd_targets_raw) != len(cd_targets): + raise ValueError( + f"CD {geo_id}: Unwanted duplication: {len(cd_targets)} unique targets from {len(cd_targets_raw)} raw targets" + ) + + # Store CD targets with stratum_group_id preserved for reconciliation + cd_targets["geographic_id"] = geo_id + all_geo_targets_dict[geo_id] = cd_targets + else: + # For state-level, collect targets for later reconciliation + state_stratum_id = self.get_state_stratum_id(geo_id) + if state_stratum_id is None: + logger.warning( + f"Could not find state {geo_id} in database" + ) + continue + state_targets = self.get_all_descendant_targets( + state_stratum_id, sim + ) + state_targets["geographic_id"] = geo_id + all_geo_targets_dict[geo_id] = state_targets + + # Reconcile targets to higher level if CD calibration + if ( + geographic_level == "congressional_district" + and all_geo_targets_dict + ): + # Age targets (stratum_group_id=2) - already match so no-op + logger.info("Reconciling CD age targets to state totals...") + reconciled_dict = self.reconcile_targets_to_higher_level( + all_geo_targets_dict, + higher_level="state", + target_filters={"stratum_group_id": 2}, # Age targets + sim=sim, + ) + all_geo_targets_dict = reconciled_dict + + # Medicaid targets (stratum_group_id=5) - needs reconciliation + # TODO(bogorek): manually trace a reconcilliation + logger.info( + "Reconciling CD Medicaid targets to state admin totals..." + ) + reconciled_dict = self.reconcile_targets_to_higher_level( + all_geo_targets_dict, + higher_level="state", + target_filters={"stratum_group_id": 5}, # Medicaid targets + sim=sim, + ) + all_geo_targets_dict = reconciled_dict + + # SNAP household targets (stratum_group_id=4) - needs reconciliation + logger.info( + "Reconciling CD SNAP household counts to state admin totals..." + ) + reconciled_dict = self.reconcile_targets_to_higher_level( + all_geo_targets_dict, + higher_level="state", + target_filters={ + "stratum_group_id": 4, + "variable": "household_count", + }, # SNAP households + sim=sim, + ) + all_geo_targets_dict = reconciled_dict + + # Now build matrices for all collected and reconciled targets + # TODO (baogorek): a lot of hard-coded stuff here, but there is an else backoff + for geo_id, geo_targets_df in all_geo_targets_dict.items(): + # Format targets + geo_target_list = [] + for _, target in geo_targets_df.iterrows(): + # Get uprating info + factor, uprating_type = self._get_uprating_info( + target["variable"], target.get("period", self.time_period) + ) + + # Apply uprating to value (may already have reconciliation factor applied) + final_value = target["value"] * factor + + # Create meaningful description based on stratum_group_id and variable + stratum_group = target.get("stratum_group_id") + + # Build descriptive prefix based on stratum_group_id + # TODO (baogorek): Usage of stratum_group is not ideal, but is this just building notes? + if isinstance(stratum_group, (int, np.integer)): + if stratum_group == 2: # Age + # Use stratum_notes if available, otherwise build from constraint + if "stratum_notes" in target and pd.notna( + target.get("stratum_notes") + ): + # Extract age range from notes like "Age: 0-4, CD 601" + notes = str(target["stratum_notes"]) + if "Age:" in notes: + age_part = ( + notes.split("Age:")[1] + .split(",")[0] + .strip() + ) + desc_prefix = f"age_{age_part}" + else: + desc_prefix = "age" + else: + desc_prefix = "age" + elif stratum_group == 3: # AGI + desc_prefix = "AGI" + elif stratum_group == 4: # SNAP + desc_prefix = "SNAP_households" + elif stratum_group == 5: # Medicaid + desc_prefix = "Medicaid_enrollment" + elif stratum_group == 6: # EITC + desc_prefix = "EITC" + elif stratum_group >= 100: # IRS variables + irs_names = { + 100: "QBI_deduction", + 101: "self_employment", + 102: "net_capital_gains", + 103: "real_estate_taxes", + 104: "rental_income", + 105: "net_capital_gain", + 106: "taxable_IRA_distributions", + 107: "taxable_interest", + 108: "tax_exempt_interest", + 109: "dividends", + 110: "qualified_dividends", + 111: "partnership_S_corp", + 112: "all_filers", + 113: "unemployment_comp", + 114: "medical_deduction", + 115: "taxable_pension", + 116: "refundable_CTC", + 117: "SALT_deduction", + 118: "income_tax_paid", + 119: "income_tax_before_credits", + } + desc_prefix = irs_names.get( + stratum_group, f"IRS_{stratum_group}" + ) + # Add variable suffix for amount vs count + if target["variable"] == "tax_unit_count": + desc_prefix = f"{desc_prefix}_count" + else: + desc_prefix = f"{desc_prefix}_amount" + else: + desc_prefix = target["variable"] + else: + desc_prefix = target["variable"] + + # Just use the descriptive prefix without geographic suffix + # The geographic context is already provided elsewhere + description = desc_prefix + + # Build description with all constraints from constraint_info + var_desc = target["variable"] + if "constraint_info" in target and pd.notna( + target["constraint_info"] + ): + constraints = target["constraint_info"].split("|") + # Filter out geographic and filer constraints + demo_constraints = [ + c + for c in constraints + if not any( + skip in c + for skip in [ + "state_fips", + "congressional_district_geoid", + "tax_unit_is_filer", + ] + ) + ] + if demo_constraints: + # Join all constraints with underscores + var_desc = f"{target['variable']}_{'_'.join(demo_constraints)}" + + geo_target_list.append( + { + "target_id": target["target_id"], + "stratum_id": target["stratum_id"], + "value": final_value, + "original_value": target.get( + "original_value_pre_reconciliation", + target["value"], + ), + "variable": target["variable"], + "variable_desc": var_desc, + "geographic_id": geo_id, + "stratum_group_id": target.get( + "stratum_group_id", geographic_level + ), # Preserve original group ID + "period": target.get("period", self.time_period), + "uprating_factor": factor, + "reconciliation_factor": target.get( + "reconciliation_factor", 1.0 + ), + "undercount_pct": target.get("undercount_pct", 0.0), + } + ) + + if geo_target_list: + targets_df = pd.DataFrame(geo_target_list) + all_targets.append(targets_df) + + # Build matrix for geo-specific targets + if sim is not None: + household_ids = sim.calculate("household_id").values + n_households = len(household_ids) + n_targets = len(targets_df) + + matrix = sparse.lil_matrix( + (n_targets, n_households), dtype=np.float32 + ) + + for j, (_, target) in enumerate(targets_df.iterrows()): + constraints = self.get_constraints_for_stratum( + target["stratum_id"] + ) + nonzero_indices, nonzero_values = ( + self.apply_constraints_to_sim_sparse( + sim, constraints, target["variable"] + ) + ) + if len(nonzero_indices) > 0: + matrix[j, nonzero_indices] = nonzero_values + + matrix = matrix.tocsr() + geo_matrices.append(matrix) + + # Store household ID mapping + prefix = ( + "cd" + if geographic_level == "congressional_district" + else "state" + ) + household_id_mapping[f"{prefix}{geo_id}"] = [ + f"{hh_id}_{prefix}{geo_id}" for hh_id in household_ids + ] + + # If building for congressional districts, add state-level SNAP costs + state_snap_targets_list = [] + state_snap_matrices = [] + if geographic_level == "congressional_district": + # Identify unique states from the CDs + unique_states = set() + for cd_id in geographic_ids: + state_fips = self.get_state_fips_for_cd(cd_id) + unique_states.add(state_fips) + + # Get household info - must match the actual matrix columns + household_ids = sim.calculate("household_id").values + n_households = len(household_ids) + total_cols = n_households * len(geographic_ids) + + # Get SNAP cost target for each state + for state_fips in sorted(unique_states): + snap_cost_df = self.get_state_snap_cost(state_fips) + if not snap_cost_df.empty: + for _, target in snap_cost_df.iterrows(): + # Get uprating info + # TODO: why is period showing up as 2022 in my interactive run? + period = target.get("period", self.time_period) + factor, uprating_type = self._get_uprating_info( + target["variable"], period + ) + + state_snap_targets_list.append( + { + "target_id": target["target_id"], + "stratum_id": target["stratum_id"], + "value": target["value"] * factor, + "original_value": target["value"], + "variable": target["variable"], + "variable_desc": "snap_cost_state", + "geographic_id": state_fips, + "stratum_group_id": "state_snap_cost", # Special group for state SNAP costs + "period": period, + "uprating_factor": factor, + "reconciliation_factor": 1.0, + "undercount_pct": 0.0, + } + ) + + # Build matrix row for this state SNAP cost + # This row should have SNAP values for households in CDs of this state + # Get constraints for this state SNAP stratum to apply to simulation + constraints = self.get_constraints_for_stratum( + target["stratum_id"] + ) + + # Create a sparse row with correct dimensions (1 x total_cols) + row_data = [] + row_indices = [] + + # Calculate SNAP values once for ALL households (geographic isolation via matrix structure) + # Note: state_fips constraint is automatically skipped, SNAP values calculated for all + # Use state-specific cached values if available + nonzero_indices, nonzero_values = ( + self.apply_constraints_to_sim_sparse( + sim, constraints, "snap", + target_state_fips=int(state_fips) # Pass state to use cached values + ) + ) + + # Create a mapping of household indices to SNAP values + snap_value_map = dict( + zip(nonzero_indices, nonzero_values) + ) + + # Place SNAP values in ALL CD columns that belong to this state + # This creates the proper geo-stacking structure where state-level targets + # span multiple CD columns (all CDs within the state) + for cd_idx, cd_id in enumerate(geographic_ids): + cd_state_fips = self.get_state_fips_from_cd(cd_id) + if cd_state_fips == state_fips: + # This CD is in the target state - add SNAP values to its columns + col_offset = cd_idx * n_households + for hh_idx, snap_val in snap_value_map.items(): + row_indices.append(col_offset + hh_idx) + row_data.append(snap_val) + + # Create sparse matrix row + if row_data: + row_matrix = sparse.csr_matrix( + (row_data, ([0] * len(row_data), row_indices)), + shape=(1, total_cols), + ) + state_snap_matrices.append(row_matrix) + + # Add state SNAP targets to all_targets + if state_snap_targets_list: + all_targets.append(pd.DataFrame(state_snap_targets_list)) + + # Add national targets to the list once + if national_targets_list: + all_targets.insert(0, pd.DataFrame(national_targets_list)) + + # Combine all targets + combined_targets = pd.concat(all_targets, ignore_index=True) + + # Stack matrices + if not geo_matrices: + raise ValueError("No geo_matrices were built - this should not happen") + + # Stack geo-specific targets (block diagonal) + stacked_geo = sparse.block_diag(geo_matrices) + logger.info( + f"Stacked geo-specific matrix: shape {stacked_geo.shape}, nnz={stacked_geo.nnz}" + ) + + # Combine all matrix parts + matrix_parts = [] + if national_matrix is not None: + national_copies = [national_matrix] * len(geographic_ids) + stacked_national = sparse.hstack(national_copies) + logger.info( + f"Stacked national matrix: shape {stacked_national.shape}, nnz={stacked_national.nnz}" + ) + matrix_parts.append(stacked_national) + matrix_parts.append(stacked_geo) + + # Add state SNAP matrices if we have them (for CD calibration) + if state_snap_matrices: + stacked_state_snap = sparse.vstack(state_snap_matrices) + matrix_parts.append(stacked_state_snap) + + # Combine all parts + combined_matrix = sparse.vstack(matrix_parts) + combined_matrix = combined_matrix.tocsr() + + logger.info( + f"Created stacked sparse matrix: shape {combined_matrix.shape}, nnz={combined_matrix.nnz}" + ) + return combined_targets, combined_matrix, household_id_mapping + + +def main(): + """Example usage for California and North Carolina.""" + from policyengine_us import Microsimulation + + # Database path + db_uri = "sqlite:////home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db" + + # Initialize sparse builder + builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + + # Create microsimulation with 2024 data + print("Loading microsimulation...") + sim = Microsimulation( + dataset="hf://policyengine/policyengine-us-data/enhanced_cps_2024.h5" + ) + + # Test single state + print("\nBuilding sparse matrix for California (FIPS 6)...") + targets_df, matrix, household_ids = ( + builder.build_matrix_for_geography_sparse("state", "6", sim) + ) + + print("\nTarget Summary:") + print(f"Total targets: {len(targets_df)}") + print(f"Matrix shape: {matrix.shape}") + print( + f"Matrix sparsity: {matrix.nnz} non-zero elements ({100*matrix.nnz/(matrix.shape[0]*matrix.shape[1]):.4f}%)" + ) + print( + f"Memory usage: {matrix.data.nbytes + matrix.indices.nbytes + matrix.indptr.nbytes} bytes" + ) + + # Test stacking multiple states + print("\n" + "=" * 70) + print( + "Testing multi-state stacking: California (6) and North Carolina (37)" + ) + print("=" * 70) + + targets_df, matrix, hh_mapping = builder.build_stacked_matrix_sparse( + "state", ["6", "37"], sim + ) + + if matrix is not None: + print(f"\nStacked matrix shape: {matrix.shape}") + print( + f"Stacked matrix sparsity: {matrix.nnz} non-zero elements ({100*matrix.nnz/(matrix.shape[0]*matrix.shape[1]):.4f}%)" + ) + print( + f"Memory usage: {matrix.data.nbytes + matrix.indices.nbytes + matrix.indptr.nbytes} bytes" + ) + + # Compare to dense matrix memory + dense_memory = ( + matrix.shape[0] * matrix.shape[1] * 4 + ) # 4 bytes per float32 + print(f"Dense matrix would use: {dense_memory} bytes") + print( + f"Memory savings: {100*(1 - (matrix.data.nbytes + matrix.indices.nbytes + matrix.indptr.nbytes)/dense_memory):.2f}%" + ) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/optimize_weights.py b/policyengine_us_data/datasets/cps/local_area_calibration/optimize_weights.py new file mode 100644 index 00000000..cec00fc5 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/optimize_weights.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +import os +import argparse +from pathlib import Path +from datetime import datetime +import pickle +import torch +import numpy as np +from scipy import sparse as sp +from l0.calibration import SparseCalibrationWeights + + +def main(): + parser = argparse.ArgumentParser( + description="Run sparse L0 weight optimization" + ) + parser.add_argument( + "--input-dir", + required=True, + help="Directory containing calibration_package.pkl", + ) + parser.add_argument( + "--output-dir", required=True, help="Directory for output files" + ) + parser.add_argument( + "--beta", + type=float, + default=0.35, + help="Beta parameter for L0 regularization", + ) + parser.add_argument( + "--lambda-l0", + type=float, + default=5e-7, + help="L0 regularization strength", + ) + parser.add_argument( + "--lambda-l2", + type=float, + default=5e-9, + help="L2 regularization strength", + ) + parser.add_argument("--lr", type=float, default=0.1, help="Learning rate") + parser.add_argument( + "--total-epochs", type=int, default=12000, help="Total training epochs" + ) + parser.add_argument( + "--epochs-per-chunk", + type=int, + default=1000, + help="Epochs per logging chunk", + ) + parser.add_argument( + "--enable-logging", + action="store_true", + help="Enable detailed epoch logging", + ) + parser.add_argument( + "--device", + default="cuda", + choices=["cuda", "cpu"], + help="Device to use", + ) + + args = parser.parse_args() + + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading calibration package from {args.input_dir}") + with open(Path(args.input_dir) / "calibration_package.pkl", "rb") as f: + calibration_data = pickle.load(f) + + X_sparse = calibration_data["X_sparse"] + init_weights = calibration_data["initial_weights"] + targets_df = calibration_data["targets_df"] + targets = targets_df.value.values + + print(f"Matrix shape: {X_sparse.shape}") + print(f"Number of targets: {len(targets)}") + + target_names = [] + for _, row in targets_df.iterrows(): + geo_prefix = f"{row['geographic_id']}" + name = f"{geo_prefix}/{row['variable_desc']}" + target_names.append(name) + + model = SparseCalibrationWeights( + n_features=X_sparse.shape[1], + beta=args.beta, + gamma=-0.1, + zeta=1.1, + init_keep_prob=0.999, + init_weights=init_weights, + log_weight_jitter_sd=0.05, + log_alpha_jitter_sd=0.01, + device=args.device, + ) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if args.enable_logging: + log_path = output_dir / "cd_calibration_log.csv" + with open(log_path, "w") as f: + f.write( + "target_name,estimate,target,epoch,error,rel_error,abs_error,rel_abs_error,loss\n" + ) + print(f"Initialized incremental log at: {log_path}") + + sparsity_path = output_dir / f"cd_sparsity_history_{timestamp}.csv" + with open(sparsity_path, "w") as f: + f.write("epoch,active_weights,total_weights,sparsity_pct\n") + print(f"Initialized sparsity tracking at: {sparsity_path}") + + for chunk_start in range(0, args.total_epochs, args.epochs_per_chunk): + chunk_epochs = min( + args.epochs_per_chunk, args.total_epochs - chunk_start + ) + current_epoch = chunk_start + chunk_epochs + + print( + f"\nTraining epochs {chunk_start + 1} to {current_epoch} of {args.total_epochs}" + ) + + model.fit( + M=X_sparse, + y=targets, + target_groups=None, + lambda_l0=args.lambda_l0, + lambda_l2=args.lambda_l2, + lr=args.lr, + epochs=chunk_epochs, + loss_type="relative", + verbose=True, + verbose_freq=chunk_epochs, + ) + + active_info = model.get_active_weights() + active_count = active_info["count"] + total_count = X_sparse.shape[1] + sparsity_pct = 100 * (1 - active_count / total_count) + + with open(sparsity_path, "a") as f: + f.write( + f"{current_epoch},{active_count},{total_count},{sparsity_pct:.4f}\n" + ) + + if args.enable_logging: + with torch.no_grad(): + y_pred = model.predict(X_sparse).cpu().numpy() + + with open(log_path, "a") as f: + for i in range(len(targets)): + estimate = y_pred[i] + target = targets[i] + error = estimate - target + rel_error = error / target if target != 0 else 0 + abs_error = abs(error) + rel_abs_error = abs(rel_error) + loss = rel_error**2 + + f.write( + f'"{target_names[i]}",{estimate},{target},{current_epoch},' + f"{error},{rel_error},{abs_error},{rel_abs_error},{loss}\n" + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + with torch.no_grad(): + w = model.get_weights(deterministic=True).cpu().numpy() + + versioned_filename = f"w_cd_{timestamp}.npy" + full_path = output_dir / versioned_filename + np.save(full_path, w) + + canonical_path = output_dir / "w_cd.npy" + np.save(canonical_path, w) + + print(f"\nOptimization complete!") + print(f"Final weights saved to: {full_path}") + print(f"Canonical weights saved to: {canonical_path}") + print(f"Weights shape: {w.shape}") + print(f"Sparsity history saved to: {sparsity_path}") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/run_holdout_fold.py b/policyengine_us_data/datasets/cps/local_area_calibration/run_holdout_fold.py new file mode 100644 index 00000000..42dea309 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/run_holdout_fold.py @@ -0,0 +1,184 @@ +import os +import numpy as np +import pandas as pd +import pickle +from scipy import sparse as sp +from holdout_validation import run_holdout_experiment, simple_holdout + +# Load the calibration package +export_dir = os.path.expanduser("~/Downloads/cd_calibration_data") +package_path = os.path.join(export_dir, "calibration_package.pkl") + +print(f"Loading calibration package from: {package_path}") +with open(package_path, "rb") as f: + data = pickle.load(f) + +print(f"Keys in data: {data.keys()}") + +X_sparse = data["X_sparse"] +targets_df = data["targets_df"] +targets = targets_df.value.values +target_groups = data["target_groups"] +init_weights = data["initial_weights"] +keep_probs = data["keep_probs"] + +print(f"Loaded {len(targets_df)} targets") +print(f"Target groups shape: {target_groups.shape}") +print(f"Unique groups: {len(np.unique(target_groups))}") + +# EXPLORE TARGET GROUPS ---------------------------- +unique_groups = np.unique(target_groups) +group_details = [] + +print(f"\nProcessing {len(unique_groups)} groups...") + +for group_id in unique_groups: + group_mask = target_groups == group_id + group_targets = targets_df[group_mask].copy() + + n_targets = len(group_targets) + geos = group_targets["geographic_id"].unique() + variables = group_targets["variable"].unique() + var_descs = group_targets["variable_desc"].unique() + + # Classify the group type + if len(geos) == 1 and len(variables) == 1: + if len(var_descs) > 1: + group_type = f"Single geo/var with {len(var_descs)} bins" + else: + group_type = "Single target" + elif len(geos) > 1 and len(variables) == 1: + group_type = f"Multi-geo ({len(geos)} geos), single var" + else: + group_type = f"Complex: {len(geos)} geos, {len(variables)} vars" + + detail = { + "group_id": group_id, + "n_targets": n_targets, + "group_type": group_type, + "geos": list(geos)[:3], # First 3 for display + "n_geos": len(geos), + "variable": ( + variables[0] if len(variables) == 1 else f"{len(variables)} vars" + ), + "sample_desc": var_descs[0] if len(var_descs) > 0 else None, + } + group_details.append(detail) + +groups_df = pd.DataFrame(group_details) + +if groups_df.empty: + print("WARNING: groups_df is empty!") + print(f"group_details has {len(group_details)} items") + if len(group_details) > 0: + print(f"First item: {group_details[0]}") +else: + print(f"\nCreated groups_df with {len(groups_df)} rows") + +# Improve the variable column for complex groups +for idx, row in groups_df.iterrows(): + if "2 vars" in str(row["variable"]) or "vars" in str(row["variable"]): + # Get the actual variables for this group + group_mask = target_groups == row["group_id"] + group_targets = targets_df[group_mask] + variables = group_targets["variable"].unique() + # Update with actual variable names + groups_df.at[idx, "variable"] = ", ".join(variables[:2]) + +# Show all groups for selection +print("\nAll target groups (use group_id for selection):") +print( + groups_df[["group_id", "n_targets", "variable", "group_type"]].to_string() +) + +# CSV export moved to end of file after results + +# INTERACTIVE HOLDOUT SELECTION ------------------------------- + +# EDIT THIS LINE: Choose your group_id values from the table above +N_GROUPS = groups_df.shape[0] + +age_ids = [30] +first_5_national_ids = [0, 1, 2, 3, 4] +second_5_national_ids = [5, 6, 7, 8, 9] +third_5_national_ids = [10, 11, 12, 13, 14] +agi_histogram_ids = [31] +agi_value_ids = [33] +eitc_cds_value_ids = [35] +last_15_national_ids = [i for i in range(15, 30)] + +union_ids = ( + age_ids + + first_5_national_ids + + second_5_national_ids + + third_5_national_ids + + agi_histogram_ids + + agi_value_ids + + eitc_cds_value_ids + + last_15_national_ids +) + +len(union_ids) + +holdout_group_ids = [i for i in range(N_GROUPS) if i not in union_ids] +len(holdout_group_ids) + + +# Make age the only holdout: +union_ids = [i for i in range(N_GROUPS) if i not in age_ids] +holdout_group_ids = age_ids + +assert len(union_ids) + len(holdout_group_ids) == N_GROUPS + +results = simple_holdout( + X_sparse=X_sparse, + targets=targets, + target_groups=target_groups, + init_weights=init_weights, + holdout_group_ids=holdout_group_ids, + targets_df=targets_df, # Pass targets_df for hierarchical analysis + check_hierarchical=True, # Enable hierarchical consistency check + epochs=2000, + lambda_l0=0, # 8e-7, + lr=0.3, + verbose_spacing=100, + device="cpu", +) + +# CREATE RESULTS DATAFRAME +# Build a comprehensive results dataframe +results_data = [] + +# Add training groups +for group_id, loss in results["train_group_losses"].items(): + # Get group info from original groups_df + if group_id in groups_df["group_id"].values: + group_info = groups_df[groups_df["group_id"] == group_id].iloc[0] + results_data.append( + { + "group_id": group_id, + "set": "train", + "loss": loss, + "n_targets": group_info["n_targets"], + "variable": group_info["variable"], + "group_type": group_info["group_type"], + } + ) + +# Add holdout groups (now using original IDs directly) +for group_id, loss in results["holdout_group_losses"].items(): + if group_id in groups_df["group_id"].values: + group_info = groups_df[groups_df["group_id"] == group_id].iloc[0] + results_data.append( + { + "group_id": group_id, + "set": "holdout", + "loss": loss, + "n_targets": group_info["n_targets"], + "variable": group_info["variable"], + "group_type": group_info["group_type"], + } + ) + +results_df = pd.DataFrame(results_data) +results_df = results_df.sort_values(["set", "loss"], ascending=[True, False]) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py new file mode 100644 index 00000000..4aa28b84 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/sparse_matrix_builder.py @@ -0,0 +1,186 @@ +""" +Sparse matrix builder for geo-stacking calibration. + +Generic, database-driven approach where all constraints (including geographic) +are evaluated as masks. Geographic constraints work because we SET state_fips +before evaluating constraints. +""" + +from collections import defaultdict +from typing import Dict, List, Optional, Tuple +import numpy as np +import pandas as pd +from scipy import sparse +from sqlalchemy import create_engine, text + +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + get_calculated_variables, + apply_op, + _get_geo_level, +) + + +class SparseMatrixBuilder: + """Build sparse calibration matrices for geo-stacking.""" + + def __init__(self, db_uri: str, time_period: int, cds_to_calibrate: List[str], + dataset_path: Optional[str] = None): + self.db_uri = db_uri + self.engine = create_engine(db_uri) + self.time_period = time_period + self.cds_to_calibrate = cds_to_calibrate + self.dataset_path = dataset_path + + def _query_targets(self, target_filter: dict) -> pd.DataFrame: + """Query targets based on filter criteria using OR logic.""" + or_conditions = [] + + if "stratum_group_ids" in target_filter: + ids = ",".join(map(str, target_filter["stratum_group_ids"])) + or_conditions.append(f"s.stratum_group_id IN ({ids})") + + if "variables" in target_filter: + vars_str = ",".join(f"'{v}'" for v in target_filter["variables"]) + or_conditions.append(f"t.variable IN ({vars_str})") + + if "target_ids" in target_filter: + ids = ",".join(map(str, target_filter["target_ids"])) + or_conditions.append(f"t.target_id IN ({ids})") + + if "stratum_ids" in target_filter: + ids = ",".join(map(str, target_filter["stratum_ids"])) + or_conditions.append(f"t.stratum_id IN ({ids})") + + if not or_conditions: + raise ValueError("target_filter must specify at least one filter criterion") + + where_clause = " OR ".join(f"({c})" for c in or_conditions) + + query = f""" + SELECT t.target_id, t.stratum_id, t.variable, t.value, t.period, + s.stratum_group_id + FROM targets t + JOIN strata s ON t.stratum_id = s.stratum_id + WHERE {where_clause} + ORDER BY t.target_id + """ + + with self.engine.connect() as conn: + return pd.read_sql(query, conn) + + def _get_constraints(self, stratum_id: int) -> List[dict]: + """Get all constraints for a stratum (including geographic).""" + query = """ + SELECT constraint_variable as variable, operation, value + FROM stratum_constraints + WHERE stratum_id = :stratum_id + """ + with self.engine.connect() as conn: + df = pd.read_sql(query, conn, params={"stratum_id": stratum_id}) + return df.to_dict('records') + + def _get_geographic_id(self, stratum_id: int) -> str: + """Extract geographic_id from constraints for targets_df.""" + constraints = self._get_constraints(stratum_id) + for c in constraints: + if c['variable'] == 'state_fips': + return c['value'] + if c['variable'] == 'congressional_district_geoid': + return c['value'] + return 'US' + + def _create_state_sim(self, state: int, n_households: int): + """Create a fresh simulation with state_fips set to given state.""" + from policyengine_us import Microsimulation + state_sim = Microsimulation(dataset=self.dataset_path) + state_sim.set_input("state_fips", self.time_period, + np.full(n_households, state, dtype=np.int32)) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + return state_sim + + def build_matrix(self, sim, target_filter: dict) -> Tuple[pd.DataFrame, sparse.csr_matrix, Dict[str, List[str]]]: + """ + Build sparse calibration matrix. + + Args: + sim: Microsimulation instance (used for household_ids, or as template) + target_filter: Dict specifying which targets to include + - {"stratum_group_ids": [4]} for SNAP targets + - {"target_ids": [123, 456]} for specific targets + + Returns: + Tuple of (targets_df, X_sparse, household_id_mapping) + """ + household_ids = sim.calculate("household_id", map_to="household").values + n_households = len(household_ids) + n_cds = len(self.cds_to_calibrate) + n_cols = n_households * n_cds + + targets_df = self._query_targets(target_filter) + n_targets = len(targets_df) + + if n_targets == 0: + raise ValueError("No targets found matching filter") + + targets_df['geographic_id'] = targets_df['stratum_id'].apply(self._get_geographic_id) + + # Sort by (geo_level, variable, geographic_id) for contiguous group rows + targets_df['_geo_level'] = targets_df['geographic_id'].apply(_get_geo_level) + targets_df = targets_df.sort_values(['_geo_level', 'variable', 'geographic_id']) + targets_df = targets_df.drop(columns=['_geo_level']).reset_index(drop=True) + + X = sparse.lil_matrix((n_targets, n_cols), dtype=np.float32) + + cds_by_state = defaultdict(list) + for cd_idx, cd in enumerate(self.cds_to_calibrate): + state = int(cd) // 100 + cds_by_state[state].append((cd_idx, cd)) + + for state, cd_list in cds_by_state.items(): + if self.dataset_path: + state_sim = self._create_state_sim(state, n_households) + else: + state_sim = sim + state_sim.set_input("state_fips", self.time_period, + np.full(n_households, state, dtype=np.int32)) + for var in get_calculated_variables(state_sim): + state_sim.delete_arrays(var) + + for cd_idx, cd in cd_list: + col_start = cd_idx * n_households + + for row_idx, (_, target) in enumerate(targets_df.iterrows()): + constraints = self._get_constraints(target['stratum_id']) + + mask = np.ones(n_households, dtype=bool) + for c in constraints: + if c['variable'] == 'congressional_district_geoid': + if c['operation'] in ('==', '=') and c['value'] != cd: + mask[:] = False + elif c['variable'] == 'state_fips': + if c['operation'] in ('==', '=') and int(c['value']) != state: + mask[:] = False + else: + try: + values = state_sim.calculate(c['variable'], map_to='household').values + mask &= apply_op(values, c['operation'], c['value']) + except Exception: + pass + + if not mask.any(): + continue + + target_values = state_sim.calculate(target['variable'], map_to='household').values + masked_values = (target_values * mask).astype(np.float32) + + nonzero = np.where(masked_values != 0)[0] + if len(nonzero) > 0: + X[row_idx, col_start + nonzero] = masked_values[nonzero] + + household_id_mapping = {} + for cd in self.cds_to_calibrate: + key = f"cd{cd}" + household_id_mapping[key] = [f"{hh_id}_{key}" for hh_id in household_ids] + + return targets_df, X.tocsr(), household_id_mapping diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py new file mode 100644 index 00000000..2be964fc --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/stacked_dataset_builder.py @@ -0,0 +1,953 @@ +""" +Create a sparse congressional district-stacked dataset with only non-zero weight households. +Standalone version that doesn't modify the working state stacking code. +""" +import sys +import numpy as np +import pandas as pd +import h5py +import os +import json +import random +from pathlib import Path +from policyengine_us import Microsimulation +from policyengine_core.data.dataset import Dataset +from policyengine_core.enums import Enum +from sqlalchemy import create_engine, text +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + download_from_huggingface, + get_cd_index_mapping, + get_id_range_for_cd, + get_cd_from_id, + get_all_cds_from_database, + get_calculated_variables, + STATE_CODES, + STATE_FIPS_TO_NAME, + STATE_FIPS_TO_CODE, +) +from policyengine_us.variables.household.demographic.geographic.county.county_enum import ( + County, +) + + +def load_cd_county_mappings(): + """Load CD to county mappings from JSON file.""" + #script_dir = Path(__file__).parent + #mapping_file = script_dir / "cd_county_mappings.json" + mapping_file = Path.cwd() / "cd_county_mappings.json" + if not mapping_file.exists(): + print( + "WARNING: cd_county_mappings.json not found. Counties will not be updated." + ) + return None + + with open(mapping_file, "r") as f: + return json.load(f) + + +def get_county_for_cd(cd_geoid, cd_county_mappings): + """ + Get a county FIPS code for a given congressional district. + Uses weighted random selection based on county proportions. + """ + if not cd_county_mappings or str(cd_geoid) not in cd_county_mappings: + return None + + county_props = cd_county_mappings[str(cd_geoid)] + if not county_props: + return None + + counties = list(county_props.keys()) + weights = list(county_props.values()) + + # Normalize weights to ensure they sum to 1 + total_weight = sum(weights) + if total_weight > 0: + weights = [w / total_weight for w in weights] + return random.choices(counties, weights=weights)[0] + + return None + + +def create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=None, + output_path=None, + dataset_path=None, +): + """ + Create a SPARSE congressional district-stacked dataset using DataFrame approach. + + Args: + w: Calibrated weight vector from L0 calibration (length = n_households * n_cds) + cds_to_calibrate: List of CD GEOID codes used in calibration + cd_subset: Optional list of CD GEOIDs to include (subset of cds_to_calibrate) + output_path: Where to save the sparse CD-stacked h5 file + dataset_path: Path to the base .h5 dataset used to create the training matrices + """ + + # Handle CD subset filtering + if cd_subset is not None: + # Validate that requested CDs are in the calibration + for cd in cd_subset: + if cd not in cds_to_calibrate: + raise ValueError(f"CD {cd} not in calibrated CDs list") + + # Get indices of requested CDs + cd_indices = [cds_to_calibrate.index(cd) for cd in cd_subset] + cds_to_process = cd_subset + + print( + f"Processing subset of {len(cd_subset)} CDs: {', '.join(cd_subset[:5])}..." + ) + else: + # Process all CDs + cd_indices = list(range(len(cds_to_calibrate))) + cds_to_process = cds_to_calibrate + print( + f"Processing all {len(cds_to_calibrate)} congressional districts" + ) + + # Generate output path if not provided + if output_path is None: + raise ValueError("No output .h5 path given") + print(f"Output path: {output_path}") + + # Check that output directory exists, create if needed + output_dir_path = os.path.dirname(output_path) + if output_dir_path and not os.path.exists(output_dir_path): + print(f"Creating output directory: {output_dir_path}") + os.makedirs(output_dir_path, exist_ok=True) + + # Load the original simulation + base_sim = Microsimulation(dataset=dataset_path) + + cd_county_mappings = load_cd_county_mappings() + + household_ids = base_sim.calculate( + "household_id", map_to="household" + ).values + n_households_orig = len(household_ids) + + # From the base sim, create mapping from household ID to index for proper filtering + hh_id_to_idx = {int(hh_id): idx for idx, hh_id in enumerate(household_ids)} + + # I.e., + # {25: 0, + # 78: 1, + # 103: 2, + # 125: 3, + + # Infer the number of households from weight vector and CD count + if len(w) % len(cds_to_calibrate) != 0: + raise ValueError( + f"Weight vector length ({len(w):,}) is not evenly divisible by " + f"number of CDs ({len(cds_to_calibrate)}). Cannot determine household count." + ) + n_households_from_weights = len(w) // len(cds_to_calibrate) + + if n_households_from_weights != n_households_orig: + raise ValueError("Households from base data set do not match households from weights") + + print(f"\nOriginal dataset has {n_households_orig:,} households") + + # Process the weight vector to understand active household-CD pairs + W_full = w.reshape(len(cds_to_calibrate), n_households_orig) + # (436, 10580) + + # Extract only the CDs we want to process + if cd_subset is not None: + W = W_full[cd_indices, :] + print( + f"Extracted weights for {len(cd_indices)} CDs from full weight matrix" + ) + else: + W = W_full + + # Count total active weights: i.e., number of active households + total_active_weights = np.sum(W > 0) + total_weight_in_W = np.sum(W) + print(f"Total active household-CD pairs: {total_active_weights:,}") + print(f"Total weight in W matrix: {total_weight_in_W:,.0f}") + + # Collect DataFrames for each CD + cd_dfs = [] + total_kept_households = 0 + #total_calibrated_weight = 0 + #total_kept_weight = 0 + time_period = int(base_sim.default_calculation_period) + + for idx, cd_geoid in enumerate(cds_to_process): + # Progress every 10 CDs and at the end ---- + if (idx + 1) % 10 == 0 or (idx + 1) == len(cds_to_process): + print( + f"Processing CD {cd_geoid} ({idx + 1}/{len(cds_to_process)})..." + ) + + # Get the correct index in the weight matrix + cd_idx = idx # Index in our filtered W matrix + + # Get ALL households with non-zero weight in this CD + active_household_indices = np.where(W[cd_idx, :] > 0)[0] + + if len(active_household_indices) == 0: + continue + + # Get the household IDs for active households + active_household_ids = set( + household_ids[hh_idx] for hh_idx in active_household_indices + ) + + # Fresh simulation + cd_sim = Microsimulation(dataset=dataset_path) + + # First, create hh_df with CALIBRATED weights from the W matrix + household_ids_in_sim = cd_sim.calculate( + "household_id", map_to="household" + ).values + + # Get this CD's calibrated weights from the weight matrix + calibrated_weights_for_cd = W[cd_idx, :] # Get this CD's row from weight matrix + + # Map the calibrated weights to household IDs + hh_weight_values = [] + for hh_id in household_ids_in_sim: + hh_idx = hh_id_to_idx[int(hh_id)] # Get index in weight matrix + hh_weight_values.append(calibrated_weights_for_cd[hh_idx]) + + # TODO: do I need this? + entity_rel = pd.DataFrame( + { + "person_id": cd_sim.calculate( + "person_id", map_to="person" + ).values, + "household_id": cd_sim.calculate( + "household_id", map_to="person" + ).values, + "tax_unit_id": cd_sim.calculate( + "tax_unit_id", map_to="person" + ).values, + "spm_unit_id": cd_sim.calculate( + "spm_unit_id", map_to="person" + ).values, + "family_id": cd_sim.calculate( + "family_id", map_to="person" + ).values, + "marital_unit_id": cd_sim.calculate( + "marital_unit_id", map_to="person" + ).values, + } + ) + + hh_df = pd.DataFrame( + { + "household_id": household_ids_in_sim, + "household_weight": hh_weight_values, + } + ) + counts = entity_rel.groupby('household_id')['person_id'].size().reset_index(name="persons_per_hh") + hh_df = hh_df.merge(counts) + hh_df['per_person_hh_weight'] = hh_df.household_weight / hh_df.persons_per_hh + + ## Now create person_rel with calibrated household weights + #person_ids = cd_sim.calculate("person_id", map_to="person").values + #person_household_ids = cd_sim.calculate("household_id", map_to="person").values + #person_tax_unit_ids = cd_sim.calculate("tax_unit_id", map_to="person").values + + ## Map calibrated household weights to person level + #hh_weight_map = dict(zip(hh_df['household_id'], hh_df['household_weight'])) + #person_household_weights = [hh_weight_map[int(hh_id)] for hh_id in person_household_ids] + + #person_rel = pd.DataFrame( + # { + # "person_id": person_ids, + # "household_id": person_household_ids, + # "household_weight": person_household_weights, + # "tax_unit_id": person_tax_unit_ids, + # } + #) + + ## Calculate person weights based on calibrated household weights + ## Person weight equals household weight (each person represents the household weight) + #person_rel['person_weight'] = person_rel['household_weight'] + + ## Tax unit weight: each tax unit gets the weight of its household + #tax_unit_df = person_rel.groupby('tax_unit_id').agg( + # tax_unit_weight=('household_weight', 'first') + #).reset_index() + + ## SPM unit weight: each SPM unit gets the weight of its household + #person_spm_ids = cd_sim.calculate('spm_unit_id', map_to='person').values + #person_rel['spm_unit_id'] = person_spm_ids + #spm_unit_df = person_rel.groupby('spm_unit_id').agg( + # spm_unit_weight=('household_weight', 'first') + #).reset_index() + + ## Marital unit weight: each marital unit gets the weight of its household + #person_marital_ids = cd_sim.calculate('marital_unit_id', map_to='person').values + #person_rel['marital_unit_id'] = person_marital_ids + #marital_unit_df = person_rel.groupby('marital_unit_id').agg( + # marital_unit_weight=('household_weight', 'first') + #).reset_index() + + ## Track calibrated weight for this CD + #cd_calibrated_weight = calibrated_weights_for_cd.sum() + #cd_active_weight = calibrated_weights_for_cd[calibrated_weights_for_cd > 0].sum() + + # SET WEIGHTS IN SIMULATION BEFORE EXTRACTING DATAFRAME + # This is the key - set_input updates the simulation's internal state + + non_household_cols = ['person_id', 'tax_unit_id', 'spm_unit_id', 'family_id', 'marital_unit_id'] + + new_weights_per_id = {} + for col in non_household_cols: + person_counts = entity_rel.groupby(col)['person_id'].size().reset_index(name="person_id_count") + # Below: drop duplicates to undo the broadcast join done in entity_rel + id_link = entity_rel[['household_id', col]].drop_duplicates() + hh_info = id_link.merge(hh_df) + + hh_info2 = hh_info.merge(person_counts, on=col) + if col == 'person_id': + # Person weight = household weight (each person represents same count as their household) + hh_info2["id_weight"] = hh_info2.household_weight + else: + hh_info2["id_weight"] = hh_info2.per_person_hh_weight * hh_info2.person_id_count + new_weights_per_id[col] = hh_info2.id_weight + + cd_sim.set_input("household_weight", time_period, hh_df.household_weight.values) + cd_sim.set_input("person_weight", time_period, new_weights_per_id['person_id']) + cd_sim.set_input("tax_unit_weight", time_period, new_weights_per_id['tax_unit_id']) + cd_sim.set_input("spm_unit_weight", time_period, new_weights_per_id['spm_unit_id']) + cd_sim.set_input("marital_unit_weight", time_period, new_weights_per_id['marital_unit_id']) + cd_sim.set_input("family_weight", time_period, new_weights_per_id['family_id']) + + # Extract state from CD GEOID and update simulation BEFORE calling to_input_dataframe() + # This ensures calculated variables (SNAP, Medicaid) use the correct state + cd_geoid_int = int(cd_geoid) + state_fips = cd_geoid_int // 100 + + cd_sim.set_input("state_fips", time_period, + np.full(n_households_orig, state_fips, dtype=np.int32)) + cd_sim.set_input("congressional_district_geoid", time_period, + np.full(n_households_orig, cd_geoid_int, dtype=np.int32)) + + # Delete cached calculated variables to ensure they're recalculated with new state + for var in get_calculated_variables(cd_sim): + cd_sim.delete_arrays(var) + + # Now extract the dataframe - calculated vars will use the updated state + df = cd_sim.to_input_dataframe() + + assert df.shape[0] == entity_rel.shape[0] # df is at the person level + + # Column names follow pattern: variable__year + hh_id_col = f"household_id__{time_period}" + cd_geoid_col = f"congressional_district_geoid__{time_period}" + hh_weight_col = f"household_weight__{time_period}" + person_weight_col = f"person_weight__{time_period}" + tax_unit_weight_col = f"tax_unit_weight__{time_period}" + person_id_col = f"person_id__{time_period}" + tax_unit_id_col = f"tax_unit_id__{time_period}" + + state_fips_col = f"state_fips__{time_period}" + state_name_col = f"state_name__{time_period}" + state_code_col = f"state_code__{time_period}" + county_fips_col = f"county_fips__{time_period}" + county_col = f"county__{time_period}" + county_str_col = f"county_str__{time_period}" + + # Filter to only active households in this CD + df_filtered = df[df[hh_id_col].isin(active_household_ids)].copy() + + ## Track weight after filtering - need to group by household since df_filtered is person-level + #df_filtered_weight = df_filtered.groupby(hh_id_col)[hh_weight_col].first().sum() + + #if abs(cd_active_weight - df_filtered_weight) > 10: + # print(f" CD {cd_geoid}: Calibrated active weight = {cd_active_weight:,.0f}, df_filtered weight = {df_filtered_weight:,.0f}, LOST {cd_active_weight - df_filtered_weight:,.0f}") + + #total_calibrated_weight += cd_active_weight + #total_kept_weight += df_filtered_weight + + # Update congressional_district_geoid to target CD + df_filtered[cd_geoid_col] = int(cd_geoid) + + # Extract state FIPS from CD GEOID (first 1-2 digits) + cd_geoid_int = int(cd_geoid) + state_fips = cd_geoid_int // 100 + + # Update state variables for consistency + df_filtered[state_fips_col] = state_fips + if state_fips in STATE_FIPS_TO_NAME: + df_filtered[state_name_col] = STATE_FIPS_TO_NAME[state_fips] + if state_fips in STATE_FIPS_TO_CODE: + df_filtered[state_code_col] = STATE_FIPS_TO_CODE[state_fips] + + # Update county variables if we have mappings + if cd_county_mappings: + # For each household, assign a county based on CD proportions + unique_hh_ids = df_filtered[hh_id_col].unique() + hh_to_county = {} + + for hh_id in unique_hh_ids: + county_fips = get_county_for_cd(cd_geoid, cd_county_mappings) + if county_fips: + hh_to_county[hh_id] = county_fips + else: + hh_to_county[hh_id] = "" + + if hh_to_county and any(hh_to_county.values()): + # Map household to county FIPS string + county_fips_str = df_filtered[hh_id_col].map(hh_to_county) + + # Convert FIPS string to integer for county_fips column + # Handle empty strings by converting to 0 + df_filtered[county_fips_col] = county_fips_str.apply( + lambda x: int(x) if x and x != "" else 0 + ) + + # Set county enum to UNKNOWN (since we don't have specific enum values) + df_filtered[county_col] = County.UNKNOWN + + # Set county_str to the string representation of FIPS + df_filtered[county_str_col] = county_fips_str + + cd_dfs.append(df_filtered) + total_kept_households += len(df_filtered[hh_id_col].unique()) + + print(f"\nCombining {len(cd_dfs)} CD DataFrames...") + print(f"Total households across all CDs: {total_kept_households:,}") + #print(f"\nWeight tracking:") + #print(f" Total calibrated active weight: {total_calibrated_weight:,.0f}") + #print(f" Total kept weight in df_filtered: {total_kept_weight:,.0f}") + #print(f" Weight retention: {100 * total_kept_weight / total_calibrated_weight:.2f}%") + + # Combine all CD DataFrames + combined_df = pd.concat(cd_dfs, ignore_index=True) + print(f"Combined DataFrame shape: {combined_df.shape}") + + # REINDEX ALL IDs TO PREVENT OVERFLOW AND HANDLE DUPLICATES + print("\nReindexing all entity IDs using 25k ranges per CD...") + + # Column names + hh_id_col = f"household_id__{time_period}" + person_id_col = f"person_id__{time_period}" + person_hh_id_col = f"person_household_id__{time_period}" + tax_unit_id_col = f"tax_unit_id__{time_period}" + person_tax_unit_col = f"person_tax_unit_id__{time_period}" + spm_unit_id_col = f"spm_unit_id__{time_period}" + person_spm_unit_col = f"person_spm_unit_id__{time_period}" + marital_unit_id_col = f"marital_unit_id__{time_period}" + person_marital_unit_col = f"person_marital_unit_id__{time_period}" + cd_geoid_col = f"congressional_district_geoid__{time_period}" + + # Cache the CD mapping to avoid thousands of database calls! + cd_to_index, _, _ = get_cd_index_mapping() + + # Create household mapping for CSV export + household_mapping = [] + + # First, create a unique row identifier to track relationships + combined_df["_row_idx"] = range(len(combined_df)) + + # Group by household ID AND congressional district to create unique household-CD pairs + hh_groups = ( + combined_df.groupby([hh_id_col, cd_geoid_col])["_row_idx"] + .apply(list) + .to_dict() + ) + + # Assign new household IDs using 25k ranges per CD + hh_row_to_new_id = {} + cd_hh_counters = {} # Track how many households assigned per CD + + for (old_hh_id, cd_geoid), row_indices in hh_groups.items(): + # Calculate the ID range for this CD directly (avoiding function call) + cd_str = str(int(cd_geoid)) + cd_idx = cd_to_index[cd_str] + start_id = cd_idx * 25_000 + end_id = start_id + 24_999 + + # Get the next available ID in this CD's range + if cd_str not in cd_hh_counters: + cd_hh_counters[cd_str] = 0 + + new_hh_id = start_id + cd_hh_counters[cd_str] + + # Check we haven't exceeded the range + if new_hh_id > end_id: + raise ValueError( + f"CD {cd_str} exceeded its 25k household allocation" + ) + + # All rows in the same household-CD pair get the SAME new ID + for row_idx in row_indices: + hh_row_to_new_id[row_idx] = new_hh_id + + # Save the mapping + household_mapping.append( + { + "new_household_id": new_hh_id, + "original_household_id": int(old_hh_id), + "congressional_district": cd_str, + "state_fips": int(cd_str) // 100, + } + ) + + cd_hh_counters[cd_str] += 1 + + # Apply new household IDs based on row index + combined_df["_new_hh_id"] = combined_df["_row_idx"].map(hh_row_to_new_id) + + # Update household IDs + combined_df[hh_id_col] = combined_df["_new_hh_id"] + + # Update person household references - since persons are already in their households, + # person_household_id should just match the household_id of their row + combined_df[person_hh_id_col] = combined_df["_new_hh_id"] + + # Report statistics + total_households = sum(cd_hh_counters.values()) + print( + f" Created {total_households:,} unique households across {len(cd_hh_counters)} CDs" + ) + + # Now handle persons with same 25k range approach - VECTORIZED + print(" Reindexing persons using 25k ranges...") + + # OFFSET PERSON IDs by 5 million to avoid collision with household IDs + PERSON_ID_OFFSET = 5_000_000 + + # Group by CD and assign IDs in bulk for each CD + for cd_geoid_val in combined_df[cd_geoid_col].unique(): + cd_str = str(int(cd_geoid_val)) + + # Calculate the ID range for this CD directly + cd_idx = cd_to_index[cd_str] + start_id = cd_idx * 25_000 + PERSON_ID_OFFSET # Add offset for persons + end_id = start_id + 24_999 + + # Get all rows for this CD + cd_mask = combined_df[cd_geoid_col] == cd_geoid_val + n_persons_in_cd = cd_mask.sum() + + # Check we won't exceed the range + if n_persons_in_cd > (end_id - start_id + 1): + raise ValueError( + f"CD {cd_str} has {n_persons_in_cd} persons, exceeds 25k allocation" + ) + + # Create sequential IDs for this CD + new_person_ids = np.arange(start_id, start_id + n_persons_in_cd) + + # Assign all at once using loc + combined_df.loc[cd_mask, person_id_col] = new_person_ids + + # Tax units - preserve structure within households + print(" Reindexing tax units...") + # Group by household first, then handle tax units within each household + new_tax_id = 0 + for hh_id in combined_df[hh_id_col].unique(): + hh_mask = combined_df[hh_id_col] == hh_id + hh_df = combined_df[hh_mask] + + # Get unique tax units within this household + unique_tax_in_hh = hh_df[person_tax_unit_col].unique() + + # Create mapping for this household's tax units + for old_tax in unique_tax_in_hh: + # Update all persons with this tax unit ID in this household + mask = (combined_df[hh_id_col] == hh_id) & ( + combined_df[person_tax_unit_col] == old_tax + ) + combined_df.loc[mask, person_tax_unit_col] = new_tax_id + # Also update tax_unit_id if it exists in the DataFrame + if tax_unit_id_col in combined_df.columns: + combined_df.loc[mask, tax_unit_id_col] = new_tax_id + new_tax_id += 1 + + # SPM units - preserve structure within households + print(" Reindexing SPM units...") + new_spm_id = 0 + for hh_id in combined_df[hh_id_col].unique(): + hh_mask = combined_df[hh_id_col] == hh_id + hh_df = combined_df[hh_mask] + + # Get unique SPM units within this household + unique_spm_in_hh = hh_df[person_spm_unit_col].unique() + + for old_spm in unique_spm_in_hh: + # Update all persons with this SPM unit ID in this household + mask = (combined_df[hh_id_col] == hh_id) & ( + combined_df[person_spm_unit_col] == old_spm + ) + combined_df.loc[mask, person_spm_unit_col] = new_spm_id + # Also update spm_unit_id if it exists + if spm_unit_id_col in combined_df.columns: + combined_df.loc[mask, spm_unit_id_col] = new_spm_id + new_spm_id += 1 + + # Marital units - preserve structure within households + print(" Reindexing marital units...") + new_marital_id = 0 + for hh_id in combined_df[hh_id_col].unique(): + hh_mask = combined_df[hh_id_col] == hh_id + hh_df = combined_df[hh_mask] + + # Get unique marital units within this household + unique_marital_in_hh = hh_df[person_marital_unit_col].unique() + + for old_marital in unique_marital_in_hh: + # Update all persons with this marital unit ID in this household + mask = (combined_df[hh_id_col] == hh_id) & ( + combined_df[person_marital_unit_col] == old_marital + ) + combined_df.loc[mask, person_marital_unit_col] = new_marital_id + # Also update marital_unit_id if it exists + if marital_unit_id_col in combined_df.columns: + combined_df.loc[mask, marital_unit_id_col] = new_marital_id + new_marital_id += 1 + + # Clean up temporary columns + temp_cols = [col for col in combined_df.columns if col.startswith("_")] + combined_df = combined_df.drop(columns=temp_cols) + + print(f" Final persons: {len(combined_df):,}") + print(f" Final households: {total_households:,}") + print(f" Final tax units: {new_tax_id:,}") + print(f" Final SPM units: {new_spm_id:,}") + print(f" Final marital units: {new_marital_id:,}") + + # Check weights in combined_df AFTER reindexing + print(f"\nWeights in combined_df AFTER reindexing:") + print(f" HH weight sum: {combined_df[hh_weight_col].sum()/1e6:.2f}M") + print(f" Person weight sum: {combined_df[person_weight_col].sum()/1e6:.2f}M") + print(f" Ratio: {combined_df[person_weight_col].sum() / combined_df[hh_weight_col].sum():.2f}") + + # Verify no overflow risk + max_person_id = combined_df[person_id_col].max() + print(f"\nOverflow check:") + print(f" Max person ID after reindexing: {max_person_id:,}") + print(f" Max person ID × 100: {max_person_id * 100:,}") + print(f" int32 max: {2_147_483_647:,}") + if max_person_id * 100 < 2_147_483_647: + print(" ✓ No overflow risk!") + else: + print(" ⚠️ WARNING: Still at risk of overflow!") + + # Create Dataset from combined DataFrame + print("\nCreating Dataset from combined DataFrame...") + sparse_dataset = Dataset.from_dataframe(combined_df, time_period) + + # Build a simulation to convert to h5 + print("Building simulation from Dataset...") + sparse_sim = Microsimulation() + sparse_sim.dataset = sparse_dataset + sparse_sim.build_from_dataset() + + # Save to h5 file + print(f"\nSaving to {output_path}...") + data = {} + + # Only save input variables (not calculated/derived variables) + # Calculated variables like state_name, state_code will be recalculated on load + input_vars = set(sparse_sim.input_variables) + print(f"Found {len(input_vars)} input variables (excluding calculated variables)") + + vars_to_save = input_vars.copy() + + # congressional_district_geoid isn't in the original microdata and has no formula, + # so it's not in input_vars. Since we set it explicitly during stacking, save it. + vars_to_save.add('congressional_district_geoid') + + variables_saved = 0 + variables_skipped = 0 + + for variable in sparse_sim.tax_benefit_system.variables: + if variable not in vars_to_save: + variables_skipped += 1 + continue + + # Only process variables that have actual data + data[variable] = {} + for period in sparse_sim.get_holder(variable).get_known_periods(): + values = sparse_sim.get_holder(variable).get_array(period) + + # Handle different value types + if ( + sparse_sim.tax_benefit_system.variables.get( + variable + ).value_type + in (Enum, str) + and variable != "county_fips" + ): + # Handle EnumArray objects + if hasattr(values, "decode_to_str"): + values = values.decode_to_str().astype("S") + else: + # Already a regular numpy array, just convert to string type + values = values.astype("S") + elif variable == "county_fips": + values = values.astype("int32") + else: + values = np.array(values) + + if values is not None: + data[variable][period] = values + variables_saved += 1 + + if len(data[variable]) == 0: + del data[variable] + + print(f"Variables saved: {variables_saved}") + print(f"Variables skipped: {variables_skipped}") + + # Write to h5 + with h5py.File(output_path, "w") as f: + for variable, periods in data.items(): + grp = f.create_group(variable) + for period, values in periods.items(): + grp.create_dataset(str(period), data=values) + + print(f"Sparse CD-stacked dataset saved successfully!") + + # Save household mapping to CSV in a mappings subdirectory + mapping_df = pd.DataFrame(household_mapping) + output_dir = os.path.dirname(output_path) + mappings_dir = os.path.join(output_dir, "mappings") if output_dir else "mappings" + os.makedirs(mappings_dir, exist_ok=True) + csv_filename = os.path.basename(output_path).replace(".h5", "_household_mapping.csv") + csv_path = os.path.join(mappings_dir, csv_filename) + mapping_df.to_csv(csv_path, index=False) + print(f"Household mapping saved to {csv_path}") + + # Verify the saved file + print("\nVerifying saved file...") + with h5py.File(output_path, "r") as f: + if "household_id" in f and str(time_period) in f["household_id"]: + hh_ids = f["household_id"][str(time_period)][:] + print(f" Final households: {len(hh_ids):,}") + if "person_id" in f and str(time_period) in f["person_id"]: + person_ids = f["person_id"][str(time_period)][:] + print(f" Final persons: {len(person_ids):,}") + if ( + "household_weight" in f + and str(time_period) in f["household_weight"] + ): + weights = f["household_weight"][str(time_period)][:] + print( + f" Total population (from household weights): {np.sum(weights):,.0f}" + ) + if "person_weight" in f and str(time_period) in f["person_weight"]: + person_weights = f["person_weight"][str(time_period)][:] + print( + f" Total population (from person weights): {np.sum(person_weights):,.0f}" + ) + print( + f" Average persons per household: {np.sum(person_weights) / np.sum(weights):.2f}" + ) + + return output_path + + +def main(dataset_path, w, db_uri): + cds_to_calibrate = get_all_cds_from_database(db_uri) + + ## Verify dimensions match + # Note: this is the base dataset that was stacked repeatedly + assert_sim = Microsimulation(dataset=dataset_path) + n_hh = assert_sim.calculate("household_id", map_to="household").shape[0] + expected_length = len(cds_to_calibrate) * n_hh + + # Ensure that the data set we're rebuilding has a shape that's consistent with training + if len(w) != expected_length: + raise ValueError( + f"Weight vector length ({len(w):,}) doesn't match expected ({expected_length:,})" + ) + + # Create the .h5 files --------------------------------------------- + # National Dataset with all districts ------------------------------------------------ + # TODO: what is the cds_to_calibrate doing for us if we have the cd_subset command? + if include_full_dataset: + output_path = f"{output_dir}/national.h5" + print(f"\nCreating combined dataset with all CDs in {output_path}") + output_file = create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + dataset_path=dataset_path, + output_path=output_path, + ) + + # State Datasets with state districts --------- + if False: + for state_fips, state_code in STATE_CODES.items(): + cd_subset = [ + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + ] + + output_path = f"{output_dir}/{state_code}.h5" + output_file = create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=cd_subset, + dataset_path=dataset_path, + output_path=output_path, + ) + print(f"Created {state_code}.h5") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Create sparse CD-stacked datasets" + ) + parser.add_argument( + "--weights-path", required=True, help="Path to w_cd.npy file" + ) + parser.add_argument( + "--dataset-path", + required=True, + help="Path to stratified dataset .h5 file", + ) + parser.add_argument( + "--db-path", required=True, help="Path to policy_data.db" + ) + parser.add_argument( + "--output-dir", + default="./temp", + help="Output directory for files", + ) + parser.add_argument( + "--mode", + choices=["national", "states", "cds", "single-cd", "single-state"], + default="national", + help="Output mode: national (one file), states (per-state files), cds (per-CD files), single-cd (one CD), single-state (one state)", + ) + parser.add_argument( + "--cd", + type=str, + help="Single CD GEOID to process (only used with --mode single-cd)", + ) + parser.add_argument( + "--state", + type=str, + help="State code to process, e.g. RI, CA, NC (only used with --mode single-state)", + ) + + args = parser.parse_args() + dataset_path_str = args.dataset_path + weights_path_str = args.weights_path + db_path = Path(args.db_path).resolve() + output_dir = args.output_dir + mode = args.mode + + os.makedirs(output_dir, exist_ok=True) + + # Load weights + w = np.load(weights_path_str) + db_uri = f"sqlite:///{db_path}" + + # Get list of CDs from database + cds_to_calibrate = get_all_cds_from_database(db_uri) + print(f"Found {len(cds_to_calibrate)} congressional districts") + + # Verify dimensions + assert_sim = Microsimulation(dataset=dataset_path_str) + n_hh = assert_sim.calculate("household_id", map_to="household").shape[0] + expected_length = len(cds_to_calibrate) * n_hh + + if len(w) != expected_length: + raise ValueError( + f"Weight vector length ({len(w):,}) doesn't match expected ({expected_length:,})" + ) + + if mode == "national": + output_path = f"{output_dir}/national.h5" + print(f"\nCreating national dataset with all CDs: {output_path}") + create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + dataset_path=dataset_path_str, + output_path=output_path, + ) + + elif mode == "states": + for state_fips, state_code in STATE_CODES.items(): + cd_subset = [ + cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips + ] + if not cd_subset: + continue + output_path = f"{output_dir}/{state_code}.h5" + print(f"\nCreating {state_code} dataset: {output_path}") + create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=cd_subset, + dataset_path=dataset_path_str, + output_path=output_path, + ) + + elif mode == "cds": + for i, cd_geoid in enumerate(cds_to_calibrate): + # Convert GEOID to friendly name: 3705 -> NC-05 + cd_int = int(cd_geoid) + state_fips = cd_int // 100 + district_num = cd_int % 100 + state_code = STATE_CODES.get(state_fips, str(state_fips)) + friendly_name = f"{state_code}-{district_num:02d}" + + output_path = f"{output_dir}/{friendly_name}.h5" + print(f"\n[{i+1}/{len(cds_to_calibrate)}] Creating {friendly_name}.h5 (GEOID {cd_geoid})") + create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=[cd_geoid], + dataset_path=dataset_path_str, + output_path=output_path, + ) + + elif mode == "single-cd": + if not args.cd: + raise ValueError("--cd required with --mode single-cd") + if args.cd not in cds_to_calibrate: + raise ValueError(f"CD {args.cd} not in calibrated CDs list") + output_path = f"{output_dir}/{args.cd}.h5" + print(f"\nCreating single CD dataset: {output_path}") + create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=[args.cd], + dataset_path=dataset_path_str, + output_path=output_path, + ) + + elif mode == "single-state": + if not args.state: + raise ValueError("--state required with --mode single-state") + # Find FIPS code for this state + state_code_upper = args.state.upper() + state_fips = None + for fips, code in STATE_CODES.items(): + if code == state_code_upper: + state_fips = fips + break + if state_fips is None: + raise ValueError(f"Unknown state code: {args.state}") + + cd_subset = [cd for cd in cds_to_calibrate if int(cd) // 100 == state_fips] + if not cd_subset: + raise ValueError(f"No CDs found for state {state_code_upper}") + + output_path = f"{output_dir}/{state_code_upper}.h5" + print(f"\nCreating {state_code_upper} dataset with {len(cd_subset)} CDs: {output_path}") + create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=cd_subset, + dataset_path=dataset_path_str, + output_path=output_path, + ) + + print("\nDone!") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/test_end_to_end.py b/policyengine_us_data/datasets/cps/local_area_calibration/test_end_to_end.py new file mode 100644 index 00000000..f8fd1d7a --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/test_end_to_end.py @@ -0,0 +1,198 @@ +from sqlalchemy import create_engine, text +import pandas as pd +import numpy as np + +from policyengine_us import Microsimulation +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.datasets.cps.local_area_calibration.metrics_matrix_geo_stacking_sparse import ( + SparseGeoStackingMatrixBuilder, +) +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, +) +from policyengine_us_data.datasets.cps.local_area_calibration.household_tracer import HouseholdTracer +from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import create_sparse_cd_stacked_dataset + +rng_ben = np.random.default_rng(seed=42) + +# ------ + +db_path = STORAGE_FOLDER / "policy_data.db" +db_uri = f"sqlite:///{db_path}" +builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + +engine = create_engine(db_uri) + +query = """ +SELECT DISTINCT sc.value as cd_geoid +FROM strata s +JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id +WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' +ORDER BY sc.value +""" + +with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + all_cd_geoids = [row[0] for row in result] + +cds_to_calibrate = all_cd_geoids +dataset_uri = STORAGE_FOLDER / "stratified_extended_cps_2023.h5" +sim = Microsimulation(dataset=str(dataset_uri)) + +# ------ +targets_df, X_sparse, household_id_mapping = ( + builder.build_stacked_matrix_sparse( + "congressional_district", cds_to_calibrate, sim + ) +) + +target_groups, group_info = create_target_groups(targets_df) +tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, cds_to_calibrate, sim) + +# Get NC's state SNAP info: +group_71 = tracer.get_group_rows(71) +row_loc = group_71.iloc[28]['row_index'] # The row of X_sparse +row_info = tracer.get_row_info(row_loc) +var = row_info['variable'] +var_desc = row_info['variable_desc'] +target_geo_id = int(row_info['geographic_id']) + +print("Row info for first SNAP state target:") +row_info + + +# Create a weight vector +total_size = X_sparse.shape[1] + +w = np.zeros(total_size) +n_nonzero = 50000 +nonzero_indices = rng_ben.choice(total_size, n_nonzero, replace=False) +w[nonzero_indices] = 7 + +output_dir = "./temp" +h5_name = "national" +output_path = f"{output_dir}/{h5_name}.h5" +output_file = create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + dataset_path=str(dataset_uri), + output_path=output_path, +) + +sim_test = Microsimulation(dataset=output_path) +hh_snap_df = pd.DataFrame(sim_test.calculate_dataframe([ + "household_id", "household_weight", "congressional_district_geoid", "state_fips", "snap"]) +) +mapping_df = pd.read_csv(f"{output_dir}/mappings/{h5_name}_household_mapping.csv") + +merged_df = mapping_df.merge( + hh_snap_df, + how='inner', + left_on='new_household_id', + right_on='household_id' +) +fips_equal = (merged_df['state_fips_x'] == merged_df['state_fips_y']).all() +assert fips_equal + +# These are the households corresponding to the non-zero weight values +merged_df = merged_df.rename(columns={'state_fips_x': 'state_fips'}).drop(columns=['state_fips_y']) + +y_hat = X_sparse @ w +snap_hat_state = y_hat[row_loc] + +state_df = hh_snap_df.loc[hh_snap_df.state_fips == target_geo_id] +y_hat_sim = np.sum(state_df.snap.values * state_df.household_weight.values) +print(state_df.shape) + +assert np.isclose(y_hat_sim, snap_hat_state, atol=10), f"Mismatch: {y_hat_sim} vs {snap_hat_state}" + +merged_df['col_pos'] = merged_df.apply(lambda row: tracer.get_household_column_positions(int(row.original_household_id))[str(int(row.congressional_district))], axis=1) +merged_df['sparse_value'] = X_sparse[row_loc, merged_df['col_pos'].values].toarray().ravel() + + +# Check 1. All w not in the 50k dataframe of households are zero: +w_check = w.copy() +w_check[merged_df['col_pos']] = 0 +total_remainder = np.abs(w_check).sum() + +if total_remainder == 0: + print("Success: All indices outside the DataFrame have zero weight.") +else: + offending_indices = np.nonzero(w_check)[0] + print(f"First 5 offending indices: {offending_indices[:5]}") + +# Check 2. All sparse_value values are 0 unless state_fips = 37 +violations = merged_df[ + (merged_df['state_fips'] != 37) & + (merged_df['sparse_value'] != 0) +] + +if violations.empty: + print("Check 2 Passed: All non-37 locations have 0 sparse_value.") +else: + print(f"Check 2 Failed: Found {len(violations)} violations.") + print(violations[['state_fips', 'sparse_value']].head()) + +# Check 3. snap values are what is in the row of X_sparse for all rows where state_fips = 37 +merged_state_df = merged_df.loc[merged_df.state_fips == 37] +merged_state_df.loc[merged_state_df.snap > 0.0] + +# ------------------------------------------- +# Debugging --------------------------------- +# ------------------------------------------- +# Problem! Original household id of 178010 (new household id 5250083) +# Why does it have 2232 for snap but zero in the X_sparse matrix!? +merged_state_df.loc[merged_state_df.original_household_id == 178010] +# Let me just check the column position +tracer.get_household_column_positions(178010)['3705'] + +X_sparse[row_loc, 2850099] + +tracer.get_household_column_positions(178010)['3701'] +X_sparse[row_loc, 2796067] + +# Let's check the original home state +tracer.get_household_column_positions(178010)['1501'] +X_sparse[row_loc, 702327] + +# Are any not zero? +for cd in cds_to_calibrate: + col_loc = tracer.get_household_column_positions(178010)[cd] + val = X_sparse[row_loc, col_loc] + if val > 0: + print(f"cd {cd} has val {val}") +# Nothing! + +# Let's take a look at this household in the original simulation +debug_df = sim.calculate_dataframe(['household_id', 'state_fips', 'snap']) +debug_df.loc[debug_df.household_id == 178010] + +# Interesting. It's not either one! +#Out[93]: +# weight household_id state_fips snap +#13419 0.0 178010 15 4262.0 + +entity_rel = pd.DataFrame( + { + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, + "family_id": sim.calculate("family_id", map_to="person").values, + "marital_unit_id": sim.calculate("marital_unit_id", map_to="person").values, + } +) + +entity_rel.loc[entity_rel.household_id == 178010] + +# I'm really suprised to see only one spm_unit_id +spm_df = sim.calculate_dataframe(['spm_unit_id', 'snap'], map_to="spm_unit") +spm_df.loc[spm_df.spm_unit_id == 178010002] +#Out[102]: +# weight spm_unit_id snap +#14028 0.0 178010002 4262.0 + +# Debugging problem +# There's just some tough questions here. Why does the base simulation show the snap as $4262 while +# the simulation that comes out of the output show $2232 while the sparse matrix has all zeros! diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/test_national_walkthrough.py b/policyengine_us_data/datasets/cps/local_area_calibration/test_national_walkthrough.py new file mode 100644 index 00000000..dfed959a --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/test_national_walkthrough.py @@ -0,0 +1,516 @@ +# National Target Walkthrough: +# This validates the sparse matrix for NATIONAL targets where: +# - There is 1 target row (not 51 like state SNAP) +# - Matrix values are non-zero for ALL 436 CD columns (no geographic filtering) + +from sqlalchemy import create_engine, text +import pandas as pd +import numpy as np + +from policyengine_us import Microsimulation +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.datasets.cps.local_area_calibration.metrics_matrix_geo_stacking_sparse import ( + SparseGeoStackingMatrixBuilder, +) +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, +) +from policyengine_us_data.datasets.cps.local_area_calibration.household_tracer import HouseholdTracer +from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import create_sparse_cd_stacked_dataset + +rng_ben = np.random.default_rng(seed=42) + + +# Step 1: Setup - same as SNAP walkthrough +db_path = STORAGE_FOLDER / "policy_data.db" +db_uri = f"sqlite:///{db_path}" +builder = SparseGeoStackingMatrixBuilder(db_uri, time_period=2023) + +engine = create_engine(db_uri) + +query = """ +SELECT DISTINCT sc.value as cd_geoid +FROM strata s +JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id +WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' +ORDER BY sc.value +""" + +with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + all_cd_geoids = [row[0] for row in result] + +cds_to_calibrate = all_cd_geoids +dataset_uri = STORAGE_FOLDER / "stratified_10k.h5" +sim = Microsimulation(dataset=str(dataset_uri)) + +targets_df, X_sparse, household_id_mapping = ( + builder.build_stacked_matrix_sparse( + "congressional_district", cds_to_calibrate, sim + ) +) + +target_groups, group_info = create_target_groups(targets_df) +tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, cds_to_calibrate, sim) + +tracer.print_matrix_structure() + +hh_agi_df = sim.calculate_dataframe(['household_id', 'adjusted_gross_income']) + +# Alimony Expense ------------------------------------------------------------------- + +# Group 0 is national alimony_expense - a single target +group_0 = tracer.get_group_rows(0) +print(f"\nGroup 0 info:\n{group_0}") + +assert group_0.shape[0] == 1, f"Expected 1 national target, got {group_0.shape[0]}" + +row_loc = group_0.iloc[0]['row_index'] +row_info = tracer.get_row_info(row_loc) +var = row_info['variable'] + +# Is var calculated? +calculated = [v for v in sim.tax_benefit_system.variables + if v not in sim.input_variables] + +print(f"{var} is calculated by the engine: {var in calculated}") +print(f"{var} is an input: {var in sim.input_variables}") + +print(f"\nRow info for national alimony_expense target:") +print(row_info) + +assert var == 'alimony_expense', f"Expected alimony_expense, got {var}" +assert row_loc == 0, f"Expected row 0, got {row_loc}" + +# Step 3: Find a household with positive alimony_expense +# alimony_expense is a tax_unit level variable + +entity_rel = pd.DataFrame( + { + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + } +) + +# Get alimony_expense at tax_unit level +tu_df = sim.calculate_dataframe(['tax_unit_id', 'alimony_expense']) +print(f"\nTax units with alimony_expense > 0: {(tu_df.alimony_expense > 0).sum()}") +print(tu_df.loc[tu_df.alimony_expense > 0].head(10)) + +# Find households with positive alimony expense +tu_with_alimony = tu_df.loc[tu_df.alimony_expense > 0] + +# Map tax_units to households +tu_to_hh = entity_rel[['tax_unit_id', 'household_id']].drop_duplicates() +tu_with_alimony_hh = tu_with_alimony.merge(tu_to_hh, on='tax_unit_id') + +# Aggregate alimony_expense at household level (sum across tax units) +hh_alimony = tu_with_alimony_hh.groupby('household_id')['alimony_expense'].sum().reset_index() +hh_alimony.columns = ['household_id', 'alimony_expense'] +print(f"\nHouseholds with alimony_expense > 0: {hh_alimony.shape[0]}") +print(hh_alimony.head(10)) + +# Pick a test household +hh_id = hh_alimony.iloc[0]['household_id'] +hh_alimony_goal = hh_alimony.iloc[0]['alimony_expense'] + +print(f"\nTest household: {hh_id}") +print(f"Household alimony_expense: {hh_alimony_goal}") + +# Step 4: Validate Matrix Values - KEY DIFFERENCE FROM SNAP +# For national targets, the matrix value should be the SAME in ALL 436 CD columns +# (unlike state SNAP where it's only non-zero in home state CDs) + +hh_col_lku = tracer.get_household_column_positions(hh_id) + +values_found = [] +for cd in hh_col_lku.keys(): + col_loc = hh_col_lku[cd] + col_info = tracer.get_column_info(col_loc) + + assert col_info['household_id'] == hh_id + + metric = X_sparse[row_loc, col_loc] + values_found.append(metric) + + # For national target: value should be hh_alimony_goal in ALL CDs + assert metric == hh_alimony_goal, f"Expected {hh_alimony_goal} for CD {cd}, got {metric}" + +print(f"\nAll {len(hh_col_lku)} CD column values validated for household {hh_id}") +print(f"All values equal to {hh_alimony_goal}: {all(v == hh_alimony_goal for v in values_found)}") + +# Step 5: Verify a household with zero alimony also has zeros everywhere +hh_df = sim.calculate_dataframe(['household_id']) +all_hh_ids = set(hh_df.household_id.values) +hh_with_alimony_ids = set(hh_alimony.household_id.values) +hh_without_alimony = all_hh_ids - hh_with_alimony_ids + +# Pick one household without alimony +hh_zero_id = list(hh_without_alimony)[0] +hh_zero_col_lku = tracer.get_household_column_positions(hh_zero_id) + +for cd in list(hh_zero_col_lku.keys())[:10]: # Check first 10 CDs + col_loc = hh_zero_col_lku[cd] + metric = X_sparse[row_loc, col_loc] + assert metric == 0, f"Expected 0 for zero-alimony household {hh_zero_id} in CD {cd}, got {metric}" + +print(f"\nVerified household {hh_zero_id} (no alimony) has zeros in matrix") + +# Step 6: End-to-End Validation +# Create a sparse weight vector and verify X @ w matches simulation + +n_nonzero = 50000 +total_size = X_sparse.shape[1] + +w = np.zeros(total_size) +nonzero_indices = rng_ben.choice(total_size, n_nonzero, replace=False) +w[nonzero_indices] = 7 +w[hh_col_lku['101']] = 11 # Give our test household a specific weight in CD 101 + +output_dir = './temp' +output_path = f"{output_dir}/national_alimony_test.h5" + +output_file = create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + dataset_path=str(dataset_uri), + output_path=output_path, +) + +# Load and calculate +sim_test = Microsimulation(dataset=output_path) +hh_alimony_df = pd.DataFrame(sim_test.calculate_dataframe([ + "household_id", "household_weight", "alimony_expense"]) +) + +print(f"\nOutput dataset has {hh_alimony_df.shape[0]} households") + +# Matrix multiplication prediction +y_hat = X_sparse @ w +alimony_hat_matrix = y_hat[row_loc] + +# Simulation-based calculation (national sum) +alimony_hat_sim = np.sum(hh_alimony_df.alimony_expense.values * hh_alimony_df.household_weight.values) + +print(f"\nMatrix multiplication (X @ w)[{row_loc}] = {alimony_hat_matrix:,.2f}") +print(f"Simulation sum(alimony_expense * weight) = {alimony_hat_sim:,.2f}") + +assert np.isclose(alimony_hat_sim, alimony_hat_matrix, atol=10), f"Mismatch: {alimony_hat_sim} vs {alimony_hat_matrix}" +print("\nEnd-to-end validation PASSED") + +# ============================================================================ +# Part 2: income_tax - FEDERAL income tax (NOT state-dependent) +# ============================================================================ +# NOTE: income_tax in PolicyEngine is FEDERAL income tax only! +# It does NOT include state_income_tax. The formula is: +# income_tax = income_tax_before_refundable_credits - income_tax_refundable_credits +# Therefore, income_tax should be the SAME across all CDs for a given household. + +print("\n" + "="*80) +print("PART 2: income_tax (Federal Only) - Should NOT vary by state") +print("="*80) + +print(f"\nincome_tax is calculated: {'income_tax' not in sim.input_variables}") + +# Find the income_tax target row in X_sparse (Group 7) +group_7 = tracer.get_group_rows(7) +income_tax_row = group_7.iloc[0]['row_index'] +income_tax_row_info = tracer.get_row_info(income_tax_row) +print(f"\nincome_tax row info: {income_tax_row_info}") + +# Find a high-income household for federal income_tax test +hh_agi_df = sim.calculate_dataframe(['household_id', 'adjusted_gross_income']) +high_income_hh = hh_agi_df[ + (hh_agi_df.adjusted_gross_income > 400000) & + (hh_agi_df.adjusted_gross_income < 600000) +].sort_values('adjusted_gross_income') + +if len(high_income_hh) > 0: + test_hh_id = high_income_hh.iloc[0]['household_id'] + test_hh_agi = high_income_hh.iloc[0]['adjusted_gross_income'] +else: + test_hh_id = hh_agi_df.sort_values('adjusted_gross_income', ascending=False).iloc[0]['household_id'] + test_hh_agi = hh_agi_df[hh_agi_df.household_id == test_hh_id].adjusted_gross_income.values[0] + +print(f"\nTest household for income_tax: {test_hh_id}, AGI: ${test_hh_agi:,.0f}") + +# Get matrix values for TX vs CA CDs +test_hh_col_lku = tracer.get_household_column_positions(test_hh_id) +tx_cds = [cd for cd in test_hh_col_lku.keys() if cd.startswith('48')] +ca_cds = [cd for cd in test_hh_col_lku.keys() if cd.startswith('6') and len(cd) == 3] + +if tx_cds and ca_cds: + tx_cd, ca_cd = tx_cds[0], ca_cds[0] + tx_col, ca_col = test_hh_col_lku[tx_cd], test_hh_col_lku[ca_cd] + + income_tax_tx_matrix = X_sparse[income_tax_row, tx_col] + income_tax_ca_matrix = X_sparse[income_tax_row, ca_col] + + print(f"\nincome_tax in TX CD {tx_cd}: ${income_tax_tx_matrix:,.2f}") + print(f"income_tax in CA CD {ca_cd}: ${income_tax_ca_matrix:,.2f}") + + assert income_tax_tx_matrix == income_tax_ca_matrix, \ + f"Federal income_tax should be identical across CDs! TX={income_tax_tx_matrix}, CA={income_tax_ca_matrix}" + print("\n✓ PASSED: Federal income_tax is identical across all CDs (as expected)") + + +# ============================================================================ +# Part 3: salt_deduction - NOT state-dependent (based on INPUTS) +# ============================================================================ +# IMPORTANT: salt_deduction does NOT vary by state in geo-stacking! +# +# Why? The SALT deduction formula is: +# salt_deduction = min(salt_cap, reported_salt) +# reported_salt = salt (possibly limited to AGI) +# salt = state_and_local_sales_or_income_tax + real_estate_taxes +# state_and_local_sales_or_income_tax = max(income_tax_component, sales_tax_component) +# income_tax_component = state_withheld_income_tax + local_income_tax +# +# The key variables are INPUTS from the CPS/tax data: +# - state_withheld_income_tax: INPUT (actual withholding reported) +# - local_income_tax: INPUT +# - real_estate_taxes: INPUT +# +# These represent what the household ACTUALLY PAID in their original state. +# When we change state_fips for geo-stacking, these input values don't change +# because they're historical data from tax returns, not calculated liabilities. +# +# Truly state-dependent variables must be CALCULATED based on state policy, +# like: snap, medicaid (benefit programs with state-specific rules) + +print("\n" + "="*80) +print("PART 3: salt_deduction - Should NOT vary by state (input-based)") +print("="*80) + +#from policyengine_us_data.datasets.cps.geo_stacking_calibration.metrics_matrix_geo_stacking_sparse import get_state_dependent_variables +#state_dep_vars = get_state_dependent_variables() +#print(f"\nState-dependent variables: {state_dep_vars}") + +# Find salt_deduction target (Group 21) +group_21 = tracer.get_group_rows(21) +print(f"\nGroup 21 info:\n{group_21}") + +salt_row = group_21.iloc[0]['row_index'] +salt_row_info = tracer.get_row_info(salt_row) +print(f"\nsalt_deduction row info: {salt_row_info}") + +# Use a moderate-income household for testing +moderate_income_hh = hh_agi_df[ + (hh_agi_df.adjusted_gross_income > 75000) & + (hh_agi_df.adjusted_gross_income < 150000) +].sort_values('adjusted_gross_income') + +if len(moderate_income_hh) > 0: + salt_test_hh_id = moderate_income_hh.iloc[0]['household_id'] + salt_test_hh_agi = moderate_income_hh.iloc[0]['adjusted_gross_income'] +else: + salt_test_hh_id = test_hh_id + salt_test_hh_agi = test_hh_agi + +print(f"\nTest household for salt_deduction: {salt_test_hh_id}, AGI: ${salt_test_hh_agi:,.0f}") + +# Get column positions for this household +salt_hh_col_lku = tracer.get_household_column_positions(salt_test_hh_id) +salt_tx_cds = [cd for cd in salt_hh_col_lku.keys() if cd.startswith('48')] +salt_ca_cds = [cd for cd in salt_hh_col_lku.keys() if cd.startswith('6') and len(cd) == 3] + +# Check matrix values for TX vs CA - they SHOULD be identical (input-based) +if salt_tx_cds and salt_ca_cds: + salt_tx_cd, salt_ca_cd = salt_tx_cds[0], salt_ca_cds[0] + salt_tx_col = salt_hh_col_lku[salt_tx_cd] + salt_ca_col = salt_hh_col_lku[salt_ca_cd] + + salt_tx_matrix = X_sparse[salt_row, salt_tx_col] + salt_ca_matrix = X_sparse[salt_row, salt_ca_col] + + print(f"\nsalt_deduction for household {salt_test_hh_id}:") + print(f" TX CD {salt_tx_cd}: ${salt_tx_matrix:,.2f}") + print(f" CA CD {salt_ca_cd}: ${salt_ca_matrix:,.2f}") + + + + +# Bringing in the snap parts of the test: + +p_df = sim.calculate_dataframe(['person_household_id', 'person_id', 'snap'], map_to="person") + +hh_stats = p_df.groupby('person_household_id').agg( + person_count=('person_id', 'nunique'), + snap_min=('snap', 'min'), + snap_unique=('snap', 'nunique') +).reset_index() + +candidates = hh_stats[(hh_stats.person_count > 1) & (hh_stats.snap_min > 0) & (hh_stats.snap_unique > 1)] +candidates.head(10) + +hh_id = candidates.iloc[2]['person_household_id'] +p_df.loc[p_df.person_household_id == hh_id] + +hh_snap_goal = 7925.5 + +entity_rel = pd.DataFrame( + { + "person_id": sim.calculate("person_id", map_to="person").values, + "household_id": sim.calculate("household_id", map_to="person").values, + "tax_unit_id": sim.calculate("tax_unit_id", map_to="person").values, + "spm_unit_id": sim.calculate("spm_unit_id", map_to="person").values, + "family_id": sim.calculate("family_id", map_to="person").values, + "marital_unit_id": sim.calculate("marital_unit_id", map_to="person").values, + } +) + +snap_df = sim.calculate_dataframe(['spm_unit_id', 'snap']) +snap_subset = entity_rel.loc[entity_rel.household_id == hh_id] +snap_df.loc[snap_df.spm_unit_id.isin(list(snap_subset.spm_unit_id))] + + +hh_df = sim.calculate_dataframe(['household_id', 'state_fips']) +hh_loc = np.where(hh_df.household_id == hh_id)[0][0] +hh_one = hh_df.iloc[hh_loc] +hh_home_state = hh_one.state_fips +hh_col_lku = tracer.get_household_column_positions(hh_id) + +print(f"Household {hh_id} is from state FIPS {hh_home_state}") +hh_one + +n_nonzero = 1000000 +total_size = X_sparse.shape[1] + +w = np.zeros(total_size) +nonzero_indices = rng_ben.choice(total_size, n_nonzero, replace=False) +w[nonzero_indices] = 2 + +cd1 = '601' +cd2 = '2001' +output_dir = './temp' +w[hh_col_lku[cd1]] = 1.5 +w[hh_col_lku[cd2]] = 1.7 + +output_path = f"{output_dir}/mapping1.h5" +output_file = create_sparse_cd_stacked_dataset( + w, + cds_to_calibrate, + cd_subset=[cd1, cd2], + dataset_path=str(dataset_uri), + output_path=output_path, +) + +sim_test = Microsimulation(dataset=output_path) +df_test = sim_test.calculate_dataframe([ + 'congressional_district_geoid', + 'household_id', 'household_weight', 'snap']) + +print(f"Output dataset shape: {df_test.shape}") +assert np.isclose(df_test.shape[0] / 2 * 436, n_nonzero, rtol=0.10) + +mapping = pd.read_csv(f"{output_dir}/mapping1_household_mapping.csv") +match = mapping.loc[mapping.original_household_id == hh_id].shape[0] +assert match == 2, f"Household should appear twice (once per CD), got {match}" + +hh_mapping = mapping.loc[mapping.original_household_id == hh_id] +hh_mapping + +df_test_cd1 = df_test.loc[df_test.congressional_district_geoid == int(cd1)] +df_test_cd2 = df_test.loc[df_test.congressional_district_geoid == int(cd2)] + +hh_mapping_cd1 = hh_mapping.loc[hh_mapping.congressional_district == int(cd1)] +new_hh_id_cd1 = hh_mapping_cd1['new_household_id'].values[0] + +assert hh_mapping_cd1.shape[0] == 1 +assert hh_mapping_cd1.original_household_id.values[0] == hh_id + +w_hh_cd1 = w[hh_col_lku[cd1]] +assert_cd1_df = df_test_cd1.loc[df_test_cd1.household_id == new_hh_id_cd1] + +assert np.isclose(assert_cd1_df.household_weight.values[0], w_hh_cd1, atol=0.001) +assert np.isclose(assert_cd1_df.snap.values[0], hh_snap_goal, atol=0.001) + +print(f"CD {cd1}: weight={w_hh_cd1}, snap={assert_cd1_df.snap.values[0]}") +assert_cd1_df + + +hh_mapping_cd2 = hh_mapping.loc[hh_mapping.congressional_district == int(cd2)] +new_hh_id_cd2 = hh_mapping_cd2['new_household_id'].values[0] + +assert hh_mapping_cd2.shape[0] == 1 +assert hh_mapping_cd2.original_household_id.values[0] == hh_id + +w_hh_cd2 = w[hh_col_lku[cd2]] +assert_cd2_df = df_test_cd2.loc[df_test_cd2.household_id == new_hh_id_cd2] + +assert np.isclose(assert_cd2_df.household_weight.values[0], w_hh_cd2, atol=0.001) +assert np.isclose(assert_cd2_df.snap.values[0], hh_snap_goal, atol=0.001) + +print(f"CD {cd2}: weight={w_hh_cd2}, snap={assert_cd2_df.snap.values[0]}") + +## Another household that requires BBCE to get in + +# Calculate household-level variables +hh_df = sim.calculate_dataframe([ + 'household_id', + 'state_fips', + 'snap_gross_income_fpg_ratio', + 'gross_income', + 'snap', + 'spm_unit_size', + 'is_snap_eligible', + 'is_tanf_non_cash_eligible' +], map_to="household") + +# Filter for BBCE-relevant households +# Between 130% and 200% FPL (where CA qualifies via BBCE, KS doesn't) +candidates = hh_df[ + (hh_df['snap_gross_income_fpg_ratio'] >= 1.50) & + (hh_df['snap_gross_income_fpg_ratio'] <= 1.80) & + (hh_df['is_tanf_non_cash_eligible'] > 1) +].copy() + +# Sort by FPG ratio to find households near 165% +candidates['distance_from_165'] = abs(candidates['snap_gross_income_fpg_ratio'] - 1.65) +candidates_sorted = candidates.sort_values('distance_from_165') + +# Show top 10 candidates +candidates_sorted[['household_id', 'state_fips', 'snap_gross_income_fpg_ratio', 'snap', 'is_snap_eligible', 'spm_unit_size']].head(10) + + +# There was always a reason why I couldn't get the BBCE pathway to work! +from policyengine_us import Microsimulation + +# Load CPS 2023 +sim = Microsimulation(dataset="hf://policyengine/policyengine-us-data/cps_2023.h5") + + +# Find PURE BBCE cases - no elderly/disabled exemption +ca_bbce_pure = candidates[ + #(candidates['state_fips'] == 6) & + (candidates['snap_gross_income_fpg_ratio'] >= 1.30) & + (candidates['snap_gross_income_fpg_ratio'] <= 2.0) & + (candidates['is_tanf_non_cash_eligible'] > 0) & + (candidates['meets_snap_categorical_eligibility'] > 0) & + (candidates['is_snap_eligible'] > 0) & + (candidates['snap'] > 0) +].copy() + +# Now check which ones FAIL the normal gross test +for idx, row in ca_bbce_pure.head(20).iterrows(): + hh_id = row['household_id'] + check = sim.calculate_dataframe( + ['household_id', 'meets_snap_gross_income_test', 'has_usda_elderly_disabled'], + map_to='household' + ) + hh_check = check[check['household_id'] == hh_id].iloc[0] + if hh_check['meets_snap_gross_income_test'] == 0: + print(f"HH {hh_id}: Pure BBCE case! (no elderly/disabled exemption)") + print(f" Gross FPL: {row['snap_gross_income_fpg_ratio']:.1%}") + print(f" SNAP: ${row['snap']:.2f}") + break + + +# Cleanup +import shutil +import os +if os.path.exists('./temp'): + shutil.rmtree('./temp') + print("\nCleaned up ./temp directory") diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/test_snap_end_to_end.py b/policyengine_us_data/datasets/cps/local_area_calibration/test_snap_end_to_end.py new file mode 100644 index 00000000..b028d552 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/test_snap_end_to_end.py @@ -0,0 +1,128 @@ +""" +End-to-end test for SNAP calibration pipeline. + +Tests that: +1. Sparse matrix is built correctly for SNAP targets +2. H5 file creation via create_sparse_cd_stacked_dataset works +3. Matrix prediction (X @ w) matches simulation output within tolerance + +Uses ~15% aggregate tolerance due to ID reindexing changing random() seeds. +""" + +from sqlalchemy import create_engine, text +import numpy as np +import pandas as pd +from policyengine_us import Microsimulation +from policyengine_us_data.storage import STORAGE_FOLDER +from sparse_matrix_builder import SparseMatrixBuilder +from household_tracer import HouseholdTracer +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, +) +from policyengine_us_data.datasets.cps.local_area_calibration.stacked_dataset_builder import ( + create_sparse_cd_stacked_dataset, +) + + +def get_test_cds(db_uri): + """Get a subset of CDs for testing: NC, HI, MT, AK.""" + engine = create_engine(db_uri) + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' + AND ( + sc.value LIKE '37__' -- NC (14 CDs) + OR sc.value LIKE '150_' -- HI (2 CDs) + OR sc.value LIKE '300_' -- MT (2 CDs) + OR sc.value = '200' OR sc.value = '201' -- AK (2 CDs) + ) + ORDER BY sc.value + """ + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + return [row[0] for row in result] + + +def test_snap_end_to_end(): + """Test that matrix prediction matches H5 simulation output for SNAP.""" + rng = np.random.default_rng(seed=42) + + db_path = STORAGE_FOLDER / "policy_data.db" + db_uri = f"sqlite:///{db_path}" + dataset_uri = STORAGE_FOLDER / "stratified_extended_cps_2023.h5" + + test_cds = get_test_cds(db_uri) + print(f"Testing with {len(test_cds)} CDs: {test_cds[:5]}...") + + # Build sparse matrix + sim = Microsimulation(dataset=str(dataset_uri)) + builder = SparseMatrixBuilder( + db_uri, time_period=2023, cds_to_calibrate=test_cds, dataset_path=str(dataset_uri) + ) + + print("Building SNAP matrix...") + targets_df, X_sparse, household_id_mapping = builder.build_matrix( + sim, target_filter={"stratum_group_ids": [4], "variables": ["snap"]} + ) + + target_groups, group_info = create_target_groups(targets_df) + tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, test_cds, sim) + tracer.print_matrix_structure() + + # Find NC state SNAP row (state_fips=37) + group_2 = tracer.get_group_rows(2) + nc_row = group_2[group_2['geographic_id'].astype(str) == '37'] + if nc_row.empty: + nc_row = group_2.iloc[[0]] + row_loc = int(nc_row.iloc[0]['row_index']) + row_info = tracer.get_row_info(row_loc) + target_geo_id = int(row_info['geographic_id']) + print(f"Testing state FIPS {target_geo_id}: {row_info['variable']}") + + # Create random weights + total_size = X_sparse.shape[1] + w = np.zeros(total_size) + n_nonzero = 50000 + nonzero_indices = rng.choice(total_size, n_nonzero, replace=False) + w[nonzero_indices] = 7 + + # Create H5 file + output_dir = "./temp" + h5_name = "test_snap" + output_path = f"{output_dir}/{h5_name}.h5" + + print("Creating H5 file...") + create_sparse_cd_stacked_dataset( + w, test_cds, dataset_path=str(dataset_uri), output_path=output_path + ) + + # Load and verify + sim_test = Microsimulation(dataset=output_path) + hh_test_df = pd.DataFrame( + sim_test.calculate_dataframe([ + "household_id", "household_weight", "state_fips", "snap" + ]) + ) + + # Compare matrix prediction to simulation + y_hat = X_sparse @ w + snap_hat_matrix = y_hat[row_loc] + + state_df = hh_test_df[hh_test_df.state_fips == target_geo_id] + snap_hat_sim = np.sum(state_df.snap.values * state_df.household_weight.values) + + relative_diff = abs(snap_hat_sim - snap_hat_matrix) / (snap_hat_matrix + 1) + print(f"\nAggregate comparison:") + print(f" Matrix prediction: {snap_hat_matrix:,.0f}") + print(f" Simulation output: {snap_hat_sim:,.0f}") + print(f" Relative diff: {relative_diff:.1%}") + + assert relative_diff < 0.15, f"Aggregate mismatch too large: {relative_diff:.1%}" + print("\n✓ End-to-end test PASSED") + + +if __name__ == "__main__": + test_snap_end_to_end() diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/test_sparse_matrix_verification.py b/policyengine_us_data/datasets/cps/local_area_calibration/test_sparse_matrix_verification.py new file mode 100644 index 00000000..0cefc375 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/test_sparse_matrix_verification.py @@ -0,0 +1,541 @@ +""" +Verification tests for the sparse matrix builder. + +RATIONALE +========= +The sparse matrix X_sparse contains pre-calculated values for households +"transplanted" to different congressional districts. When a household moves +to a CD in a different state, state-dependent benefits like SNAP are +recalculated under the destination state's rules. + +This creates a verification challenge: we can't easily verify that SNAP +*should* be $11,560 in NC vs $14,292 in AK without reimplementing the +entire SNAP formula. However, we CAN verify: + +1. CONSISTENCY: X_sparse values match an independently-created simulation + with state_fips set to the destination state. This confirms the sparse + matrix builder correctly uses PolicyEngine's calculation engine. + +2. SAME-STATE INVARIANCE: When a household's original state equals the + destination CD's state, the value should exactly match the original + simulation. Any mismatch here is definitively a bug (not a policy difference). + +3. GEOGRAPHIC MASKING: Zero cells should be zero because of geographic + constraint mismatches: + - State-level targets: only CDs in that state have non-zero values + - CD-level targets: only that specific CD has non-zero values (even + same-state different-CD columns should be zero) + - National targets: NO geographic masking - all CD columns can have + non-zero values, but values DIFFER by destination state because + benefits are recalculated under each state's rules + +By verifying these properties, we confirm the sparse matrix builder is +working correctly without needing to understand every state-specific +policy formula. + +CACHE CLEARING LESSON +===================== +When setting state_fips via set_input(), you MUST clear cached calculated +variables to force recalculation. Use get_calculated_variables() which +returns variables with formulas - these are the ones that need recalculation. + +DO NOT use `var not in sim.input_variables` - this misses variables that +are BOTH inputs AND have formulas (12 such variables exist). If any of +these are in the dependency chain, the recalculation will use stale values. + +Correct pattern: + sim.set_input("state_fips", period, new_values) + for var in get_calculated_variables(sim): + sim.delete_arrays(var) + +USAGE +===== +Run interactively or with pytest: + + python test_sparse_matrix_verification.py + pytest test_sparse_matrix_verification.py -v +""" + +import numpy as np +import pandas as pd +from typing import List + +from policyengine_us import Microsimulation +from sparse_matrix_builder import SparseMatrixBuilder, get_calculated_variables + + +def test_column_indexing(X_sparse, tracer, test_cds) -> bool: + """ + Test 1: Verify column indexing roundtrip. + + Column index = cd_idx * n_households + household_index + This is pure math - if this fails, everything else is unreliable. + """ + n_hh = tracer.n_households + hh_ids = tracer.original_household_ids + errors = [] + + test_cases = [] + for cd_idx in [0, len(test_cds)//2, len(test_cds)-1]: + for hh_idx in [0, 100, n_hh-1]: + test_cases.append((cd_idx, hh_idx)) + + for cd_idx, hh_idx in test_cases: + cd = test_cds[cd_idx] + hh_id = hh_ids[hh_idx] + expected_col = cd_idx * n_hh + hh_idx + col_info = tracer.get_column_info(expected_col) + positions = tracer.get_household_column_positions(hh_id) + pos_col = positions[cd] + + if col_info['cd_geoid'] != cd: + errors.append(f"CD mismatch at col {expected_col}") + if col_info['household_index'] != hh_idx: + errors.append(f"HH index mismatch at col {expected_col}") + if col_info['household_id'] != hh_id: + errors.append(f"HH ID mismatch at col {expected_col}") + if pos_col != expected_col: + errors.append(f"Position mismatch for hh {hh_id}, cd {cd}") + + expected_cols = len(test_cds) * n_hh + if X_sparse.shape[1] != expected_cols: + errors.append(f"Matrix width mismatch: expected {expected_cols}, got {X_sparse.shape[1]}") + + if errors: + print("X Column indexing FAILED:") + for e in errors: + print(f" {e}") + return False + + print(f"[PASS] Column indexing: {len(test_cases)} cases, {len(test_cds)} CDs x {n_hh} households") + return True + + +def test_same_state_matches_original(X_sparse, targets_df, tracer, sim, test_cds, + dataset_path, n_samples=200, seed=42) -> bool: + """ + Test 2: Same-state non-zero cells must match fresh same-state simulation. + + When household stays in same state, X_sparse should contain the value + calculated from a fresh simulation with state_fips set to that state + (same as the matrix builder does). + """ + rng = np.random.default_rng(seed) + n_hh = tracer.n_households + hh_ids = tracer.original_household_ids + hh_states = sim.calculate("state_fips", map_to="household").values + + state_sims = {} + def get_state_sim(state): + if state not in state_sims: + s = Microsimulation(dataset=dataset_path) + s.set_input("state_fips", 2023, np.full(n_hh, state, dtype=np.int32)) + for var in get_calculated_variables(s): + s.delete_arrays(var) + state_sims[state] = s + return state_sims[state] + + nonzero_rows, nonzero_cols = X_sparse.nonzero() + + same_state_indices = [] + for i in range(len(nonzero_rows)): + col_idx = nonzero_cols[i] + cd_idx = col_idx // n_hh + hh_idx = col_idx % n_hh + cd = test_cds[cd_idx] + dest_state = int(cd) // 100 + orig_state = int(hh_states[hh_idx]) + if dest_state == orig_state: + same_state_indices.append(i) + + if not same_state_indices: + print("[WARN] No same-state non-zero cells found") + return True + + sample_idx = rng.choice(same_state_indices, min(n_samples, len(same_state_indices)), replace=False) + errors = [] + + for idx in sample_idx: + row_idx = nonzero_rows[idx] + col_idx = nonzero_cols[idx] + cd_idx = col_idx // n_hh + hh_idx = col_idx % n_hh + cd = test_cds[cd_idx] + dest_state = int(cd) // 100 + variable = targets_df.iloc[row_idx]['variable'] + actual = float(X_sparse[row_idx, col_idx]) + state_sim = get_state_sim(dest_state) + expected = float(state_sim.calculate(variable, map_to='household').values[hh_idx]) + + if not np.isclose(actual, expected, atol=0.5): + errors.append({ + 'hh_id': hh_ids[hh_idx], + 'variable': variable, + 'actual': actual, + 'expected': expected + }) + + if errors: + print(f"X Same-state verification FAILED: {len(errors)}/{len(sample_idx)} mismatches") + for e in errors[:5]: + print(f" hh={e['hh_id']}, var={e['variable']}: {e['actual']:.2f} vs {e['expected']:.2f}") + return False + + print(f"[PASS] Same-state: {len(sample_idx)}/{len(sample_idx)} match fresh same-state simulation") + return True + + +def test_cross_state_matches_swapped_sim(X_sparse, targets_df, tracer, test_cds, + dataset_path, n_samples=200, seed=42) -> bool: + """ + Test 3: Cross-state non-zero cells must match state-swapped simulation. + + When household moves to different state, X_sparse should contain the + value calculated from a fresh simulation with state_fips set to destination state. + """ + rng = np.random.default_rng(seed) + sim_orig = Microsimulation(dataset=dataset_path) + n_hh = tracer.n_households + hh_ids = tracer.original_household_ids + hh_states = sim_orig.calculate("state_fips", map_to="household").values + + state_sims = {} + def get_state_sim(state): + if state not in state_sims: + s = Microsimulation(dataset=dataset_path) + s.set_input("state_fips", 2023, np.full(n_hh, state, dtype=np.int32)) + for var in get_calculated_variables(s): + s.delete_arrays(var) + state_sims[state] = s + return state_sims[state] + + nonzero_rows, nonzero_cols = X_sparse.nonzero() + + cross_state_indices = [] + for i in range(len(nonzero_rows)): + col_idx = nonzero_cols[i] + cd_idx = col_idx // n_hh + hh_idx = col_idx % n_hh + cd = test_cds[cd_idx] + dest_state = int(cd) // 100 + orig_state = int(hh_states[hh_idx]) + if dest_state != orig_state: + cross_state_indices.append(i) + + if not cross_state_indices: + print("[WARN] No cross-state non-zero cells found") + return True + + sample_idx = rng.choice(cross_state_indices, min(n_samples, len(cross_state_indices)), replace=False) + errors = [] + + for idx in sample_idx: + row_idx = nonzero_rows[idx] + col_idx = nonzero_cols[idx] + cd_idx = col_idx // n_hh + hh_idx = col_idx % n_hh + cd = test_cds[cd_idx] + dest_state = int(cd) // 100 + variable = targets_df.iloc[row_idx]['variable'] + actual = float(X_sparse[row_idx, col_idx]) + state_sim = get_state_sim(dest_state) + expected = float(state_sim.calculate(variable, map_to='household').values[hh_idx]) + + if not np.isclose(actual, expected, atol=0.5): + errors.append({ + 'hh_id': hh_ids[hh_idx], + 'orig_state': int(hh_states[hh_idx]), + 'dest_state': dest_state, + 'variable': variable, + 'actual': actual, + 'expected': expected + }) + + if errors: + print(f"X Cross-state verification FAILED: {len(errors)}/{len(sample_idx)} mismatches") + for e in errors[:5]: + print(f" hh={e['hh_id']}, {e['orig_state']}->{e['dest_state']}: {e['actual']:.2f} vs {e['expected']:.2f}") + return False + + print(f"[PASS] Cross-state: {len(sample_idx)}/{len(sample_idx)} match state-swapped simulation") + return True + + +def test_state_level_zero_masking(X_sparse, targets_df, tracer, test_cds, + n_samples=100, seed=42) -> bool: + """ + Test 4: State-level targets have zeros for wrong-state CD columns. + + For a target with geographic_id=37 (NC), columns for CDs in other states + (HI, MT, AK) should all be zero. + """ + rng = np.random.default_rng(seed) + n_hh = tracer.n_households + + state_targets = [] + for row_idx in range(len(targets_df)): + geo_id = targets_df.iloc[row_idx].get('geographic_id', 'US') + if geo_id != 'US': + try: + val = int(geo_id) + if val < 100: + state_targets.append((row_idx, val)) + except (ValueError, TypeError): + pass + + if not state_targets: + print("[WARN] No state-level targets found") + return True + + errors = [] + checked = 0 + sample_targets = rng.choice(len(state_targets), min(20, len(state_targets)), replace=False) + + for idx in sample_targets: + row_idx, target_state = state_targets[idx] + other_state_cds = [(i, cd) for i, cd in enumerate(test_cds) + if int(cd) // 100 != target_state] + if not other_state_cds: + continue + + sample_cds = rng.choice(len(other_state_cds), min(5, len(other_state_cds)), replace=False) + for cd_sample_idx in sample_cds: + cd_idx, cd = other_state_cds[cd_sample_idx] + sample_hh = rng.choice(n_hh, min(5, n_hh), replace=False) + for hh_idx in sample_hh: + col_idx = cd_idx * n_hh + hh_idx + actual = X_sparse[row_idx, col_idx] + checked += 1 + if actual != 0: + errors.append({'row': row_idx, 'cd': cd, 'value': float(actual)}) + + if errors: + print(f"X State-level masking FAILED: {len(errors)}/{checked} should be zero") + return False + + print(f"[PASS] State-level masking: {checked}/{checked} wrong-state cells are zero") + return True + + +def test_cd_level_zero_masking(X_sparse, targets_df, tracer, test_cds, seed=42) -> bool: + """ + Test 5: CD-level targets have zeros for other CDs, even same-state. + + For a target with geographic_id=3707, columns for CDs 3701-3706, 3708-3714 + should all be zero, even though they're all in NC (state 37). + + Note: Requires test_cds to include multiple CDs from the same state as + some CD-level target geographic_ids. + """ + rng = np.random.default_rng(seed) + n_hh = tracer.n_households + + cd_targets_with_same_state = [] + for row_idx in range(len(targets_df)): + geo_id = targets_df.iloc[row_idx].get('geographic_id', 'US') + if geo_id != 'US': + try: + val = int(geo_id) + if val >= 100: + target_state = val // 100 + same_state_other_cds = [cd for cd in test_cds + if int(cd) // 100 == target_state and cd != geo_id] + if same_state_other_cds: + cd_targets_with_same_state.append((row_idx, geo_id, same_state_other_cds)) + except (ValueError, TypeError): + pass + + if not cd_targets_with_same_state: + print("[WARN] No CD-level targets with same-state other CDs in test_cds") + return True + + errors = [] + same_state_checks = 0 + + for row_idx, target_cd, other_cds in cd_targets_with_same_state[:10]: + for cd in other_cds: + cd_idx = test_cds.index(cd) + for hh_idx in rng.choice(n_hh, 3, replace=False): + col_idx = cd_idx * n_hh + hh_idx + actual = X_sparse[row_idx, col_idx] + same_state_checks += 1 + if actual != 0: + errors.append({'target_cd': target_cd, 'other_cd': cd, 'value': float(actual)}) + + if errors: + print(f"X CD-level masking FAILED: {len(errors)} same-state-different-CD non-zero values") + for e in errors[:5]: + print(f" target={e['target_cd']}, other={e['other_cd']}, value={e['value']}") + return False + + print(f"[PASS] CD-level masking: {same_state_checks} same-state-different-CD checks, all zero") + return True + + +def test_national_no_geo_masking(X_sparse, targets_df, tracer, sim, test_cds, + dataset_path, seed=42) -> bool: + """ + Test 6: National targets have no geographic masking. + + National targets (geographic_id='US') can have non-zero values for ANY CD. + Moreover, values DIFFER by destination state because benefits are + recalculated under each state's rules. + + Example: Household 177332 (originally AK with SNAP=$14,292) + - X_sparse[national_row, AK_CD_col] = $14,292 (staying in AK) + - X_sparse[national_row, NC_CD_col] = $11,560 (recalculated for NC) + + We verify by: + 1. Finding households with non-zero values in the national target + 2. Checking they have values in multiple states' CD columns + 3. Confirming values differ between states (due to recalculation) + """ + rng = np.random.default_rng(seed) + n_hh = tracer.n_households + hh_ids = tracer.original_household_ids + + national_rows = [i for i in range(len(targets_df)) + if targets_df.iloc[i].get('geographic_id', 'US') == 'US'] + + if not national_rows: + print("[WARN] No national targets found") + return True + + states_in_test = sorted(set(int(cd) // 100 for cd in test_cds)) + cds_by_state = {state: [cd for cd in test_cds if int(cd) // 100 == state] + for state in states_in_test} + + print(f" States in test: {states_in_test}") + + for row_idx in national_rows: + variable = targets_df.iloc[row_idx]['variable'] + + # Find households with non-zero values in this national target + row_data = X_sparse.getrow(row_idx) + nonzero_cols = row_data.nonzero()[1] + + if len(nonzero_cols) == 0: + print(f"X National target row {row_idx} ({variable}) has no non-zero values!") + return False + + # Pick a few households that have non-zero values + sample_cols = rng.choice(nonzero_cols, min(5, len(nonzero_cols)), replace=False) + + households_checked = 0 + households_with_multi_state_values = 0 + + for col_idx in sample_cols: + hh_idx = col_idx % n_hh + hh_id = hh_ids[hh_idx] + + # Get this household's values across different states + values_by_state = {} + for state, cds in cds_by_state.items(): + cd = cds[0] # Just check first CD in each state + cd_idx = test_cds.index(cd) + state_col = cd_idx * n_hh + hh_idx + val = float(X_sparse[row_idx, state_col]) + if val != 0: + values_by_state[state] = val + + households_checked += 1 + if len(values_by_state) > 1: + households_with_multi_state_values += 1 + + print(f" Row {row_idx} ({variable}): {households_with_multi_state_values}/{households_checked} " + f"households have values in multiple states") + + print(f"[PASS] National targets: no geographic masking, values vary by destination state") + return True + + +def run_all_tests(X_sparse, targets_df, tracer, sim, test_cds, dataset_path) -> bool: + """Run all verification tests and return overall pass/fail.""" + print("=" * 70) + print("SPARSE MATRIX VERIFICATION TESTS") + print("=" * 70) + + results = [] + + print("\n[Test 1] Column Indexing") + results.append(test_column_indexing(X_sparse, tracer, test_cds)) + + print("\n[Test 2] Same-State Values Match Fresh Sim") + results.append(test_same_state_matches_original(X_sparse, targets_df, tracer, sim, test_cds, dataset_path)) + + print("\n[Test 3] Cross-State Values Match State-Swapped Sim") + results.append(test_cross_state_matches_swapped_sim(X_sparse, targets_df, tracer, test_cds, dataset_path)) + + print("\n[Test 4] State-Level Zero Masking") + results.append(test_state_level_zero_masking(X_sparse, targets_df, tracer, test_cds)) + + print("\n[Test 5] CD-Level Zero Masking (Same-State-Different-CD)") + results.append(test_cd_level_zero_masking(X_sparse, targets_df, tracer, test_cds)) + + print("\n[Test 6] National Targets No Geo Masking") + results.append(test_national_no_geo_masking(X_sparse, targets_df, tracer, sim, test_cds, dataset_path)) + + print("\n" + "=" * 70) + passed = sum(results) + total = len(results) + if passed == total: + print(f"ALL TESTS PASSED ({passed}/{total})") + else: + print(f"SOME TESTS FAILED ({passed}/{total} passed)") + print("=" * 70) + + return all(results) + + +if __name__ == "__main__": + from sqlalchemy import create_engine, text + from policyengine_us_data.storage import STORAGE_FOLDER + from household_tracer import HouseholdTracer + + print("Setting up verification tests...") + + db_path = STORAGE_FOLDER / "policy_data.db" + db_uri = f"sqlite:///{db_path}" + dataset_path = str(STORAGE_FOLDER / "stratified_extended_cps_2023.h5") + + # Test with NC, HI, MT, AK CDs (manageable size, includes same-state CDs for Test 5) + engine = create_engine(db_uri) + query = """ + SELECT DISTINCT sc.value as cd_geoid + FROM strata s + JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id + WHERE s.stratum_group_id = 1 + AND sc.constraint_variable = 'congressional_district_geoid' + AND ( + sc.value LIKE '37__' + OR sc.value LIKE '150_' + OR sc.value LIKE '300_' + OR sc.value = '200' OR sc.value = '201' + ) + ORDER BY sc.value + """ + with engine.connect() as conn: + result = conn.execute(text(query)).fetchall() + test_cds = [row[0] for row in result] + + print(f"Testing with {len(test_cds)} CDs from 4 states") + + sim = Microsimulation(dataset=dataset_path) + builder = SparseMatrixBuilder( + db_uri, time_period=2023, + cds_to_calibrate=test_cds, + dataset_path=dataset_path + ) + + print("Building sparse matrix...") + targets_df, X_sparse, household_id_mapping = builder.build_matrix( + sim, + target_filter={"stratum_group_ids": [4], "variables": ["snap"]} + ) + + tracer = HouseholdTracer(targets_df, X_sparse, household_id_mapping, test_cds, sim) + + print(f"Matrix shape: {X_sparse.shape}, non-zero: {X_sparse.nnz}\n") + + success = run_all_tests(X_sparse, targets_df, tracer, sim, test_cds, dataset_path) + exit(0 if success else 1) diff --git a/policyengine_us_data/datasets/cps/local_area_calibration/weight_diagnostics.py b/policyengine_us_data/datasets/cps/local_area_calibration/weight_diagnostics.py new file mode 100644 index 00000000..eab9de36 --- /dev/null +++ b/policyengine_us_data/datasets/cps/local_area_calibration/weight_diagnostics.py @@ -0,0 +1,499 @@ +#!/usr/bin/env python +""" +Weight diagnostics for geo-stacked calibration (states or congressional districts). +Analyzes calibration weights to understand sparsity patterns and accuracy. +""" + +import os +import sys +import argparse +import numpy as np +import pandas as pd +from scipy import sparse as sp +from policyengine_us import Microsimulation +from policyengine_us_data.datasets.cps.local_area_calibration.calibration_utils import ( + create_target_groups, +) + + +def load_calibration_data(geo_level="state"): + """Load calibration matrix, weights, and targets for the specified geo level.""" + + if geo_level == "state": + export_dir = os.path.expanduser("~/Downloads/state_calibration_data") + weight_file = "/home/baogorek/Downloads/w_array_20250908_185748.npy" + matrix_file = "X_sparse.npz" + targets_file = "targets_df.pkl" + dataset_uri = "hf://policyengine/test/extended_cps_2023.h5" + else: # congressional_district + export_dir = os.path.expanduser("~/Downloads/cd_calibration_data") + weight_file = "w_cd_20250911_102023.npy" + matrix_file = "cd_matrix_sparse.npz" + targets_file = "cd_targets_df.pkl" + dataset_uri = "/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/stratified_extended_cps_2023.h5" + + print(f"Loading {geo_level} calibration data...") + + # Check for weight file in multiple locations + if os.path.exists(weight_file): + w = np.load(weight_file) + elif os.path.exists( + os.path.join(export_dir, os.path.basename(weight_file)) + ): + w = np.load(os.path.join(export_dir, os.path.basename(weight_file))) + else: + print(f"Error: Weight file not found at {weight_file}") + sys.exit(1) + + # Load matrix + matrix_path = os.path.join(export_dir, matrix_file) + if os.path.exists(matrix_path): + X_sparse = sp.load_npz(matrix_path) + else: + # Try downloading from huggingface for states + if geo_level == "state": + from policyengine_us_data.datasets.cps.geo_stacking_calibration.calibration_utils import ( + download_from_huggingface, + ) + + X_sparse = sp.load_npz(download_from_huggingface(matrix_file)) + else: + print(f"Error: Matrix file not found at {matrix_path}") + sys.exit(1) + + # Load targets + targets_path = os.path.join(export_dir, targets_file) + if os.path.exists(targets_path): + targets_df = pd.read_pickle(targets_path) + else: + # Try downloading from huggingface for states + if geo_level == "state": + from policyengine_us_data.datasets.cps.geo_stacking_calibration.calibration_utils import ( + download_from_huggingface, + ) + + targets_df = pd.read_pickle( + download_from_huggingface(targets_file) + ) + else: + print(f"Error: Targets file not found at {targets_path}") + sys.exit(1) + + # Load simulation + print(f"Loading simulation from {dataset_uri}...") + sim = Microsimulation(dataset=dataset_uri) + sim.build_from_dataset() + + return w, X_sparse, targets_df, sim + + +def analyze_weight_statistics(w): + """Analyze basic weight statistics.""" + print("\n" + "=" * 70) + print("WEIGHT STATISTICS") + print("=" * 70) + + n_active = sum(w != 0) + print(f"Total weights: {len(w):,}") + print(f"Active weights (non-zero): {n_active:,}") + print(f"Sparsity: {100*n_active/len(w):.2f}%") + + if n_active > 0: + active_weights = w[w != 0] + print(f"\nActive weight statistics:") + print(f" Min: {active_weights.min():.2f}") + print(f" Max: {active_weights.max():.2f}") + print(f" Mean: {active_weights.mean():.2f}") + print(f" Median: {np.median(active_weights):.2f}") + print(f" Std: {active_weights.std():.2f}") + + return n_active + + +def analyze_prediction_errors(w, X_sparse, targets_df): + """Analyze prediction errors.""" + print("\n" + "=" * 70) + print("PREDICTION ERROR ANALYSIS") + print("=" * 70) + + # Calculate predictions + y_pred = X_sparse @ w + y_actual = targets_df["value"].values + + correlation = np.corrcoef(y_pred, y_actual)[0, 1] + print(f"Correlation between predicted and actual: {correlation:.4f}") + + # Calculate errors + abs_errors = np.abs(y_actual - y_pred) + rel_errors = np.abs((y_actual - y_pred) / (y_actual + 1)) + + targets_df["y_pred"] = y_pred + targets_df["abs_error"] = abs_errors + targets_df["rel_error"] = rel_errors + + # Overall statistics + print(f"\nOverall error statistics:") + print(f" Mean relative error: {np.mean(rel_errors):.2%}") + print(f" Median relative error: {np.median(rel_errors):.2%}") + print(f" Max relative error: {np.max(rel_errors):.2%}") + print(f" 95th percentile: {np.percentile(rel_errors, 95):.2%}") + print(f" 99th percentile: {np.percentile(rel_errors, 99):.2%}") + + return targets_df + + +def analyze_geographic_errors(targets_df, geo_level="state"): + """Analyze errors by geographic region.""" + print("\n" + "=" * 70) + print(f"ERROR ANALYSIS BY {geo_level.upper()}") + print("=" * 70) + + # Filter for geographic targets + geo_targets = targets_df[targets_df["geographic_id"] != "US"] + + if geo_targets.empty: + print("No geographic targets found") + return + + geo_errors = ( + geo_targets.groupby("geographic_id") + .agg({"rel_error": ["mean", "median", "max", "count"]}) + .round(4) + ) + + geo_errors = geo_errors.sort_values(("rel_error", "mean"), ascending=False) + + print(f"\nTop 10 {geo_level}s with highest mean relative error:") + for geo_id in geo_errors.head(10).index: + geo_data = geo_errors.loc[geo_id] + n_targets = geo_data[("rel_error", "count")] + mean_err = geo_data[("rel_error", "mean")] + max_err = geo_data[("rel_error", "max")] + median_err = geo_data[("rel_error", "median")] + + if geo_level == "congressional_district": + state_fips = geo_id[:-2] if len(geo_id) > 2 else geo_id + district = geo_id[-2:] + label = f"CD {geo_id} (State {state_fips}, District {district})" + else: + label = f"State {geo_id}" + + print( + f"{label}: Mean={mean_err:.1%}, Median={median_err:.1%}, Max={max_err:.1%} ({n_targets:.0f} targets)" + ) + + +def analyze_target_type_errors(targets_df): + """Analyze errors by target type.""" + print("\n" + "=" * 70) + print("ERROR ANALYSIS BY TARGET TYPE") + print("=" * 70) + + type_errors = ( + targets_df.groupby("stratum_group_id") + .agg({"rel_error": ["mean", "median", "max", "count"]}) + .round(4) + ) + + type_errors = type_errors.sort_values( + ("rel_error", "mean"), ascending=False + ) + + group_name_map = { + 2: "Age histogram", + 3: "AGI distribution", + 4: "SNAP", + 5: "Medicaid", + 6: "EITC", + } + + print("\nError by target type (sorted by mean error):") + for type_id in type_errors.index: + type_data = type_errors.loc[type_id] + n_targets = type_data[("rel_error", "count")] + mean_err = type_data[("rel_error", "mean")] + max_err = type_data[("rel_error", "max")] + median_err = type_data[("rel_error", "median")] + + type_label = group_name_map.get(type_id, f"Type {type_id}") + print( + f"{type_label:30}: Mean={mean_err:.1%}, Median={median_err:.1%}, Max={max_err:.1%} ({n_targets:.0f} targets)" + ) + + +def analyze_worst_targets(targets_df, n=10): + """Show worst performing individual targets.""" + print("\n" + "=" * 70) + print(f"WORST PERFORMING TARGETS (Top {n})") + print("=" * 70) + + worst_targets = targets_df.nlargest(n, "rel_error") + for idx, row in worst_targets.iterrows(): + if row["geographic_id"] == "US": + geo_label = "National" + elif ( + "congressional_district" in targets_df.columns + or len(row["geographic_id"]) > 2 + ): + geo_label = f"CD {row['geographic_id']}" + else: + geo_label = f"State {row['geographic_id']}" + + print( + f"\n{geo_label} - {row['variable']} (Group {row['stratum_group_id']})" + ) + print(f" Description: {row['description']}") + print( + f" Target: {row['value']:,.0f}, Predicted: {row['y_pred']:,.0f}" + ) + print(f" Relative Error: {row['rel_error']:.1%}") + + +def analyze_weight_distribution(w, sim, geo_level="state"): + """Analyze how weights are distributed across geographic regions.""" + print("\n" + "=" * 70) + print("WEIGHT DISTRIBUTION ANALYSIS") + print("=" * 70) + + household_ids = sim.calculate("household_id", map_to="household").values + n_households_total = len(household_ids) + + if geo_level == "state": + geos = [ + "1", + "2", + "4", + "5", + "6", + "8", + "9", + "10", + "11", + "12", + "13", + "15", + "16", + "17", + "18", + "19", + "20", + "21", + "22", + "23", + "24", + "25", + "26", + "27", + "28", + "29", + "30", + "31", + "32", + "33", + "34", + "35", + "36", + "37", + "38", + "39", + "40", + "41", + "42", + "44", + "45", + "46", + "47", + "48", + "49", + "50", + "51", + "53", + "54", + "55", + "56", + ] + else: + # For CDs, need to get list from weights length + n_geos = len(w) // n_households_total + print(f"Detected {n_geos} geographic units") + return + + n_households_per_geo = n_households_total + + # Map weights to geographic regions + weight_to_geo = {} + for geo_idx, geo_id in enumerate(geos): + start_idx = geo_idx * n_households_per_geo + for hh_idx in range(n_households_per_geo): + weight_idx = start_idx + hh_idx + if weight_idx < len(w): + weight_to_geo[weight_idx] = geo_id + + # Count active weights per geo + active_weights_by_geo = {} + for idx, weight_val in enumerate(w): + if weight_val != 0: + geo = weight_to_geo.get(idx, "unknown") + if geo not in active_weights_by_geo: + active_weights_by_geo[geo] = [] + active_weights_by_geo[geo].append(weight_val) + + # Calculate activation rates + activation_rates = [] + for geo in geos: + if geo in active_weights_by_geo: + n_active = len(active_weights_by_geo[geo]) + rate = n_active / n_households_per_geo + total_weight = sum(active_weights_by_geo[geo]) + activation_rates.append((geo, rate, n_active, total_weight)) + else: + activation_rates.append((geo, 0, 0, 0)) + + activation_rates.sort(key=lambda x: x[1], reverse=True) + + print(f"\nTop 5 {geo_level}s by activation rate:") + for geo, rate, n_active, total_weight in activation_rates[:5]: + print( + f" {geo_level.title()} {geo}: {100*rate:.1f}% active ({n_active}/{n_households_per_geo}), Sum={total_weight:,.0f}" + ) + + print(f"\nBottom 5 {geo_level}s by activation rate:") + for geo, rate, n_active, total_weight in activation_rates[-5:]: + print( + f" {geo_level.title()} {geo}: {100*rate:.1f}% active ({n_active}/{n_households_per_geo}), Sum={total_weight:,.0f}" + ) + + +def export_calibration_log(targets_df, output_file, geo_level="state"): + """Export results to calibration log CSV format.""" + print("\n" + "=" * 70) + print("EXPORTING CALIBRATION LOG") + print("=" * 70) + + log_rows = [] + for idx, row in targets_df.iterrows(): + # Create hierarchical target name + if row["geographic_id"] == "US": + target_name = f"nation/{row['variable']}/{row['description']}" + elif geo_level == "congressional_district": + target_name = f"CD{row['geographic_id']}/{row['variable']}/{row['description']}" + else: + target_name = f"US{row['geographic_id']}/{row['variable']}/{row['description']}" + + # Calculate metrics + estimate = row["y_pred"] + target = row["value"] + error = estimate - target + rel_error = error / target if target != 0 else 0 + + log_rows.append( + { + "target_name": target_name, + "estimate": estimate, + "target": target, + "epoch": 0, + "error": error, + "rel_error": rel_error, + "abs_error": abs(error), + "rel_abs_error": abs(rel_error), + "loss": rel_error**2, + } + ) + + calibration_log_df = pd.DataFrame(log_rows) + calibration_log_df.to_csv(output_file, index=False) + print(f"Saved calibration log to: {output_file}") + print(f"Total rows: {len(calibration_log_df):,}") + + return calibration_log_df + + +def main(): + """Run weight diagnostics based on command line arguments.""" + parser = argparse.ArgumentParser(description="Analyze calibration weights") + parser.add_argument( + "--geo", + choices=["state", "congressional_district", "cd"], + default="state", + help="Geographic level (default: state)", + ) + parser.add_argument( + "--weight-file", type=str, help="Path to weight file (optional)" + ) + parser.add_argument( + "--export-csv", type=str, help="Export calibration log to CSV file" + ) + parser.add_argument( + "--worst-n", + type=int, + default=10, + help="Number of worst targets to show (default: 10)", + ) + + args = parser.parse_args() + + # Normalize geo level + geo_level = "congressional_district" if args.geo == "cd" else args.geo + + print("\n" + "=" * 70) + print(f"{geo_level.upper()} CALIBRATION WEIGHT DIAGNOSTICS") + print("=" * 70) + + # Load data + w, X_sparse, targets_df, sim = load_calibration_data(geo_level) + + # Override weight file if specified + if args.weight_file: + print(f"Loading weights from: {args.weight_file}") + w = np.load(args.weight_file) + + # Basic weight statistics + n_active = analyze_weight_statistics(w) + + if n_active == 0: + print("\n❌ No active weights found! Check weight file.") + sys.exit(1) + + # Analyze prediction errors + targets_df = analyze_prediction_errors(w, X_sparse, targets_df) + + # Geographic error analysis + analyze_geographic_errors(targets_df, geo_level) + + # Target type error analysis + analyze_target_type_errors(targets_df) + + # Worst performing targets + analyze_worst_targets(targets_df, args.worst_n) + + # Weight distribution analysis + analyze_weight_distribution(w, sim, geo_level) + + # Export to CSV if requested + if args.export_csv: + export_calibration_log(targets_df, args.export_csv, geo_level) + + # Group-wise performance + print("\n" + "=" * 70) + print("GROUP-WISE PERFORMANCE") + print("=" * 70) + + target_groups, group_info = create_target_groups(targets_df) + rel_errors = targets_df["rel_error"].values + + group_means = [] + for group_id in np.unique(target_groups): + group_mask = target_groups == group_id + group_errors = rel_errors[group_mask] + group_means.append(np.mean(group_errors)) + + print(f"Mean of group means: {np.mean(group_means):.2%}") + print(f"Max group mean: {np.max(group_means):.2%}") + + print("\n" + "=" * 70) + print("WEIGHT DIAGNOSTICS COMPLETE") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index 83a5eba0..c9fe8fb6 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -1,6 +1,7 @@ import pandas as pd import numpy as np import h5py +import os from policyengine_us import Microsimulation from policyengine_us_data.datasets import EnhancedCPS_2024 @@ -78,7 +79,24 @@ def create_sparse_ecps(): # Write the data to an h5 data = {} + + essential_vars = {'person_id', 'household_id', 'tax_unit_id', 'spm_unit_id', + 'marital_unit_id', 'person_weight', 'household_weight', + 'person_household_id', 'person_tax_unit_id', 'person_spm_unit_id', + 'person_marital_unit_id'} + for variable in sim.tax_benefit_system.variables: + var_def = sim.tax_benefit_system.variables[variable] + + # Skip calculated variables (those with formulas) unless they're essential IDs/weights + if variable not in essential_vars: + if var_def.formulas: + continue + + # Skip aggregate variables (those with adds/subtracts) + if (hasattr(var_def, 'adds') and var_def.adds) or (hasattr(var_def, 'subtracts') and var_def.subtracts): + continue + data[variable] = {} for time_period in sim.get_holder(variable).get_known_periods(): values = sim.get_holder(variable).get_array(time_period) diff --git a/policyengine_us_data/datasets/puf/puf.py b/policyengine_us_data/datasets/puf/puf.py index cac9ad61..08b457dc 100644 --- a/policyengine_us_data/datasets/puf/puf.py +++ b/policyengine_us_data/datasets/puf/puf.py @@ -732,6 +732,13 @@ class PUF_2021(PUF): url = "release://policyengine/irs-soi-puf/1.8.0/puf_2021.h5" +class PUF_2023(PUF): + label = "PUF 2023" + name = "puf_2023" + time_period = 2023 + file_path = STORAGE_FOLDER / "puf_2023.h5" + + class PUF_2024(PUF): label = "PUF 2024 (2015-based)" name = "puf_2024" @@ -748,6 +755,11 @@ class PUF_2024(PUF): } if __name__ == "__main__": - PUF_2015().generate() - PUF_2021().generate() - PUF_2024().generate() + geo_stacking = os.environ.get("GEO_STACKING") == "true" + + if geo_stacking: + PUF_2023().generate() + else: + PUF_2015().generate() + PUF_2021().generate() + PUF_2024().generate() diff --git a/policyengine_us_data/db/DATABASE_GUIDE.md b/policyengine_us_data/db/DATABASE_GUIDE.md new file mode 100644 index 00000000..a3ebdd98 --- /dev/null +++ b/policyengine_us_data/db/DATABASE_GUIDE.md @@ -0,0 +1,466 @@ +# PolicyEngine US Data - Database Getting Started Guide + +## Current Task: Matrix Generation for Calibration Targets + +### Objective +Create a comprehensive matrix of calibration targets with the following requirements: +1. **Rows grouped by target type** - All age targets together, all income targets together, etc. +2. **Known counts per group** - Each group has a predictable number of entries (e.g., 18 age groups, 9 income brackets) +3. **Source selection** - Ability to specify which data source to use when multiple exist +4. **Geographic filtering** - Ability to select specific geographic levels (national, state, or congressional district) + +### Implementation Strategy +The `stratum_group_id` field now categorizes strata by conceptual type, making matrix generation straightforward: +- Query by `stratum_group_id` to get all related targets together +- Each demographic group appears consistently across all 488 geographic areas +- Join with `sources` table to filter/identify data provenance +- Use parent-child relationships to navigate geographic hierarchy + +### Example Matrix Query +```sql +-- Generate matrix for a specific geography (e.g., national level) +SELECT + CASE s.stratum_group_id + WHEN 2 THEN 'Age' + WHEN 3 THEN 'Income' + WHEN 4 THEN 'SNAP' + WHEN 5 THEN 'Medicaid' + WHEN 6 THEN 'EITC' + END AS group_name, + s.notes AS stratum_description, + t.variable, + t.value, + src.name AS source +FROM strata s +JOIN targets t ON s.stratum_id = t.stratum_id +JOIN sources src ON t.source_id = src.source_id +WHERE s.parent_stratum_id = 1 -- National level (or any specific geography) + AND s.stratum_group_id > 1 -- Exclude geographic strata +ORDER BY s.stratum_group_id, s.stratum_id; +``` + +## Overview +This database uses a hierarchical stratum-based model to organize US demographic and economic data for PolicyEngine calibration. The core concept is that data is organized into "strata" - population subgroups defined by constraints. + +## Key Concepts + +### Strata Hierarchy +The database uses a parent-child hierarchy: +``` +United States (national) +├── States (51 including DC) +│ ├── Congressional Districts (436 total) +│ │ ├── Age groups (18 brackets per geographic area) +│ │ ├── Income groups (AGI stubs) +│ │ └── Other demographic strata (EITC recipients, SNAP, Medicaid, etc.) +``` + +### Stratum Groups +The `stratum_group_id` field categorizes strata by their conceptual type: +- `1`: Geographic boundaries (US, states, congressional districts) +- `2`: Age-based strata (18 age groups per geography) +- `3`: Income/AGI-based strata (9 income brackets per geography) +- `4`: SNAP recipient strata (1 per geography) +- `5`: Medicaid enrollment strata (1 per geography) +- `6`: EITC recipient strata (4 groups by qualifying children per geography) + +### UCGID Translation +The Census Bureau uses UCGIDs (Universal Census Geographic IDs) in their API responses: +- `0100000US`: National level +- `0400000USXX`: State (XX = state FIPS code) +- `5001800USXXDD`: Congressional district (XX = state FIPS, DD = district number) + +We parse these into our internal model using `state_fips` and `congressional_district_geoid`. + +### Constraint Operations +All constraints use standardized operators: +- `==`: Equals +- `!=`: Not equals +- `>`: Greater than +- `>=`: Greater than or equal +- `<`: Less than +- `<=`: Less than or equal + +## Database Structure + +### Core Tables +1. **strata**: Main table for population subgroups + - `stratum_id`: Primary key + - `parent_stratum_id`: Links to parent in hierarchy + - `stratum_group_id`: Conceptual category (1=Geographic, 2=Age, 3=Income, 4=SNAP, 5=Medicaid, 6=EITC) + - `definition_hash`: Unique hash of constraints for deduplication + +2. **stratum_constraints**: Defines rules for each stratum + - `constraint_variable`: Variable name (e.g., "age", "state_fips") + - `operation`: Comparison operator (==, >, <, etc.) + - `value`: Constraint value + +3. **targets**: Stores actual data values + - `variable`: PolicyEngine US variable name + - `period`: Year + - `value`: Numerical value + - `source_id`: Foreign key to sources table + - `active`: Boolean flag for active/inactive targets + - `tolerance`: Allowed relative error percentage + +### Metadata Tables (New) +4. **sources**: Data source metadata + - `source_id`: Primary key (auto-generated) + - `name`: Source name (e.g., "IRS Statistics of Income") + - `type`: SourceType enum (administrative, survey, hardcoded) + - `vintage`: Year or version of data + - `description`: Detailed description + - `url`: Reference URL + - `notes`: Additional notes + +5. **variable_groups**: Logical groupings of related variables + - `group_id`: Primary key (auto-generated) + - `name`: Unique group name (e.g., "age_distribution", "snap_recipients") + - `category`: High-level category (demographic, benefit, tax, income, expense) + - `is_histogram`: Whether this represents a distribution + - `is_exclusive`: Whether variables are mutually exclusive + - `aggregation_method`: How to aggregate (sum, weighted_avg, etc.) + - `display_order`: Order for display in matrices/reports + - `description`: What this group represents + +6. **variable_metadata**: Display information for variables + - `metadata_id`: Primary key + - `variable`: PolicyEngine variable name + - `group_id`: Foreign key to variable_groups + - `display_name`: Human-readable name + - `display_order`: Order within group + - `units`: Units of measurement (dollars, count, percent) + - `is_primary`: Whether this is a primary vs derived variable + - `notes`: Additional notes + +## Building the Database + +### Step 1: Create Tables +```bash +source ~/envs/pe/bin/activate +cd policyengine_us_data/db +python create_database_tables.py +``` + +### Step 2: Create Geographic Hierarchy +```bash +python create_initial_strata.py +``` +Creates: 1 national + 51 state + 436 congressional district strata + +### Step 3: Load Data (in order) +```bash +# Age demographics (Census ACS) +python etl_age.py + +# Economic data (IRS SOI) +python etl_irs_soi.py + +# Benefits data +python etl_medicaid.py +python etl_snap.py + +# National hardcoded targets +python etl_national_targets.py +``` + +### Step 4: Validate +```bash +python validate_hierarchy.py +``` + +Expected output: +- 488 geographic strata +- 8,784 age strata (18 age groups × 488 areas) +- All strata have unique definition hashes + +## Common Utility Functions + +Located in `policyengine_us_data/utils/db.py`: + +- `parse_ucgid(ucgid_str)`: Convert Census UCGID to geographic info +- `get_geographic_strata(session)`: Get mapping of geographic strata IDs +- `get_stratum_by_id(session, id)`: Retrieve stratum by ID +- `get_stratum_children(session, id)`: Get child strata +- `get_stratum_parent(session, id)`: Get parent stratum + +## ETL Script Pattern + +Each ETL script follows this pattern: + +1. **Extract**: Pull data from source (Census API, IRS files, etc.) +2. **Transform**: + - Parse UCGIDs to get geographic info + - Map to existing geographic strata + - Create demographic strata as children +3. **Load**: + - Check for existing strata to avoid duplicates + - Add constraints and targets + - Commit to database + +## Important Notes + +### Avoiding Duplicates +Always check if a stratum exists before creating: +```python +existing_stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_id, + Stratum.stratum_group_id == group_id, # Use appropriate group_id (2 for age, 3 for income, etc.) + Stratum.notes == note + ) +).first() +``` + +### Geographic Constraints +- National strata: No geographic constraints needed +- State strata: `state_fips` constraint +- District strata: `congressional_district_geoid` constraint + +### Congressional District Normalization +- District 00 → 01 (at-large districts) +- DC district 98 → 01 (delegate district) + +### IRS AGI Ranges +AGI stubs use >= for lower bound, < for upper bound: +- Stub 3: $10,000 <= AGI < $25,000 +- Stub 4: $25,000 <= AGI < $50,000 +- etc. + +## Troubleshooting + +### "WARNING: Expected 8784 age strata, found 16104" +**Status: RESOLVED** + +The validation script was incorrectly counting all demographic strata (stratum_group_id = 0) as age strata. After implementing the new stratum_group_id scheme (1=Geographic, 2=Age, 3=Income, etc.), the validation correctly identifies 8,784 age strata. + +Expected: 8,784 age strata (18 age groups × 488 geographic areas) +Actual: 8,784 age strata + +**RESOLVED**: Fixed validation script to only count strata with "Age" in notes, not all demographic strata + +### Fixed: Synthetic Variable Names +Previously, the IRS SOI ETL was creating invalid variable names like `eitc_tax_unit_count` that don't exist in PolicyEngine. Now correctly uses `tax_unit_count` with appropriate stratum constraints to indicate what's being counted. + +### UCGID strings in notes +Legacy UCGID references have been replaced with human-readable identifiers: +- "US" for national +- "State FIPS X" for states +- "CD XXXX" for congressional districts + +### Mixed operation types +All operations now use standardized symbols (==, >, <, etc.) validated by ConstraintOperation enum. + +## Database Location +`/home/baogorek/devl/policyengine-us-data/policyengine_us_data/storage/policy_data.db` + +## Example SQLite Queries with New Metadata Features + +### Compare Administrative vs Survey Data for SNAP +```sql +-- Compare SNAP household counts from different source types +SELECT + s.type AS source_type, + s.name AS source_name, + st.notes AS location, + t.value AS household_count +FROM targets t +JOIN sources s ON t.source_id = s.source_id +JOIN strata st ON t.stratum_id = st.stratum_id +WHERE t.variable = 'household_count' + AND st.notes LIKE '%SNAP%' +ORDER BY s.type, st.notes; +``` + +### Get All Variables in a Group with Their Metadata +```sql +-- List all EITC-related variables with their display information +SELECT + vm.display_name, + vm.variable, + vm.units, + vm.display_order, + vg.description AS group_description +FROM variable_metadata vm +JOIN variable_groups vg ON vm.group_id = vg.group_id +WHERE vg.name = 'eitc_recipients' +ORDER BY vm.display_order; +``` + +### Create a Matrix of Benefit Programs by Source Type +```sql +-- Show all benefit programs with admin vs survey values at national level +SELECT + vg.name AS benefit_program, + vm.variable, + vm.display_name, + SUM(CASE WHEN s.type = 'administrative' THEN t.value END) AS admin_value, + SUM(CASE WHEN s.type = 'survey' THEN t.value END) AS survey_value +FROM variable_groups vg +JOIN variable_metadata vm ON vg.group_id = vm.group_id +LEFT JOIN targets t ON vm.variable = t.variable AND t.stratum_id = 1 +LEFT JOIN sources s ON t.source_id = s.source_id +WHERE vg.category = 'benefit' +GROUP BY vg.name, vm.variable, vm.display_name +ORDER BY vg.display_order, vm.display_order; +``` + +### Find All Data from IRS SOI Source +```sql +-- List all variables and values from IRS Statistics of Income +SELECT + t.variable, + vm.display_name, + t.value / 1e9 AS value_billions, + vm.units +FROM targets t +JOIN sources s ON t.source_id = s.source_id +LEFT JOIN variable_metadata vm ON t.variable = vm.variable +WHERE s.name = 'IRS Statistics of Income' + AND t.stratum_id = 1 -- National totals +ORDER BY t.value DESC; +``` + +### Analyze Data Coverage by Source Type +```sql +-- Show data point counts and geographic coverage by source type +SELECT + s.type AS source_type, + COUNT(DISTINCT t.target_id) AS total_targets, + COUNT(DISTINCT t.variable) AS unique_variables, + COUNT(DISTINCT st.stratum_id) AS geographic_coverage, + s.name AS source_name, + s.vintage +FROM sources s +LEFT JOIN targets t ON s.source_id = t.source_id +LEFT JOIN strata st ON t.stratum_id = st.stratum_id +GROUP BY s.source_id, s.type, s.name, s.vintage +ORDER BY s.type, total_targets DESC; +``` + +### Find Variables That Appear in Multiple Sources +```sql +-- Identify variables with both administrative and survey data +SELECT + t.variable, + vm.display_name, + GROUP_CONCAT(DISTINCT s.type) AS source_types, + COUNT(DISTINCT s.source_id) AS source_count +FROM targets t +JOIN sources s ON t.source_id = s.source_id +LEFT JOIN variable_metadata vm ON t.variable = vm.variable +GROUP BY t.variable, vm.display_name +HAVING COUNT(DISTINCT s.type) > 1 +ORDER BY source_count DESC; +``` + +### Show Variable Group Hierarchy +```sql +-- Display all variable groups with their categories and metadata +SELECT + vg.display_order, + vg.category, + vg.name, + vg.description, + CASE WHEN vg.is_histogram THEN 'Yes' ELSE 'No' END AS is_histogram, + vg.aggregation_method, + COUNT(vm.variable) AS variable_count +FROM variable_groups vg +LEFT JOIN variable_metadata vm ON vg.group_id = vm.group_id +GROUP BY vg.group_id +ORDER BY vg.display_order; +``` + +### Audit Query: Find Variables Without Metadata +```sql +-- Identify variables in targets that lack metadata entries +SELECT DISTINCT + t.variable, + COUNT(*) AS usage_count, + GROUP_CONCAT(DISTINCT s.name) AS sources_using +FROM targets t +LEFT JOIN variable_metadata vm ON t.variable = vm.variable +LEFT JOIN sources s ON t.source_id = s.source_id +WHERE vm.metadata_id IS NULL +GROUP BY t.variable +ORDER BY usage_count DESC; +``` + +### Query by Stratum Group +```sql +-- Get all age-related strata and their targets +SELECT + s.stratum_id, + s.notes, + t.variable, + t.value, + src.name AS source +FROM strata s +JOIN targets t ON s.stratum_id = t.stratum_id +JOIN sources src ON t.source_id = src.source_id +WHERE s.stratum_group_id = 2 -- Age strata +LIMIT 20; + +-- Count strata by group +SELECT + stratum_group_id, + CASE stratum_group_id + WHEN 1 THEN 'Geographic' + WHEN 2 THEN 'Age' + WHEN 3 THEN 'Income/AGI' + WHEN 4 THEN 'SNAP' + WHEN 5 THEN 'Medicaid' + WHEN 6 THEN 'EITC' + END AS group_name, + COUNT(*) AS stratum_count +FROM strata +GROUP BY stratum_group_id +ORDER BY stratum_group_id; +``` + +## Key Improvements Made +1. Removed UCGID as a constraint variable (legacy Census concept) +2. Standardized constraint operations with validation +3. Consolidated duplicate code (parse_ucgid, get_geographic_strata) +4. Fixed epsilon hack in IRS AGI ranges +5. ~~Added proper duplicate checking in age ETL (still has known bug causing duplicates)~~ **RESOLVED** +6. Improved human-readable notes without UCGID strings +7. **NEW: Added metadata tables for sources, variable groups, and variable metadata** +8. **NEW: Fixed synthetic variable name bug (e.g., eitc_tax_unit_count → tax_unit_count)** +9. **NEW: Auto-generated source IDs instead of hardcoding** +10. **NEW: Proper categorization of admin vs survey data for same concepts** +11. **NEW: Implemented conceptual stratum_group_id scheme for better organization and querying** + +## Known Issues / TODOs + +### IMPORTANT: stratum_id vs state_fips Codes +**WARNING**: The `stratum_id` is an auto-generated sequential ID and has NO relationship to FIPS codes, despite some confusing coincidences: +- California: stratum_id = 6, state_fips = "06" (coincidental match!) +- North Carolina: stratum_id = 35, state_fips = "37" (no match) +- Ohio: stratum_id = 37, state_fips = "39" (no match) + +When querying for states, ALWAYS use the `state_fips` constraint value, never assume stratum_id matches FIPS. The calibration code correctly uses `get_state_stratum_id(state_fips)` to look up the proper stratum_id. + +Example of correct lookup: +```sql +-- Find North Carolina's stratum_id by FIPS code +SELECT s.stratum_id, s.notes +FROM strata s +JOIN stratum_constraints sc ON s.stratum_id = sc.stratum_id +WHERE sc.constraint_variable = 'state_fips' + AND sc.value = '37'; -- Returns stratum_id = 35 +``` + +### Type Conversion for Constraint Values +**DESIGN DECISION**: The `value` column in `stratum_constraints` must store heterogeneous data types as strings. The calibration code deserializes these (lines 233-247 in `metrics_matrix_geo_stacking.py`): +- Numeric strings → int/float (for age, income constraints) +- "True"/"False" → Python booleans (for medicaid_enrolled, snap_enrolled) +- Other strings remain strings (for state_fips, which may have leading zeros) + +This explicit type conversion is necessary and correct. The alternative of using "1"/"0" for booleans would work but be less clear in the database. + +### Medicaid Data Structure +- Medicaid uses `person_count` variable (not `medicaid`) because it's structured as a histogram with constraints +- State-level targets use administrative data (T-MSIS source) +- Congressional district level uses survey data (ACS source) +- No national Medicaid target exists (intentionally, to avoid double-counting when using state-level data) \ No newline at end of file diff --git a/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md b/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md new file mode 100644 index 00000000..3d722516 --- /dev/null +++ b/policyengine_us_data/db/IRS_SOI_DATA_ISSUE.md @@ -0,0 +1,109 @@ +# IRS SOI Data Inconsistency: A59664 Units Issue + +## Summary +The IRS Statistics of Income (SOI) Congressional District data file has an undocumented data inconsistency where column A59664 (EITC amount for 3+ children) is reported in **dollars** instead of **thousands of dollars** like all other monetary columns. + +## Discovery Date +December 2024 + +## Affected Data +- **File**: https://www.irs.gov/pub/irs-soi/22incd.csv (and likely other years) +- **Column**: A59664 - "Earned income credit with three qualifying children amount" +- **Issue**: Value is in dollars, not thousands of dollars + +## Evidence + +### 1. Documentation States All Money in Thousands +From the IRS SOI documentation: "For all the files, the money amounts are reported in thousands of dollars." + +### 2. Data Analysis Shows Inconsistency +California example from 2022 data: +``` +A59661 (EITC 0 children): 284,115 (thousands) = $284M ✓ +A59662 (EITC 1 child): 2,086,260 (thousands) = $2.1B ✓ +A59663 (EITC 2 children): 2,067,922 (thousands) = $2.1B ✓ +A59664 (EITC 3+ children): 1,248,669,042 (if thousands) = $1.25 TRILLION ✗ +``` + +### 3. Total EITC Confirms the Issue +``` +A59660 (Total EITC): 5,687,167 (thousands) = $5.69B + +Sum with A59664 as dollars: $5.69B ✓ (matches!) +Sum with A59664 as thousands: $1.25T ✗ (way off!) +``` + +### 4. Pattern Across All States +The ratio of A59664 to A59663 is consistently ~600x across all states: +- California: 603.8x +- North Carolina: 598.9x +- New York: 594.2x +- Texas: 691.5x + +If both were in the same units, this ratio should be 0.5-2x. + +## Additional Finding: "Three" Means "Three or More" + +The documentation says "three qualifying children" but the data shows this represents "three or more": +- Sum of N59661 + N59662 + N59663 + N59664 = 23,261,270 +- N59660 (Total EITC recipients) = 23,266,630 +- Difference: 5,360 (0.02% - essentially equal) + +This confirms that category 4 represents families with 3+ children, not exactly 3. + +## Fix Applied + +In `etl_irs_soi.py`, we now divide A59664 by 1000 before applying the standard multiplier: + +```python +if amount_col == 'A59664': + # Convert from dollars to thousands to match other columns + rec_amounts["target_value"] /= 1_000 +``` + +## Impact Before Fix +- EITC calibration targets for 3+ children were 1000x too high +- California target: $1.25 trillion instead of $1.25 billion +- Made calibration impossible to converge for EITC + +## Verification Steps +1. Download IRS SOI data for any year +2. Check A59660 (total EITC) value +3. Sum A59661-A59664 with A59664 divided by 1000 +4. Confirm sum matches A59660 + +## Recommendation for IRS +The IRS should either: +1. Fix the data to report A59664 in thousands like other columns +2. Document this exception clearly in their documentation + +## Verification Code + +To verify this issue or check if the IRS has fixed it: + +```python +import pandas as pd + +# Load IRS data +df = pd.read_csv('https://www.irs.gov/pub/irs-soi/22incd.csv') +us_data = df[(df['STATE'] == 'US') & (df['agi_stub'] == 0)] + +# Get EITC values +a61 = us_data['A59661'].values[0] * 1000 # 0 children (convert from thousands) +a62 = us_data['A59662'].values[0] * 1000 # 1 child +a63 = us_data['A59663'].values[0] * 1000 # 2 children +a64 = us_data['A59664'].values[0] # 3+ children (already in dollars!) +total = us_data['A59660'].values[0] * 1000 # Total EITC + +print(f'Sum with A59664 as dollars: ${(a61 + a62 + a63 + a64):,.0f}') +print(f'Total EITC (A59660): ${total:,.0f}') +print(f'Match: {abs(total - (a61 + a62 + a63 + a64)) < 1e6}') + +# Check ratio to confirm inconsistency +ratio = us_data['A59664'].values[0] / us_data['A59663'].values[0] +print(f'\nA59664/A59663 ratio: {ratio:.1f}x') +print('(Should be ~0.5-2x if same units, but is ~600x)') +``` + +## Related Files +- `/home/baogorek/devl/policyengine-us-data/policyengine_us_data/db/etl_irs_soi.py` - ETL script with fix and auto-detection \ No newline at end of file diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index df03772d..d9675dc7 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -11,6 +11,7 @@ SQLModel, create_engine, ) +from pydantic import validator from policyengine_us.system import system from policyengine_us_data.storage import STORAGE_FOLDER @@ -30,6 +31,17 @@ ) +class ConstraintOperation(str, Enum): + """Allowed operations for stratum constraints.""" + + EQ = "==" # Equals + NE = "!=" # Not equals + GT = ">" # Greater than + GE = ">=" # Greater than or equal + LT = "<" # Less than + LE = "<=" # Less than or equal + + class Stratum(SQLModel, table=True): """Represents a unique population subgroup (stratum).""" @@ -89,13 +101,13 @@ class StratumConstraint(SQLModel, table=True): __tablename__ = "stratum_constraints" stratum_id: int = Field(foreign_key="strata.stratum_id", primary_key=True) - constraint_variable: USVariable = Field( + constraint_variable: str = Field( primary_key=True, description="The variable the constraint applies to (e.g., 'age').", ) operation: str = Field( primary_key=True, - description="The comparison operator (e.g., 'greater_than_or_equal').", + description="The comparison operator (==, !=, >, >=, <, <=).", ) value: str = Field( description="The value for the constraint rule (e.g., '25')." @@ -106,6 +118,16 @@ class StratumConstraint(SQLModel, table=True): strata_rel: Stratum = Relationship(back_populates="constraints_rel") + @validator("operation") + def validate_operation(cls, v): + """Validate that the operation is one of the allowed values.""" + allowed_ops = [op.value for op in ConstraintOperation] + if v not in allowed_ops: + raise ValueError( + f"Invalid operation '{v}'. Must be one of: {', '.join(allowed_ops)}" + ) + return v + class Target(SQLModel, table=True): """Stores the data values for a specific stratum.""" @@ -137,7 +159,9 @@ class Target(SQLModel, table=True): default=None, description="The numerical value of the target variable." ) source_id: Optional[int] = Field( - default=None, description="Identifier for the data source." + default=None, + foreign_key="sources.source_id", + description="Identifier for the data source.", ) active: bool = Field( default=True, @@ -153,6 +177,134 @@ class Target(SQLModel, table=True): ) strata_rel: Stratum = Relationship(back_populates="targets_rel") + source_rel: Optional["Source"] = Relationship() + + +class SourceType(str, Enum): + """Types of data sources.""" + + ADMINISTRATIVE = "administrative" + SURVEY = "survey" + SYNTHETIC = "synthetic" + DERIVED = "derived" + HARDCODED = ( + "hardcoded" # Values from various sources, hardcoded into the system + ) + + +class Source(SQLModel, table=True): + """Metadata about data sources.""" + + __tablename__ = "sources" + __table_args__ = ( + UniqueConstraint("name", "vintage", name="uq_source_name_vintage"), + ) + + source_id: Optional[int] = Field( + default=None, + primary_key=True, + description="Unique identifier for the data source.", + ) + name: str = Field( + description="Name of the data source (e.g., 'IRS SOI', 'Census ACS').", + index=True, + ) + type: SourceType = Field( + description="Type of data source (administrative, survey, etc.)." + ) + description: Optional[str] = Field( + default=None, description="Detailed description of the data source." + ) + url: Optional[str] = Field( + default=None, + description="URL or reference to the original data source.", + ) + vintage: Optional[str] = Field( + default=None, description="Version or release date of the data source." + ) + notes: Optional[str] = Field( + default=None, description="Additional notes about the source." + ) + + +class VariableGroup(SQLModel, table=True): + """Groups of related variables that form logical units.""" + + __tablename__ = "variable_groups" + + group_id: Optional[int] = Field( + default=None, + primary_key=True, + description="Unique identifier for the variable group.", + ) + name: str = Field( + description="Name of the variable group (e.g., 'age_distribution', 'snap_recipients').", + index=True, + unique=True, + ) + category: str = Field( + description="High-level category (e.g., 'demographic', 'benefit', 'tax', 'income').", + index=True, + ) + is_histogram: bool = Field( + default=False, + description="Whether this group represents a histogram/distribution.", + ) + is_exclusive: bool = Field( + default=False, + description="Whether variables in this group are mutually exclusive.", + ) + aggregation_method: Optional[str] = Field( + default=None, + description="How to aggregate variables in this group (sum, weighted_avg, etc.).", + ) + display_order: Optional[int] = Field( + default=None, + description="Order for displaying this group in matrices/reports.", + ) + description: Optional[str] = Field( + default=None, description="Description of what this group represents." + ) + + +class VariableMetadata(SQLModel, table=True): + """Maps PolicyEngine variables to their groups and provides metadata.""" + + __tablename__ = "variable_metadata" + __table_args__ = ( + UniqueConstraint("variable", name="uq_variable_metadata_variable"), + ) + + metadata_id: Optional[int] = Field(default=None, primary_key=True) + variable: str = Field( + description="PolicyEngine variable name.", index=True + ) + group_id: Optional[int] = Field( + default=None, + foreign_key="variable_groups.group_id", + description="ID of the variable group this belongs to.", + ) + display_name: Optional[str] = Field( + default=None, + description="Human-readable name for display in matrices.", + ) + display_order: Optional[int] = Field( + default=None, + description="Order within its group for display purposes.", + ) + units: Optional[str] = Field( + default=None, + description="Units of measurement (dollars, count, percent, etc.).", + ) + is_primary: bool = Field( + default=True, + description="Whether this is a primary variable vs derived/auxiliary.", + ) + notes: Optional[str] = Field( + default=None, description="Additional notes about the variable." + ) + + group_rel: Optional[VariableGroup] = Relationship() # This SQLAlchemy event listener works directly with the SQLModel class @@ -169,7 +321,13 @@ def calculate_definition_hash(mapper, connection, target: Stratum): return if not target.constraints_rel: # Handle cases with no constraints - target.definition_hash = hashlib.sha256(b"").hexdigest() + # Include parent_stratum_id to make hash unique per parent + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) + target.definition_hash = hashlib.sha256( + parent_str.encode("utf-8") + ).hexdigest() return constraint_strings = [ @@ -178,7 +336,11 @@ def calculate_definition_hash(mapper, connection, target: Stratum): ] constraint_strings.sort() - fingerprint_text = "\n".join(constraint_strings) + # Include parent_stratum_id in the hash to ensure uniqueness per parent + parent_str = ( + str(target.parent_stratum_id) if target.parent_stratum_id else "" + ) + fingerprint_text = parent_str + "\n" + "\n".join(constraint_strings) h = hashlib.sha256(fingerprint_text.encode("utf-8")) target.definition_hash = h.hexdigest() diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index 5653948b..17345fa7 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -1,72 +1,186 @@ from typing import Dict +import requests +import pandas as pd import pandas as pd from sqlmodel import Session, create_engine from policyengine_us_data.storage import STORAGE_FOLDER - - -from policyengine_us.variables.household.demographic.geographic.ucgid.ucgid_enum import ( - UCGID, -) from policyengine_us_data.db.create_database_tables import ( Stratum, StratumConstraint, ) -def main(): - # Get the implied hierarchy by the UCGID enum -------- - rows = [] - for node in UCGID: - codes = node.get_hierarchical_codes() - rows.append( - { - "name": node.name, - "code": codes[0], - "parent": codes[1] if len(codes) > 1 else None, - } - ) +def fetch_congressional_districts(year): + + # Fetch from Census API + base_url = f"https://api.census.gov/data/{year}/acs/acs5" + params = { + "get": "NAME", + "for": "congressional district:*", + "in": "state:*", + } + + response = requests.get(base_url, params=params) + data = response.json() - hierarchy_df = ( - pd.DataFrame(rows) - .sort_values(["parent", "code"], na_position="first") - .reset_index(drop=True) + df = pd.DataFrame(data[1:], columns=data[0]) + df["state_fips"] = df["state"].astype(int) + df = df[df["state_fips"] <= 56].copy() + df["district_number"] = df["congressional district"].apply( + lambda x: 0 if x in ["ZZ", "98"] else int(x) ) + # Filter out statewide summary records for multi-district states + df["n_districts"] = df.groupby("state_fips")["state_fips"].transform( + "count" + ) + df = df[(df["n_districts"] == 1) | (df["district_number"] > 0)].copy() + df = df.drop(columns=["n_districts"]) + + df.loc[df["district_number"] == 0, "district_number"] = 1 + df["congressional_district_geoid"] = ( + df["state_fips"] * 100 + df["district_number"] + ) + + df = df[ + [ + "state_fips", + "district_number", + "congressional_district_geoid", + "NAME", + ] + ] + df = df.sort_values("congressional_district_geoid") + + return df + + +def main(): + # State FIPS to name/abbreviation mapping + STATE_NAMES = { + 1: "Alabama (AL)", + 2: "Alaska (AK)", + 4: "Arizona (AZ)", + 5: "Arkansas (AR)", + 6: "California (CA)", + 8: "Colorado (CO)", + 9: "Connecticut (CT)", + 10: "Delaware (DE)", + 11: "District of Columbia (DC)", + 12: "Florida (FL)", + 13: "Georgia (GA)", + 15: "Hawaii (HI)", + 16: "Idaho (ID)", + 17: "Illinois (IL)", + 18: "Indiana (IN)", + 19: "Iowa (IA)", + 20: "Kansas (KS)", + 21: "Kentucky (KY)", + 22: "Louisiana (LA)", + 23: "Maine (ME)", + 24: "Maryland (MD)", + 25: "Massachusetts (MA)", + 26: "Michigan (MI)", + 27: "Minnesota (MN)", + 28: "Mississippi (MS)", + 29: "Missouri (MO)", + 30: "Montana (MT)", + 31: "Nebraska (NE)", + 32: "Nevada (NV)", + 33: "New Hampshire (NH)", + 34: "New Jersey (NJ)", + 35: "New Mexico (NM)", + 36: "New York (NY)", + 37: "North Carolina (NC)", + 38: "North Dakota (ND)", + 39: "Ohio (OH)", + 40: "Oklahoma (OK)", + 41: "Oregon (OR)", + 42: "Pennsylvania (PA)", + 44: "Rhode Island (RI)", + 45: "South Carolina (SC)", + 46: "South Dakota (SD)", + 47: "Tennessee (TN)", + 48: "Texas (TX)", + 49: "Utah (UT)", + 50: "Vermont (VT)", + 51: "Virginia (VA)", + 53: "Washington (WA)", + 54: "West Virginia (WV)", + 55: "Wisconsin (WI)", + 56: "Wyoming (WY)", + } + + # Fetch congressional district data for year 2023 + year = 2023 + cd_df = fetch_congressional_districts(year) + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" engine = create_engine(DATABASE_URL) - # map the ucgid_str 'code' to auto-generated 'stratum_id' - code_to_stratum_id: Dict[str, int] = {} - with Session(engine) as session: - for _, row in hierarchy_df.iterrows(): - parent_code = row["parent"] + # Truncate existing tables + session.query(StratumConstraint).delete() + session.query(Stratum).delete() + session.commit() - parent_id = ( - code_to_stratum_id.get(parent_code) if parent_code else None + # Create national level stratum + us_stratum = Stratum( + parent_stratum_id=None, + notes="United States", + stratum_group_id=1, + ) + us_stratum.constraints_rel = [] # No constraints for national level + session.add(us_stratum) + session.flush() + us_stratum_id = us_stratum.stratum_id + + # Track state strata for parent relationships + state_stratum_ids = {} + + # Create state-level strata + unique_states = cd_df["state_fips"].unique() + for state_fips in sorted(unique_states): + state_name = STATE_NAMES.get( + state_fips, f"State FIPS {state_fips}" ) - - new_stratum = Stratum( - parent_stratum_id=parent_id, - notes=f'{row["name"]} (ucgid {row["code"]})', + state_stratum = Stratum( + parent_stratum_id=us_stratum_id, + notes=state_name, stratum_group_id=1, ) - - new_stratum.constraints_rel = [ + state_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["code"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ) ] - - session.add(new_stratum) - + session.add(state_stratum) session.flush() + state_stratum_ids[state_fips] = state_stratum.stratum_id + + # Create congressional district strata + for _, row in cd_df.iterrows(): + state_fips = row["state_fips"] + cd_geoid = row["congressional_district_geoid"] + name = row["NAME"] - code_to_stratum_id[row["code"]] = new_stratum.stratum_id + cd_stratum = Stratum( + parent_stratum_id=state_stratum_ids[state_fips], + notes=f"{name} (CD GEOID {cd_geoid})", + stratum_group_id=1, + ) + cd_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), + ) + ] + session.add(cd_stratum) session.commit() diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index bb83067c..e878458d 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -1,6 +1,6 @@ import pandas as pd import numpy as np -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER @@ -8,8 +8,15 @@ Stratum, StratumConstraint, Target, + SourceType, ) from policyengine_us_data.utils.census import get_census_docs, pull_acs_table +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, +) LABEL_TO_SHORT = { @@ -82,17 +89,12 @@ def transform_age_data(age_data, docs): df_long["age_less_than"] = age_bounds[["lt"]] df_long["variable"] = "person_count" df_long["reform_id"] = 0 - df_long["source_id"] = 1 df_long["active"] = True return df_long -def get_parent_geo(geo): - return {"National": None, "State": "National", "District": "State"}[geo] - - -def load_age_data(df_long, geo, year, stratum_lookup=None): +def load_age_data(df_long, geo, year): # Quick data quality check before loading ---- if geo == "National": @@ -108,51 +110,151 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" engine = create_engine(DATABASE_URL) - if stratum_lookup is None: - if geo != "National": - raise ValueError("Include stratum_lookup unless National geo") - stratum_lookup = {"National": {}} - else: - stratum_lookup[geo] = {} - with Session(engine) as session: + # Get or create the Census ACS source + census_source = get_or_create_source( + session, + name="Census ACS Table S0101", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 5-year estimates", + description="American Community Survey Age and Sex demographics", + url="https://data.census.gov/", + notes="Age distribution in 18 brackets across all geographic levels", + ) + + # Get or create the age distribution variable group + age_group = get_or_create_variable_group( + session, + name="age_distribution", + category="demographic", + is_histogram=True, + is_exclusive=True, + aggregation_method="sum", + display_order=1, + description="Age distribution in 18 brackets (0-4, 5-9, ..., 85+)", + ) + + # Create variable metadata for person_count + get_or_create_variable_metadata( + session, + variable="person_count", + group=age_group, + display_name="Population Count", + display_order=1, + units="count", + notes="Number of people in age bracket", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + for _, row in df_long.iterrows(): - # Create the parent Stratum object. - # We will attach children to it before adding it to the session. - note = f"Age: {row['age_range']}, Geo: {row['ucgid_str']}" - parent_geo = get_parent_geo(geo) - parent_stratum_id = ( - stratum_lookup[parent_geo][row["age_range"]] - if parent_geo - else None - ) + # Parse the UCGID to determine geographic info + geo_info = parse_ucgid(row["ucgid_str"]) + + # Determine parent stratum based on geographic level + if geo_info["type"] == "national": + parent_stratum_id = geo_strata["national"] + elif geo_info["type"] == "state": + parent_stratum_id = geo_strata["state"][geo_info["state_fips"]] + elif geo_info["type"] == "district": + parent_stratum_id = geo_strata["district"][ + geo_info["congressional_district_geoid"] + ] + else: + raise ValueError(f"Unknown geography type: {geo_info['type']}") + + # Create the age stratum as a child of the geographic stratum + # Build a proper geographic identifier for the notes + if geo_info["type"] == "national": + geo_desc = "US" + elif geo_info["type"] == "state": + geo_desc = f"State FIPS {geo_info['state_fips']}" + elif geo_info["type"] == "district": + geo_desc = f"CD {geo_info['congressional_district_geoid']}" + else: + geo_desc = "Unknown" + + note = f"Age: {row['age_range']}, {geo_desc}" + + # Check if this age stratum already exists + existing_stratum = session.exec( + select(Stratum).where( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 2, # Age strata group + Stratum.notes == note, + ) + ).first() + + if existing_stratum: + # Update the existing stratum's target instead of creating a duplicate + existing_target = session.exec( + select(Target).where( + Target.stratum_id == existing_stratum.stratum_id, + Target.variable == row["variable"], + Target.period == year, + ) + ).first() + + if existing_target: + # Update existing target + existing_target.value = row["value"] + else: + # Add new target to existing stratum + new_target = Target( + stratum_id=existing_stratum.stratum_id, + variable=row["variable"], + period=year, + value=row["value"], + source_id=census_source.source_id, + active=row["active"], + ) + session.add(new_target) + continue # Skip creating a new stratum new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=2, # Age strata group notes=note, ) - # Create constraints and link them to the parent's relationship attribute. - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], - ), + # Create constraints including both age and geographic for uniqueness + new_stratum.constraints_rel = [] + + # Add geographic constraints based on level + if geo_info["type"] == "state": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ) + ) + elif geo_info["type"] == "district": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ) + ) + # For national level, no geographic constraint needed + + # Add age constraints + new_stratum.constraints_rel.append( StratumConstraint( constraint_variable="age", - operation="greater_than", + operation=">", value=str(row["age_greater_than"]), - ), - ] + ) + ) age_lt_value = row["age_less_than"] if not np.isinf(age_lt_value): new_stratum.constraints_rel.append( StratumConstraint( constraint_variable="age", - operation="less_than", + operation="<", value=str(row["age_less_than"]), ) ) @@ -163,7 +265,7 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): variable=row["variable"], period=year, value=row["value"], - source_id=row["source_id"], + source_id=census_source.source_id, active=row["active"], ) ) @@ -172,15 +274,9 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): # The 'cascade' setting will handle the children automatically. session.add(new_stratum) - # Flush to get the id - session.flush() - stratum_lookup[geo][row["age_range"]] = new_stratum.stratum_id - # Commit all the new objects at once. session.commit() - return stratum_lookup - if __name__ == "__main__": @@ -199,8 +295,8 @@ def load_age_data(df_long, geo, year, stratum_lookup=None): long_district_df = transform_age_data(district_df, docs) # --- Load -------- - national_strata_lku = load_age_data(long_national_df, "National", year) - state_strata_lku = load_age_data( - long_state_df, "State", year, national_strata_lku - ) - load_age_data(long_district_df, "District", year, state_strata_lku) + # Note: The geographic strata must already exist in the database + # (created by create_initial_strata.py) + load_age_data(long_national_df, "National", year) + load_age_data(long_state_df, "State", year) + load_age_data(long_district_df, "District", year) diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index 786abb1c..46601e8c 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER @@ -11,13 +11,20 @@ Stratum, StratumConstraint, Target, + SourceType, ) from policyengine_us_data.utils.db import ( get_stratum_by_id, - get_simple_stratum_by_ucgid, get_root_strata, get_stratum_children, get_stratum_parent, + parse_ucgid, + get_geographic_strata, +) +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, ) from policyengine_us_data.utils.census import TERRITORY_UCGIDS from policyengine_us_data.storage.calibration_targets.make_district_mapping import ( @@ -26,19 +33,17 @@ """See the 22incddocguide.docx manual from the IRS SOI""" -# Let's make this work with strict inequalities -# Language in the doc: '$10,000 under $25,000' -epsilon = 0.005 # i.e., half a penny +# Language in the doc: '$10,000 under $25,000' means >= $10,000 and < $25,000 AGI_STUB_TO_INCOME_RANGE = { - 1: (-np.inf, 1), - 2: (1 - epsilon, 10_000), - 3: (10_000 - epsilon, 25_000), - 4: (25_000 - epsilon, 50_000), - 5: (50_000 - epsilon, 75_000), - 6: (75_000 - epsilon, 100_000), - 7: (100_000 - epsilon, 200_000), - 8: (200_000 - epsilon, 500_000), - 9: (500_000 - epsilon, np.inf), + 1: (-np.inf, 1), # Under $1 (negative AGI allowed) + 2: (1, 10_000), # $1 under $10,000 + 3: (10_000, 25_000), # $10,000 under $25,000 + 4: (25_000, 50_000), # $25,000 under $50,000 + 5: (50_000, 75_000), # $50,000 under $75,000 + 6: (75_000, 100_000), # $75,000 under $100,000 + 7: (100_000, 200_000), # $100,000 under $200,000 + 8: (200_000, 500_000), # $200,000 under $500,000 + 9: (500_000, np.inf), # $500,000 or more } @@ -61,6 +66,15 @@ def make_records( breakdown_col: Optional[str] = None, multiplier: int = 1_000, ): + """ + Create standardized records from IRS SOI data. + + IMPORTANT DATA INCONSISTENCY (discovered 2024-12): + The IRS SOI documentation states "money amounts are reported in thousands of dollars." + This is true for almost all columns EXCEPT A59664 (EITC with 3+ children amount), + which is already in dollars, not thousands. This appears to be a data quality issue + in the IRS SOI file itself. We handle this special case below. + """ df = df.rename( {count_col: "tax_unit_count", amount_col: amount_name}, axis=1 ).copy() @@ -71,8 +85,31 @@ def make_records( rec_counts = create_records(df, breakdown_col, "tax_unit_count") rec_amounts = create_records(df, breakdown_col, amount_name) - rec_amounts["target_value"] *= multiplier # Only the amounts get * 1000 - rec_counts["target_variable"] = f"{amount_name}_tax_unit_count" + + # SPECIAL CASE: A59664 (EITC with 3+ children) is already in dollars, not thousands! + # All other EITC amounts (A59661-A59663) are correctly in thousands. + # This was verified by checking that A59660 (total EITC) equals the sum only when + # A59664 is treated as already being in dollars. + if amount_col == "A59664": + # Check if IRS has fixed the data inconsistency + # If values are < 10 million, they're likely already in thousands (fixed) + max_value = rec_amounts["target_value"].max() + if max_value < 10_000_000: + print( + f"WARNING: A59664 values appear to be in thousands (max={max_value:,.0f})" + ) + print("The IRS may have fixed their data inconsistency.") + print( + "Please verify and remove the special case handling if confirmed." + ) + # Don't apply the fix - data appears to already be in thousands + else: + # Convert from dollars to thousands to match other columns + rec_amounts["target_value"] /= 1_000 + + rec_amounts["target_value"] *= multiplier # Apply standard multiplier + # Note: tax_unit_count is the correct variable - the stratum constraints + # indicate what is being counted (e.g., eitc > 0 for EITC recipients) return rec_counts, rec_amounts @@ -150,7 +187,42 @@ def extract_soi_data() -> pd.DataFrame: In the file below, "22" is 2022, "in" is individual returns, "cd" is congressional districts """ - return pd.read_csv("https://www.irs.gov/pub/irs-soi/22incd.csv") + df = pd.read_csv("https://www.irs.gov/pub/irs-soi/22incd.csv") + + # Validate EITC data consistency (check if IRS fixed the A59664 issue) + us_data = df[(df["STATE"] == "US") & (df["agi_stub"] == 0)] + if not us_data.empty and all( + col in us_data.columns + for col in ["A59660", "A59661", "A59662", "A59663", "A59664"] + ): + total_eitc = us_data["A59660"].values[0] + sum_as_thousands = ( + us_data["A59661"].values[0] + + us_data["A59662"].values[0] + + us_data["A59663"].values[0] + + us_data["A59664"].values[0] + ) + sum_mixed = ( + us_data["A59661"].values[0] + + us_data["A59662"].values[0] + + us_data["A59663"].values[0] + + us_data["A59664"].values[0] / 1000 + ) + + # Check which interpretation matches the total + if abs(total_eitc - sum_as_thousands) < 100: # Within 100K (thousands) + print("=" * 60) + print("ALERT: IRS may have fixed the A59664 data inconsistency!") + print(f"Total EITC (A59660): {total_eitc:,.0f}") + print(f"Sum treating A59664 as thousands: {sum_as_thousands:,.0f}") + print("These now match! Please verify and update the code.") + print("=" * 60) + elif abs(total_eitc - sum_mixed) < 100: + print( + "Note: A59664 still has the units inconsistency (in dollars, not thousands)" + ) + + return df def transform_soi_data(raw_df): @@ -159,14 +231,20 @@ def transform_soi_data(raw_df): dict(code="59661", name="eitc", breakdown=("eitc_child_count", 0)), dict(code="59662", name="eitc", breakdown=("eitc_child_count", 1)), dict(code="59663", name="eitc", breakdown=("eitc_child_count", 2)), - dict(code="59664", name="eitc", breakdown=("eitc_child_count", "3+")), + dict( + code="59664", name="eitc", breakdown=("eitc_child_count", "3+") + ), # Doc says "three" but data shows this is 3+ dict( code="04475", name="qualified_business_income_deduction", breakdown=None, ), + dict(code="00900", name="self_employment_income", breakdown=None), + dict( + code="01000", name="net_capital_gains", breakdown=None + ), # Not to be confused with the always positive net_capital_gain dict(code="18500", name="real_estate_taxes", breakdown=None), - dict(code="01000", name="net_capital_gain", breakdown=None), + dict(code="25870", name="rental_income", breakdown=None), dict(code="01400", name="taxable_ira_distributions", breakdown=None), dict(code="00300", name="taxable_interest_income", breakdown=None), dict(code="00400", name="tax_exempt_interest_income", breakdown=None), @@ -184,6 +262,7 @@ def transform_soi_data(raw_df): dict(code="11070", name="refundable_ctc", breakdown=None), dict(code="18425", name="salt", breakdown=None), dict(code="06500", name="income_tax", breakdown=None), + dict(code="05800", name="income_tax_before_credits", breakdown=None), ] # National --------------- @@ -291,6 +370,294 @@ def load_soi_data(long_dfs, year): session = Session(engine) + # Get or create the IRS SOI source + irs_source = get_or_create_source( + session, + name="IRS Statistics of Income", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"{year} Tax Year", + description="IRS Statistics of Income administrative tax data", + url="https://www.irs.gov/statistics", + notes="Tax return data by congressional district, state, and national levels", + ) + + # Create variable groups + agi_group = get_or_create_variable_group( + session, + name="agi_distribution", + category="income", + is_histogram=True, + is_exclusive=True, + aggregation_method="sum", + display_order=4, + description="Adjusted Gross Income distribution by IRS income stubs", + ) + + eitc_group = get_or_create_variable_group( + session, + name="eitc_recipients", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=5, + description="Earned Income Tax Credit by number of qualifying children", + ) + + ctc_group = get_or_create_variable_group( + session, + name="ctc_recipients", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=6, + description="Child Tax Credit recipients and amounts", + ) + + income_components_group = get_or_create_variable_group( + session, + name="income_components", + category="income", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=7, + description="Components of income (interest, dividends, capital gains, etc.)", + ) + + deductions_group = get_or_create_variable_group( + session, + name="tax_deductions", + category="tax", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=8, + description="Tax deductions (SALT, medical, real estate, etc.)", + ) + + # Create variable metadata + # EITC - both amount and count use same variable with different constraints + get_or_create_variable_metadata( + session, + variable="eitc", + group=eitc_group, + display_name="EITC Amount", + display_order=1, + units="dollars", + notes="EITC amounts by number of qualifying children", + ) + + # For counts, tax_unit_count is used with appropriate constraints + get_or_create_variable_metadata( + session, + variable="tax_unit_count", + group=None, # This spans multiple groups based on constraints + display_name="Tax Unit Count", + display_order=100, + units="count", + notes="Number of tax units - meaning depends on stratum constraints", + ) + + # CTC + get_or_create_variable_metadata( + session, + variable="refundable_ctc", + group=ctc_group, + display_name="Refundable CTC", + display_order=1, + units="dollars", + ) + + # AGI and related + get_or_create_variable_metadata( + session, + variable="adjusted_gross_income", + group=agi_group, + display_name="Adjusted Gross Income", + display_order=1, + units="dollars", + ) + + get_or_create_variable_metadata( + session, + variable="person_count", + group=agi_group, + display_name="Person Count", + display_order=3, + units="count", + notes="Number of people in tax units by AGI bracket", + ) + + # Income components + income_vars = [ + ("taxable_interest_income", "Taxable Interest", 1), + ("tax_exempt_interest_income", "Tax-Exempt Interest", 2), + ("dividend_income", "Ordinary Dividends", 3), + ("qualified_dividend_income", "Qualified Dividends", 4), + ("net_capital_gain", "Net Capital Gain", 5), + ("taxable_ira_distributions", "Taxable IRA Distributions", 6), + ("taxable_pension_income", "Taxable Pensions", 7), + ("taxable_social_security", "Taxable Social Security", 8), + ("unemployment_compensation", "Unemployment Compensation", 9), + ( + "tax_unit_partnership_s_corp_income", + "Partnership/S-Corp Income", + 10, + ), + ] + + for var_name, display_name, order in income_vars: + get_or_create_variable_metadata( + session, + variable=var_name, + group=income_components_group, + display_name=display_name, + display_order=order, + units="dollars", + ) + + # Deductions + deduction_vars = [ + ("salt", "State and Local Taxes", 1), + ("real_estate_taxes", "Real Estate Taxes", 2), + ("medical_expense_deduction", "Medical Expenses", 3), + ("qualified_business_income_deduction", "QBI Deduction", 4), + ] + + for var_name, display_name, order in deduction_vars: + get_or_create_variable_metadata( + session, + variable=var_name, + group=deductions_group, + display_name=display_name, + display_order=order, + units="dollars", + ) + + # Income tax + get_or_create_variable_metadata( + session, + variable="income_tax", + group=None, # Could create a tax_liability group if needed + display_name="Income Tax", + display_order=1, + units="dollars", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + + # Create filer strata as intermediate layer between geographic and IRS-specific strata + # All IRS data represents only tax filers, not the entire population + filer_strata = {"national": None, "state": {}, "district": {}} + + # National filer stratum - check if it exists first + national_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == geo_strata["national"], + Stratum.notes == "United States - Tax Filers", + ) + .first() + ) + + if not national_filer_stratum: + national_filer_stratum = Stratum( + parent_stratum_id=geo_strata["national"], + stratum_group_id=2, # Filer population group + notes="United States - Tax Filers", + ) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(national_filer_stratum) + session.flush() + + filer_strata["national"] = national_filer_stratum.stratum_id + + # State filer strata + for state_fips, state_geo_stratum_id in geo_strata["state"].items(): + # Check if state filer stratum exists + state_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == state_geo_stratum_id, + Stratum.notes == f"State FIPS {state_fips} - Tax Filers", + ) + .first() + ) + + if not state_filer_stratum: + state_filer_stratum = Stratum( + parent_stratum_id=state_geo_stratum_id, + stratum_group_id=2, # Filer population group + notes=f"State FIPS {state_fips} - Tax Filers", + ) + state_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(state_fips), + ), + ] + session.add(state_filer_stratum) + session.flush() + + filer_strata["state"][state_fips] = state_filer_stratum.stratum_id + + # District filer strata + for district_geoid, district_geo_stratum_id in geo_strata[ + "district" + ].items(): + # Check if district filer stratum exists + district_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == district_geo_stratum_id, + Stratum.notes + == f"Congressional District {district_geoid} - Tax Filers", + ) + .first() + ) + + if not district_filer_stratum: + district_filer_stratum = Stratum( + parent_stratum_id=district_geo_stratum_id, + stratum_group_id=2, # Filer population group + notes=f"Congressional District {district_geoid} - Tax Filers", + ) + district_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(district_geoid), + ), + ] + session.add(district_filer_stratum) + session.flush() + + filer_strata["district"][ + district_geoid + ] = district_filer_stratum.stratum_id + + session.commit() + # Load EITC data -------------------------------------------------------- eitc_data = { "0": (long_dfs[0], long_dfs[1]), @@ -299,73 +666,147 @@ def load_soi_data(long_dfs, year): "3+": (long_dfs[6], long_dfs[7]), } - stratum_lookup = {"State": {}, "District": {}} + eitc_stratum_lookup = {"national": {}, "state": {}, "district": {}} for n_children in eitc_data.keys(): eitc_count_i, eitc_amount_i = eitc_data[n_children] for i in range(eitc_count_i.shape[0]): ucgid_i = eitc_count_i[["ucgid_str"]].iloc[i].values[0] - note = f"Geo: {ucgid_i}, EITC received with {n_children} children" + geo_info = parse_ucgid(ucgid_i) - if len(ucgid_i) == 9: # National. - new_stratum = Stratum( - parent_stratum_id=None, stratum_group_id=0, notes=note - ) - elif len(ucgid_i) == 11: # State - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["National"], - stratum_group_id=0, - notes=note, + # Determine parent stratum based on geographic level - use filer strata not geo strata + if geo_info["type"] == "national": + parent_stratum_id = filer_strata["national"] + note = f"National EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + elif geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + note = f"State FIPS {geo_info['state_fips']} EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ), + ] + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + note = f"Congressional District {geo_info['congressional_district_geoid']} EITC received with {n_children} children (filers)" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ), + ] + + # Check if stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 6, + Stratum.notes == note, ) - elif len(ucgid_i) == 13: # District + .first() + ) + + if existing_stratum: + new_stratum = existing_stratum + else: new_stratum = Stratum( - parent_stratum_id=stratum_lookup["State"][ - "0400000US" + ucgid_i[9:11] - ], - stratum_group_id=0, + parent_stratum_id=parent_stratum_id, + stratum_group_id=6, # EITC strata group notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=ucgid_i, - ), - ] - if n_children == "3+": - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation="greater_than", - value="2", + new_stratum.constraints_rel = constraints + if n_children == "3+": + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="eitc_child_count", + operation=">", + value="2", + ) ) - ) - else: - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation="equals", - value=f"{n_children}", + else: + new_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="eitc_child_count", + operation="==", + value=f"{n_children}", + ) ) - ) - new_stratum.targets_rel = [ - Target( - variable="eitc", - period=year, - value=eitc_amount_i.iloc[i][["target_value"]].values[0], - source_id=5, - active=True, + session.add(new_stratum) + session.flush() + + # Get both count and amount values + count_value = eitc_count_i.iloc[i][["target_value"]].values[0] + amount_value = eitc_amount_i.iloc[i][["target_value"]].values[0] + + # Check if targets already exist and update or create them + for variable, value in [ + ("tax_unit_count", count_value), + ("eitc", amount_value), + ]: + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == new_stratum.stratum_id, + Target.variable == variable, + Target.period == year, + ) + .first() ) - ] + + if existing_target: + existing_target.value = value + existing_target.source_id = irs_source.source_id + else: + new_stratum.targets_rel.append( + Target( + variable=variable, + period=year, + value=value, + source_id=irs_source.source_id, + active=True, + ) + ) session.add(new_stratum) session.flush() - if len(ucgid_i) == 9: - stratum_lookup["National"] = new_stratum.stratum_id - elif len(ucgid_i) == 11: - stratum_lookup["State"][ucgid_i] = new_stratum.stratum_id + # Store lookup for later use + if geo_info["type"] == "national": + eitc_stratum_lookup["national"][ + n_children + ] = new_stratum.stratum_id + elif geo_info["type"] == "state": + key = (geo_info["state_fips"], n_children) + eitc_stratum_lookup["state"][key] = new_stratum.stratum_id + elif geo_info["type"] == "district": + key = (geo_info["congressional_district_geoid"], n_children) + eitc_stratum_lookup["district"][key] = new_stratum.stratum_id session.commit() @@ -377,30 +818,140 @@ def load_soi_data(long_dfs, year): == "adjusted_gross_income" and long_dfs[i][["breakdown_variable"]].values[0] == "one" ][0] + # IRS variables start at stratum_group_id 100 + irs_group_id_start = 100 + for j in range(8, first_agi_index, 2): count_j, amount_j = long_dfs[j], long_dfs[j + 1] + count_variable_name = count_j.iloc[0][["target_variable"]].values[ + 0 + ] # Should be tax_unit_count amount_variable_name = amount_j.iloc[0][["target_variable"]].values[0] + + # Assign a unique stratum_group_id for this IRS variable + stratum_group_id = irs_group_id_start + (j - 8) // 2 + print( - f"Loading amount data for IRS SOI data on {amount_variable_name}" + f"Loading count and amount data for IRS SOI data on {amount_variable_name} (group_id={stratum_group_id})" ) + for i in range(count_j.shape[0]): ucgid_i = count_j[["ucgid_str"]].iloc[i].values[0] + geo_info = parse_ucgid(ucgid_i) + + # Get parent filer stratum (not geographic stratum) + if geo_info["type"] == "national": + parent_stratum_id = filer_strata["national"] + geo_description = "National" + elif geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + geo_description = f"State {geo_info['state_fips']}" + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + geo_description = ( + f"CD {geo_info['congressional_district_geoid']}" + ) - # Reusing an existing stratum this time, since there is no breakdown - stratum = get_simple_stratum_by_ucgid(session, ucgid_i) - amount_value = amount_j.iloc[i][["target_value"]].values[0] + # Create child stratum with constraint for this IRS variable + # Note: This stratum will have the constraint that amount_variable > 0 + note = f"{geo_description} filers with {amount_variable_name} > 0" - stratum.targets_rel.append( - Target( - variable=amount_variable_name, - period=year, - value=amount_value, - source_id=5, - active=True, + # Check if child stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == stratum_group_id, ) + .first() ) - session.add(stratum) + if existing_stratum: + child_stratum = existing_stratum + else: + # Create new child stratum with constraint + child_stratum = Stratum( + parent_stratum_id=parent_stratum_id, + stratum_group_id=stratum_group_id, + notes=note, + ) + + # Add constraints - filer status and this IRS variable must be positive + child_stratum.constraints_rel.extend( + [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable=amount_variable_name, + operation=">", + value="0", + ), + ] + ) + + # Add geographic constraints if applicable + if geo_info["type"] == "state": + child_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ) + ) + elif geo_info["type"] == "district": + child_stratum.constraints_rel.append( + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str( + geo_info["congressional_district_geoid"] + ), + ) + ) + + session.add(child_stratum) + session.flush() + + count_value = count_j.iloc[i][["target_value"]].values[0] + amount_value = amount_j.iloc[i][["target_value"]].values[0] + + # Check if targets already exist and update or create them + for variable, value in [ + (count_variable_name, count_value), + (amount_variable_name, amount_value), + ]: + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == child_stratum.stratum_id, + Target.variable == variable, + Target.period == year, + ) + .first() + ) + + if existing_target: + existing_target.value = value + existing_target.source_id = irs_source.source_id + else: + child_stratum.targets_rel.append( + Target( + variable=variable, + period=year, + value=value, + source_id=irs_source.source_id, + active=True, + ) + ) + + session.add(child_stratum) session.flush() session.commit() @@ -411,16 +962,49 @@ def load_soi_data(long_dfs, year): for i in range(agi_values.shape[0]): ucgid_i = agi_values[["ucgid_str"]].iloc[i].values[0] - stratum = get_simple_stratum_by_ucgid(session, ucgid_i) - stratum.targets_rel.append( - Target( - variable="adjusted_gross_income", - period=year, - value=agi_values.iloc[i][["target_value"]].values[0], - source_id=5, - active=True, + geo_info = parse_ucgid(ucgid_i) + + # Add target to existing FILER stratum (not geographic stratum) + if geo_info["type"] == "national": + stratum = session.get(Stratum, filer_strata["national"]) + elif geo_info["type"] == "state": + stratum = session.get( + Stratum, filer_strata["state"][geo_info["state_fips"]] + ) + elif geo_info["type"] == "district": + stratum = session.get( + Stratum, + filer_strata["district"][ + geo_info["congressional_district_geoid"] + ], ) + + # Check if target already exists + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == stratum.stratum_id, + Target.variable == "adjusted_gross_income", + Target.period == year, + ) + .first() ) + + if existing_target: + existing_target.value = agi_values.iloc[i][ + ["target_value"] + ].values[0] + existing_target.source_id = irs_source.source_id + else: + stratum.targets_rel.append( + Target( + variable="adjusted_gross_income", + period=year, + value=agi_values.iloc[i][["target_value"]].values[0], + source_id=irs_source.source_id, + active=True, + ) + ) session.add(stratum) session.flush() @@ -437,93 +1021,167 @@ def load_soi_data(long_dfs, year): agi_income_lower, agi_income_upper = AGI_STUB_TO_INCOME_RANGE[agi_stub] # Make a National Stratum for each AGI Stub even w/o associated national target - note = f"Geo: 0100000US, AGI > {agi_income_lower}, AGI < {agi_income_upper}" - nat_stratum = Stratum( - parent_stratum_id=None, stratum_group_id=0, notes=note - ) - nat_stratum.constraints_rel.extend( - [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="greater_than", - value=str(agi_income_lower), - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="less_than", - value=str(agi_income_upper), - ), - ] + note = f"National filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + + # Check if national AGI stratum already exists + nat_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == filer_strata["national"], + Stratum.stratum_group_id == 3, + Stratum.notes == note, + ) + .first() ) - session.add(nat_stratum) - session.flush() - stratum_lookup = { - "National": nat_stratum.stratum_id, - "State": {}, - "District": {}, - } - for i in range(agi_df.shape[0]): - ucgid_i = agi_df[["ucgid_str"]].iloc[i].values[0] - note = f"Geo: {ucgid_i}, AGI > {agi_income_lower}, AGI < {agi_income_upper}" - - person_count = agi_df.iloc[i][["target_value"]].values[0] - - if len(ucgid_i) == 11: # State - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["National"], - stratum_group_id=0, - notes=note, - ) - elif len(ucgid_i) == 13: # District - new_stratum = Stratum( - parent_stratum_id=stratum_lookup["State"][ - "0400000US" + ucgid_i[9:11] - ], - stratum_group_id=0, - notes=note, - ) - new_stratum.constraints_rel.extend( + if not nat_stratum: + nat_stratum = Stratum( + parent_stratum_id=filer_strata["national"], + stratum_group_id=3, # Income/AGI strata group + notes=note, + ) + nat_stratum.constraints_rel.extend( [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=ucgid_i, + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", ), StratumConstraint( constraint_variable="adjusted_gross_income", - operation="greater_than", + operation=">=", value=str(agi_income_lower), ), StratumConstraint( constraint_variable="adjusted_gross_income", - operation="less_than", + operation="<", value=str(agi_income_upper), ), ] ) - new_stratum.targets_rel.append( - Target( - variable="person_count", - period=year, - value=person_count, - source_id=5, - active=True, + session.add(nat_stratum) + session.flush() + + agi_stratum_lookup = { + "national": nat_stratum.stratum_id, + "state": {}, + "district": {}, + } + for i in range(agi_df.shape[0]): + ucgid_i = agi_df[["ucgid_str"]].iloc[i].values[0] + geo_info = parse_ucgid(ucgid_i) + person_count = agi_df.iloc[i][["target_value"]].values[0] + + if geo_info["type"] == "state": + parent_stratum_id = filer_strata["state"][ + geo_info["state_fips"] + ] + note = f"State FIPS {geo_info['state_fips']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="state_fips", + operation="==", + value=str(geo_info["state_fips"]), + ), + ] + elif geo_info["type"] == "district": + parent_stratum_id = filer_strata["district"][ + geo_info["congressional_district_geoid"] + ] + note = f"Congressional District {geo_info['congressional_district_geoid']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" + constraints = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ), + StratumConstraint( + constraint_variable="congressional_district_geoid", + operation="==", + value=str(geo_info["congressional_district_geoid"]), + ), + ] + else: + continue # Skip if not state or district (shouldn't happen, but defensive) + + # Check if stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == parent_stratum_id, + Stratum.stratum_group_id == 3, + Stratum.notes == note, ) + .first() ) + if existing_stratum: + new_stratum = existing_stratum + else: + new_stratum = Stratum( + parent_stratum_id=parent_stratum_id, + stratum_group_id=3, # Income/AGI strata group + notes=note, + ) + new_stratum.constraints_rel = constraints + new_stratum.constraints_rel.extend( + [ + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation=">=", + value=str(agi_income_lower), + ), + StratumConstraint( + constraint_variable="adjusted_gross_income", + operation="<", + value=str(agi_income_upper), + ), + ] + ) + session.add(new_stratum) + session.flush() + + # Check if target already exists and update or create it + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == new_stratum.stratum_id, + Target.variable == "person_count", + Target.period == year, + ) + .first() + ) + + if existing_target: + existing_target.value = person_count + existing_target.source_id = irs_source.source_id + else: + new_stratum.targets_rel.append( + Target( + variable="person_count", + period=year, + value=person_count, + source_id=irs_source.source_id, + active=True, + ) + ) + session.add(new_stratum) session.flush() - if len(ucgid_i) == 9: - stratum_lookup["National"] = new_stratum.stratum_id - elif len(ucgid_i) == 11: - stratum_lookup["State"][ucgid_i] = new_stratum.stratum_id + if geo_info["type"] == "state": + agi_stratum_lookup["state"][ + geo_info["state_fips"] + ] = new_stratum.stratum_id + elif geo_info["type"] == "district": + agi_stratum_lookup["district"][ + geo_info["congressional_district_geoid"] + ] = new_stratum.stratum_id session.commit() diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index 926a0d88..b4e79c9a 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -1,7 +1,8 @@ import requests import pandas as pd -from sqlmodel import Session, create_engine +import numpy as np +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER @@ -9,38 +10,91 @@ Stratum, StratumConstraint, Target, + SourceType, +) +from policyengine_us_data.utils.census import ( + STATE_ABBREV_TO_FIPS, + pull_acs_table, +) +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, ) -from policyengine_us_data.utils.census import STATE_ABBREV_TO_FIPS -def extract_medicaid_data(year): - base_url = ( - f"https://api.census.gov/data/{year}/acs/acs1/subject?get=group(S2704)" - ) - url = f"{base_url}&for=congressional+district:*" - response = requests.get(url) - response.raise_for_status() +def extract_administrative_medicaid_data(year): + item = "6165f45b-ca93-5bb5-9d06-db29c692a360" - data = response.json() + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", + "Accept": "application/json", + "Accept-Language": "en-US,en;q=0.5", + } - headers = data[0] - data_rows = data[1:] - cd_survey_df = pd.DataFrame(data_rows, columns=headers) + try: + session = requests.Session() + session.headers.update(headers) - item = "6165f45b-ca93-5bb5-9d06-db29c692a360" - response = requests.get( - f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" - ) - metadata = response.json() + metadata_url = f"https://data.medicaid.gov/api/1/metastore/schemas/dataset/items/{item}?show-reference-ids=false" + print(f"Attempting to fetch Medicaid metadata from: {metadata_url}") - data_url = metadata["distribution"][0]["data"]["downloadURL"] - state_admin_df = pd.read_csv(data_url) + response = session.get(metadata_url, timeout=30) + response.raise_for_status() - return cd_survey_df, state_admin_df + metadata = response.json() + if ( + "distribution" not in metadata + or len(metadata["distribution"]) == 0 + ): + raise ValueError( + f"No distribution found in metadata for item {item}" + ) + + data_url = metadata["distribution"][0]["data"]["downloadURL"] + print(f"Downloading Medicaid data from: {data_url}") + + try: + state_admin_df = pd.read_csv(data_url) + print( + f"Successfully downloaded {len(state_admin_df)} rows of Medicaid administrative data" + ) + return state_admin_df + except Exception as csv_error: + print(f"\nError downloading CSV from: {data_url}") + print(f"Error: {csv_error}") + print( + f"\nThe metadata API returned successfully, but the data file doesn't exist." + ) + print(f"This suggests the dataset has been updated/moved.") + print(f"Please visit https://data.medicaid.gov/ and search for:") + print( + f" - 'Medicaid Enrollment' or 'T-MSIS' or 'Performance Indicators'" + ) + print(f"Then update the item ID in the code (currently: {item})\n") + raise + + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + print(f"\n404 Error: Medicaid metadata item not found.") + print(f"The item ID '{item}' may have changed.") + print( + f"Please check https://data.medicaid.gov/ for updated dataset IDs." + ) + print(f"Search for 'Medicaid Enrollment' or 'T-MSIS' datasets.\n") + raise + except requests.exceptions.RequestException as e: + print(f"Error downloading Medicaid data: {e}") + raise -def transform_medicaid_data(state_admin_df, cd_survey_df, year): +def extract_survey_medicaid_data(year): + return pull_acs_table("S2704", "District", year) + + +def transform_administrative_medicaid_data(state_admin_df, year): reporting_period = year * 100 + 12 print(f"Reporting period is {reporting_period}") state_df = state_admin_df.loc[ @@ -55,22 +109,19 @@ def transform_medicaid_data(state_admin_df, cd_survey_df, year): state_df["FIPS"] = state_df["State Abbreviation"].map(STATE_ABBREV_TO_FIPS) - cd_df = cd_survey_df[ - ["GEO_ID", "state", "congressional district", "S2704_C02_006E"] - ] - - nc_cd_sum = cd_df.loc[cd_df.state == "37"].S2704_C02_006E.astype(int).sum() - nc_state_sum = state_df.loc[state_df.FIPS == "37"][ - "Total Medicaid Enrollment" - ].values[0] - assert nc_cd_sum > 0.5 * nc_state_sum - assert nc_cd_sum <= nc_state_sum - state_df = state_df.rename( columns={"Total Medicaid Enrollment": "medicaid_enrollment"} ) state_df["ucgid_str"] = "0400000US" + state_df["FIPS"].astype(str) + return state_df[["ucgid_str", "medicaid_enrollment"]] + + +def transform_survey_medicaid_data(cd_survey_df): + cd_df = cd_survey_df[ + ["GEO_ID", "state", "congressional district", "S2704_C02_006E"] + ] + cd_df = cd_df.rename( columns={ "S2704_C02_006E": "medicaid_enrollment", @@ -79,8 +130,7 @@ def transform_medicaid_data(state_admin_df, cd_survey_df, year): ) cd_df = cd_df.loc[cd_df.state != "72"] - out_cols = ["ucgid_str", "medicaid_enrollment"] - return state_df[out_cols], cd_df[out_cols] + return cd_df[["ucgid_str", "medicaid_enrollment"]] def load_medicaid_data(long_state, long_cd, year): @@ -88,24 +138,67 @@ def load_medicaid_data(long_state, long_cd, year): DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" engine = create_engine(DATABASE_URL) - stratum_lookup = {} - with Session(engine) as session: + # Get or create sources + admin_source = get_or_create_source( + session, + name="Medicaid T-MSIS", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"{year} Final Report", + description="Medicaid Transformed MSIS administrative enrollment data", + url="https://data.medicaid.gov/", + notes="State-level Medicaid enrollment from administrative records", + ) + + survey_source = get_or_create_source( + session, + name="Census ACS Table S2704", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 1-year estimates", + description="American Community Survey health insurance coverage data", + url="https://data.census.gov/", + notes="Congressional district level Medicaid coverage from ACS", + ) + + # Get or create Medicaid variable group + medicaid_group = get_or_create_variable_group( + session, + name="medicaid_recipients", + category="benefit", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=3, + description="Medicaid enrollment and spending", + ) + + # Create variable metadata + # Note: The actual target variable used is "person_count" with medicaid_enrolled==True constraint + # This metadata entry is kept for consistency with the actual variable being used + get_or_create_variable_metadata( + session, + variable="person_count", + group=medicaid_group, + display_name="Medicaid Enrollment", + display_order=1, + units="count", + notes="Number of people enrolled in Medicaid (person_count with medicaid_enrolled==True)", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # National ---------------- + # Create a Medicaid stratum as child of the national geographic stratum nat_stratum = Stratum( - parent_stratum_id=None, - stratum_group_id=0, - notes="Geo: 0100000US Medicaid Enrolled", + parent_stratum_id=geo_strata["national"], + stratum_group_id=5, # Medicaid strata group + notes="National Medicaid Enrolled", ) nat_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -113,29 +206,36 @@ def load_medicaid_data(long_state, long_cd, year): session.add(nat_stratum) session.flush() - stratum_lookup["National"] = nat_stratum.stratum_id + medicaid_stratum_lookup = { + "national": nat_stratum.stratum_id, + "state": {}, + } # State ------------------- - stratum_lookup["State"] = {} for _, row in long_state.iterrows(): + # Parse the UCGID to get state_fips + geo_info = parse_ucgid(row["ucgid_str"]) + state_fips = geo_info["state_fips"] - note = f"Geo: {row['ucgid_str']} Medicaid Enrolled" - parent_stratum_id = nat_stratum.stratum_id + # Get the parent geographic stratum + parent_stratum_id = geo_strata["state"][state_fips] + + note = f"State FIPS {state_fips} Medicaid Enrolled" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=5, # Medicaid strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -144,36 +244,41 @@ def load_medicaid_data(long_state, long_cd, year): variable="person_count", period=year, value=row["medicaid_enrollment"], - source_id=2, + source_id=admin_source.source_id, active=True, ) ) session.add(new_stratum) session.flush() - stratum_lookup["State"][row["ucgid_str"]] = new_stratum.stratum_id + medicaid_stratum_lookup["state"][ + state_fips + ] = new_stratum.stratum_id # District ------------------- for _, row in long_cd.iterrows(): + # Parse the UCGID to get district info + geo_info = parse_ucgid(row["ucgid_str"]) + cd_geoid = geo_info["congressional_district_geoid"] - note = f"Geo: {row['ucgid_str']} Medicaid Enrolled" - parent_stratum_id = stratum_lookup["State"][ - f'0400000US{row["ucgid_str"][-4:-2]}' - ] + # Get the parent geographic stratum + parent_stratum_id = geo_strata["district"][cd_geoid] + + note = f"Congressional District {cd_geoid} Medicaid Enrolled" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=5, # Medicaid strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), ), StratumConstraint( constraint_variable="medicaid_enrolled", - operation="equals", + operation="==", value="True", ), ] @@ -182,7 +287,7 @@ def load_medicaid_data(long_state, long_cd, year): variable="person_count", period=year, value=row["medicaid_enrollment"], - source_id=2, + source_id=survey_source.source_id, active=True, ) ) @@ -192,17 +297,36 @@ def load_medicaid_data(long_state, long_cd, year): session.commit() -if __name__ == "__main__": - +def main(): year = 2023 # Extract ------------------------------ - cd_survey_df, state_admin_df = extract_medicaid_data(year) + state_admin_df = extract_administrative_medicaid_data(year) + cd_survey_df = extract_survey_medicaid_data(year) # Transform ------------------- - long_state, long_cd = transform_medicaid_data( - state_admin_df, cd_survey_df, year + long_state = transform_administrative_medicaid_data(state_admin_df, year) + long_cd = transform_survey_medicaid_data(cd_survey_df) + + # Validate consistency between sources + nc_cd_sum = ( + long_cd.loc[long_cd.ucgid_str.str.contains("5001800US37")] + .medicaid_enrollment.astype(int) + .sum() ) + nc_state_sum = long_state.loc[long_state.ucgid_str == "0400000US37"][ + "medicaid_enrollment" + ].values[0] + assert ( + nc_cd_sum > 0.5 * nc_state_sum + ), f"NC CD sum ({nc_cd_sum}) is too low compared to state sum ({nc_state_sum})" + assert ( + nc_cd_sum <= nc_state_sum + ), f"NC CD sum ({nc_cd_sum}) exceeds state sum ({nc_state_sum})" # Load ----------------------- load_medicaid_data(long_state, long_cd, year) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py new file mode 100644 index 00000000..262a83d9 --- /dev/null +++ b/policyengine_us_data/db/etl_national_targets.py @@ -0,0 +1,666 @@ +from sqlmodel import Session, create_engine +import pandas as pd + +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, + Target, + SourceType, +) +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, +) + + +def extract_national_targets(): + """ + Extract national calibration targets from various sources. + + Returns + ------- + dict + Dictionary containing: + - direct_sum_targets: Variables that can be summed directly + - tax_filer_targets: Tax-related variables requiring filer constraint + - conditional_count_targets: Enrollment counts requiring constraints + - cbo_targets: List of CBO projection targets + - treasury_targets: List of Treasury/JCT targets + """ + + # Initialize PolicyEngine for parameter access + from policyengine_us import Microsimulation + + sim = Microsimulation( + dataset="hf://policyengine/policyengine-us-data/cps_2023.h5" + ) + + # Direct sum targets - these are regular variables that can be summed + # Store with their actual source year (2024 for hardcoded values from loss.py) + HARDCODED_YEAR = 2024 + + # Separate tax-related targets that need filer constraint + tax_filer_targets = [ + { + "variable": "salt_deduction", + "value": 21.247e9, + "source": "Joint Committee on Taxation", + "notes": "SALT deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "medical_expense_deduction", + "value": 11.4e9, + "source": "Joint Committee on Taxation", + "notes": "Medical expense deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "charitable_deduction", + "value": 65.301e9, + "source": "Joint Committee on Taxation", + "notes": "Charitable deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "interest_deduction", + "value": 24.8e9, + "source": "Joint Committee on Taxation", + "notes": "Mortgage interest deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + { + "variable": "qualified_business_income_deduction", + "value": 63.1e9, + "source": "Joint Committee on Taxation", + "notes": "QBI deduction tax expenditure", + "year": HARDCODED_YEAR, + }, + ] + + direct_sum_targets = [ + { + "variable": "alimony_income", + "value": 13e9, + "source": "Survey-reported (post-TCJA grandfathered)", + "notes": "Alimony received - survey reported, not tax-filer restricted", + "year": HARDCODED_YEAR, + }, + { + "variable": "alimony_expense", + "value": 13e9, + "source": "Survey-reported (post-TCJA grandfathered)", + "notes": "Alimony paid - survey reported, not tax-filer restricted", + "year": HARDCODED_YEAR, + }, + { + "variable": "medicaid", + "value": 871.7e9, + "source": "https://www.cms.gov/files/document/highlights.pdf", + "notes": "CMS 2023 highlights document - total Medicaid spending", + "year": HARDCODED_YEAR, + }, + { + "variable": "net_worth", + "value": 160e12, + "source": "Federal Reserve SCF", + "notes": "Total household net worth", + "year": HARDCODED_YEAR, + }, + { + "variable": "health_insurance_premiums_without_medicare_part_b", + "value": 385e9, + "source": "MEPS/NHEA", + "notes": "Health insurance premiums excluding Medicare Part B", + "year": HARDCODED_YEAR, + }, + { + "variable": "other_medical_expenses", + "value": 278e9, + "source": "MEPS/NHEA", + "notes": "Out-of-pocket medical expenses", + "year": HARDCODED_YEAR, + }, + { + "variable": "medicare_part_b_premiums", + "value": 112e9, + "source": "CMS Medicare data", + "notes": "Medicare Part B premium payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "over_the_counter_health_expenses", + "value": 72e9, + "source": "Consumer Expenditure Survey", + "notes": "OTC health products and supplies", + "year": HARDCODED_YEAR, + }, + { + "variable": "child_support_expense", + "value": 33e9, + "source": "Census Bureau", + "notes": "Child support payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "child_support_received", + "value": 33e9, + "source": "Census Bureau", + "notes": "Child support received", + "year": HARDCODED_YEAR, + }, + { + "variable": "spm_unit_capped_work_childcare_expenses", + "value": 348e9, + "source": "Census Bureau SPM", + "notes": "Work and childcare expenses for SPM", + "year": HARDCODED_YEAR, + }, + { + "variable": "spm_unit_capped_housing_subsidy", + "value": 35e9, + "source": "HUD/Census", + "notes": "Housing subsidies", + "year": HARDCODED_YEAR, + }, + { + "variable": "tanf", + "value": 9e9, + "source": "HHS/ACF", + "notes": "TANF cash assistance", + "year": HARDCODED_YEAR, + }, + { + "variable": "real_estate_taxes", + "value": 500e9, + "source": "Census Bureau", + "notes": "Property taxes paid", + "year": HARDCODED_YEAR, + }, + { + "variable": "rent", + "value": 735e9, + "source": "Census Bureau/BLS", + "notes": "Rental payments", + "year": HARDCODED_YEAR, + }, + { + "variable": "tip_income", + "value": 53.2e9, + "source": "IRS Form W-2 Box 7 statistics", + "notes": "Social security tips uprated 40% to account for underreporting", + "year": HARDCODED_YEAR, + }, + ] + + # Conditional count targets - these need strata with constraints + # Store with actual source year + conditional_count_targets = [ + { + "constraint_variable": "medicaid", + "stratum_group_id": 5, # Medicaid strata group + "person_count": 72_429_055, + "source": "CMS/HHS administrative data", + "notes": "Medicaid enrollment count", + "year": HARDCODED_YEAR, + }, + { + "constraint_variable": "aca_ptc", + "stratum_group_id": None, # Will use a generic stratum or create new group + "person_count": 19_743_689, + "source": "CMS marketplace data", + "notes": "ACA Premium Tax Credit recipients", + "year": HARDCODED_YEAR, + }, + ] + + # Add SSN card type NONE targets for multiple years + # Based on loss.py lines 445-460 + ssn_none_targets_by_year = [ + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", # Need to specify the value we're checking for + "stratum_group_id": 7, # New group for SSN card type + "person_count": 11.0e6, + "source": "DHS Office of Homeland Security Statistics", + "notes": "Undocumented population estimate for Jan 1, 2022", + "year": 2022, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 12.2e6, + "source": "Center for Migration Studies ACS-based residual estimate", + "notes": "Undocumented population estimate (published May 2025)", + "year": 2023, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 13.0e6, + "source": "Reuters synthesis of experts", + "notes": "Undocumented population central estimate (~13-14 million)", + "year": 2024, + }, + { + "constraint_variable": "ssn_card_type", + "constraint_value": "NONE", + "stratum_group_id": 7, + "person_count": 13.0e6, + "source": "Reuters synthesis of experts", + "notes": "Same midpoint carried forward - CBP data show 95% drop in border apprehensions", + "year": 2025, + }, + ] + + conditional_count_targets.extend(ssn_none_targets_by_year) + + # CBO projection targets - get for a specific year + CBO_YEAR = 2023 # Year the CBO projections are for + cbo_vars = [ + "income_tax", + "snap", + "social_security", + "ssi", + "unemployment_compensation", + ] + + cbo_targets = [] + for variable_name in cbo_vars: + try: + value = sim.tax_benefit_system.parameters( + CBO_YEAR + ).calibration.gov.cbo._children[variable_name] + cbo_targets.append( + { + "variable": variable_name, + "value": float(value), + "source": "CBO Budget Projections", + "notes": f"CBO projection for {variable_name}", + "year": CBO_YEAR, + } + ) + except (KeyError, AttributeError) as e: + print( + f"Warning: Could not extract CBO parameter for {variable_name}: {e}" + ) + + # Treasury/JCT targets (EITC) - get for a specific year + TREASURY_YEAR = 2023 + try: + eitc_value = sim.tax_benefit_system.parameters.calibration.gov.treasury.tax_expenditures.eitc( + TREASURY_YEAR + ) + treasury_targets = [ + { + "variable": "eitc", + "value": float(eitc_value), + "source": "Treasury/JCT Tax Expenditures", + "notes": "EITC tax expenditure", + "year": TREASURY_YEAR, + } + ] + except (KeyError, AttributeError) as e: + print(f"Warning: Could not extract Treasury EITC parameter: {e}") + treasury_targets = [] + + return { + "direct_sum_targets": direct_sum_targets, + "tax_filer_targets": tax_filer_targets, + "conditional_count_targets": conditional_count_targets, + "cbo_targets": cbo_targets, + "treasury_targets": treasury_targets, + } + + +def transform_national_targets(raw_targets): + """ + Transform extracted targets into standardized format for loading. + + Parameters + ---------- + raw_targets : dict + Dictionary from extract_national_targets() + + Returns + ------- + tuple + (direct_targets_df, tax_filer_df, conditional_targets) + - direct_targets_df: DataFrame with direct sum targets + - tax_filer_df: DataFrame with tax-related targets needing filer constraint + - conditional_targets: List of conditional count targets + """ + + # Process direct sum targets (non-tax items and some CBO items) + # Note: income_tax from CBO and eitc from Treasury need filer constraint + cbo_non_tax = [ + t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax" + ] + cbo_tax = [ + t for t in raw_targets["cbo_targets"] if t["variable"] == "income_tax" + ] + + all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax + + # Tax-related targets that need filer constraint + all_tax_filer_targets = ( + raw_targets["tax_filer_targets"] + + cbo_tax + + raw_targets["treasury_targets"] # EITC + ) + + direct_df = ( + pd.DataFrame(all_direct_targets) + if all_direct_targets + else pd.DataFrame() + ) + tax_filer_df = ( + pd.DataFrame(all_tax_filer_targets) + if all_tax_filer_targets + else pd.DataFrame() + ) + + # Conditional targets stay as list for special processing + conditional_targets = raw_targets["conditional_count_targets"] + + return direct_df, tax_filer_df, conditional_targets + + +def load_national_targets( + direct_targets_df, tax_filer_df, conditional_targets +): + """ + Load national targets into the database. + + Parameters + ---------- + direct_targets_df : pd.DataFrame + DataFrame with direct sum target data + tax_filer_df : pd.DataFrame + DataFrame with tax-related targets needing filer constraint + conditional_targets : list + List of conditional count targets requiring strata + """ + + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + engine = create_engine(DATABASE_URL) + + with Session(engine) as session: + # Get or create the calibration source + calibration_source = get_or_create_source( + session, + name="PolicyEngine Calibration Targets", + source_type=SourceType.HARDCODED, + vintage="Mixed (2023-2024)", + description="National calibration targets from various authoritative sources", + url=None, + notes="Aggregated from CMS, IRS, CBO, Treasury, and other federal sources", + ) + + # Get the national stratum + us_stratum = ( + session.query(Stratum) + .filter(Stratum.parent_stratum_id == None) + .first() + ) + + if not us_stratum: + raise ValueError( + "National stratum not found. Run create_initial_strata.py first." + ) + + # Process direct sum targets + for _, target_data in direct_targets_df.iterrows(): + target_year = target_data["year"] + # Check if target already exists + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == us_stratum.stratum_id, + Target.variable == target_data["variable"], + Target.period == target_year, + ) + .first() + ) + + # Combine source info into notes + notes_parts = [] + if pd.notna(target_data.get("notes")): + notes_parts.append(target_data["notes"]) + notes_parts.append( + f"Source: {target_data.get('source', 'Unknown')}" + ) + combined_notes = " | ".join(notes_parts) + + if existing_target: + # Update existing target + existing_target.value = target_data["value"] + existing_target.notes = combined_notes + print(f"Updated target: {target_data['variable']}") + else: + # Create new target + target = Target( + stratum_id=us_stratum.stratum_id, + variable=target_data["variable"], + period=target_year, + value=target_data["value"], + source_id=calibration_source.source_id, + active=True, + notes=combined_notes, + ) + session.add(target) + print(f"Added target: {target_data['variable']}") + + # Process tax-related targets that need filer constraint + if not tax_filer_df.empty: + # Get or create the national filer stratum + national_filer_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == us_stratum.stratum_id, + Stratum.notes == "United States - Tax Filers", + ) + .first() + ) + + if not national_filer_stratum: + # Create national filer stratum + national_filer_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=2, # Filer population group + notes="United States - Tax Filers", + ) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable="tax_unit_is_filer", + operation="==", + value="1", + ) + ] + session.add(national_filer_stratum) + session.flush() + print("Created national filer stratum") + + # Add tax-related targets to filer stratum + for _, target_data in tax_filer_df.iterrows(): + target_year = target_data["year"] + # Check if target already exists + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == national_filer_stratum.stratum_id, + Target.variable == target_data["variable"], + Target.period == target_year, + ) + .first() + ) + + # Combine source info into notes + notes_parts = [] + if pd.notna(target_data.get("notes")): + notes_parts.append(target_data["notes"]) + notes_parts.append( + f"Source: {target_data.get('source', 'Unknown')}" + ) + combined_notes = " | ".join(notes_parts) + + if existing_target: + # Update existing target + existing_target.value = target_data["value"] + existing_target.notes = combined_notes + print(f"Updated filer target: {target_data['variable']}") + else: + # Create new target + target = Target( + stratum_id=national_filer_stratum.stratum_id, + variable=target_data["variable"], + period=target_year, + value=target_data["value"], + source_id=calibration_source.source_id, + active=True, + notes=combined_notes, + ) + session.add(target) + print(f"Added filer target: {target_data['variable']}") + + # Process conditional count targets (enrollment counts) + for cond_target in conditional_targets: + constraint_var = cond_target["constraint_variable"] + stratum_group_id = cond_target.get("stratum_group_id") + target_year = cond_target["year"] + + # Determine stratum group ID and constraint details + if constraint_var == "medicaid": + stratum_group_id = 5 # Medicaid strata group + stratum_notes = "National Medicaid Enrollment" + constraint_operation = ">" + constraint_value = "0" + elif constraint_var == "aca_ptc": + stratum_group_id = ( + 6 # EITC group or could create new ACA group + ) + stratum_notes = "National ACA Premium Tax Credit Recipients" + constraint_operation = ">" + constraint_value = "0" + elif constraint_var == "ssn_card_type": + stratum_group_id = 7 # SSN card type group + stratum_notes = "National Undocumented Population" + constraint_operation = "=" + constraint_value = cond_target.get("constraint_value", "NONE") + else: + stratum_notes = f"National {constraint_var} Recipients" + constraint_operation = ">" + constraint_value = "0" + + # Check if this stratum already exists + existing_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id == us_stratum.stratum_id, + Stratum.stratum_group_id == stratum_group_id, + Stratum.notes == stratum_notes, + ) + .first() + ) + + if existing_stratum: + # Update the existing target in this stratum + existing_target = ( + session.query(Target) + .filter( + Target.stratum_id == existing_stratum.stratum_id, + Target.variable == "person_count", + Target.period == target_year, + ) + .first() + ) + + if existing_target: + existing_target.value = cond_target["person_count"] + print(f"Updated enrollment target for {constraint_var}") + else: + # Add new target to existing stratum + new_target = Target( + stratum_id=existing_stratum.stratum_id, + variable="person_count", + period=target_year, + value=cond_target["person_count"], + source_id=calibration_source.source_id, + active=True, + notes=f"{cond_target['notes']} | Source: {cond_target['source']}", + ) + session.add(new_target) + print(f"Added enrollment target for {constraint_var}") + else: + # Create new stratum with constraint + new_stratum = Stratum( + parent_stratum_id=us_stratum.stratum_id, + stratum_group_id=stratum_group_id, + notes=stratum_notes, + ) + + # Add constraint + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=constraint_var, + operation=constraint_operation, + value=constraint_value, + ) + ] + + # Add target + new_stratum.targets_rel = [ + Target( + variable="person_count", + period=target_year, + value=cond_target["person_count"], + source_id=calibration_source.source_id, + active=True, + notes=f"{cond_target['notes']} | Source: {cond_target['source']}", + ) + ] + + session.add(new_stratum) + print( + f"Created stratum and target for {constraint_var} enrollment" + ) + + session.commit() + + total_targets = ( + len(direct_targets_df) + + len(tax_filer_df) + + len(conditional_targets) + ) + print(f"\nSuccessfully loaded {total_targets} national targets") + print(f" - {len(direct_targets_df)} direct sum targets") + print(f" - {len(tax_filer_df)} tax filer targets") + print( + f" - {len(conditional_targets)} enrollment count targets (as strata)" + ) + + +def main(): + """Main ETL pipeline for national targets.""" + + # Extract + print("Extracting national targets...") + raw_targets = extract_national_targets() + + # Transform + print("Transforming targets...") + direct_targets_df, tax_filer_df, conditional_targets = ( + transform_national_targets(raw_targets) + ) + + # Load + print("Loading targets into database...") + load_national_targets(direct_targets_df, tax_filer_df, conditional_targets) + + print("\nETL pipeline complete!") + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 1fba44a4..cf1f5f43 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -5,7 +5,7 @@ import pandas as pd import numpy as np import us -from sqlmodel import Session, create_engine +from sqlmodel import Session, create_engine, select from policyengine_us_data.storage import STORAGE_FOLDER @@ -13,11 +13,21 @@ Stratum, StratumConstraint, Target, + Source, + SourceType, + VariableGroup, + VariableMetadata, ) from policyengine_us_data.utils.census import ( pull_acs_table, STATE_NAME_TO_FIPS, ) +from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata +from policyengine_us_data.utils.db_metadata import ( + get_or_create_source, + get_or_create_variable_group, + get_or_create_variable_metadata, +) def extract_administrative_snap_data(year=2023): @@ -149,24 +159,65 @@ def load_administrative_snap_data(df_states, year): DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" engine = create_engine(DATABASE_URL) - stratum_lookup = {} - with Session(engine) as session: + # Get or create the administrative source + admin_source = get_or_create_source( + session, + name="USDA FNS SNAP Data", + source_type=SourceType.ADMINISTRATIVE, + vintage=f"FY {year}", + description="SNAP administrative data from USDA Food and Nutrition Service", + url="https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap", + notes="State-level administrative totals for households and costs", + ) + + # Get or create the SNAP variable group + snap_group = get_or_create_variable_group( + session, + name="snap_recipients", + category="benefit", + is_histogram=False, + is_exclusive=False, + aggregation_method="sum", + display_order=2, + description="SNAP (food stamps) recipient counts and benefits", + ) + + # Get or create variable metadata + get_or_create_variable_metadata( + session, + variable="snap", + group=snap_group, + display_name="SNAP Benefits", + display_order=1, + units="dollars", + notes="Annual SNAP benefit costs", + ) + + get_or_create_variable_metadata( + session, + variable="household_count", + group=snap_group, + display_name="SNAP Household Count", + display_order=2, + units="count", + notes="Number of households receiving SNAP", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # National ---------------- + # Create a SNAP stratum as child of the national geographic stratum nat_stratum = Stratum( - parent_stratum_id=None, - stratum_group_id=0, - notes="Geo: 0100000US Received SNAP Benefits", + parent_stratum_id=geo_strata["national"], + stratum_group_id=4, # SNAP strata group + notes="National Received SNAP Benefits", ) nat_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value="0100000US", - ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -175,29 +226,33 @@ def load_administrative_snap_data(df_states, year): session.add(nat_stratum) session.flush() - stratum_lookup["National"] = nat_stratum.stratum_id + snap_stratum_lookup = {"national": nat_stratum.stratum_id, "state": {}} # State ------------------- - stratum_lookup["State"] = {} for _, row in df_states.iterrows(): + # Parse the UCGID to get state_fips + geo_info = parse_ucgid(row["ucgid_str"]) + state_fips = geo_info["state_fips"] + + # Get the parent geographic stratum + parent_stratum_id = geo_strata["state"][state_fips] - note = f"Geo: {row['ucgid_str']} Received SNAP Benefits" - parent_stratum_id = nat_stratum.stratum_id + note = f"State FIPS {state_fips} Received SNAP Benefits" new_stratum = Stratum( parent_stratum_id=parent_stratum_id, - stratum_group_id=0, + stratum_group_id=4, # SNAP strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="state_fips", + operation="==", + value=str(state_fips), ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -207,7 +262,7 @@ def load_administrative_snap_data(df_states, year): variable="household_count", period=year, value=row["Households"], - source_id=3, + source_id=admin_source.source_id, active=True, ) ) @@ -216,49 +271,70 @@ def load_administrative_snap_data(df_states, year): variable="snap", period=year, value=row["Cost"], - source_id=3, + source_id=admin_source.source_id, active=True, ) ) session.add(new_stratum) session.flush() - stratum_lookup["State"][row["ucgid_str"]] = new_stratum.stratum_id + snap_stratum_lookup["state"][state_fips] = new_stratum.stratum_id session.commit() - return stratum_lookup + return snap_stratum_lookup -def load_survey_snap_data(survey_df, year, stratum_lookup=None): - """Use an already defined stratum_lookup to load the survey SNAP data""" +def load_survey_snap_data(survey_df, year, snap_stratum_lookup): + """Use an already defined snap_stratum_lookup to load the survey SNAP data - if stratum_lookup is None: - raise ValueError("stratum_lookup must be provided") + Note: snap_stratum_lookup should contain the SNAP strata created by + load_administrative_snap_data, so we don't recreate them. + """ DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" engine = create_engine(DATABASE_URL) with Session(engine) as session: + # Get or create the survey source + survey_source = get_or_create_source( + session, + name="Census ACS Table S2201", + source_type=SourceType.SURVEY, + vintage=f"{year} ACS 5-year estimates", + description="American Community Survey SNAP/Food Stamps data", + url="https://data.census.gov/", + notes="Congressional district level SNAP household counts from ACS", + ) + + # Fetch existing geographic strata + geo_strata = get_geographic_strata(session) + # Create new strata for districts whose households recieve SNAP benefits district_df = survey_df.copy() for _, row in district_df.iterrows(): - note = f"Geo: {row['ucgid_str']} Received SNAP Benefits" - state_ucgid_str = "0400000US" + row["ucgid_str"][9:11] - state_stratum_id = stratum_lookup["State"][state_ucgid_str] + # Parse the UCGID to get district info + geo_info = parse_ucgid(row["ucgid_str"]) + cd_geoid = geo_info["congressional_district_geoid"] + + # Get the parent geographic stratum + parent_stratum_id = geo_strata["district"][cd_geoid] + + note = f"Congressional District {cd_geoid} Received SNAP Benefits" + new_stratum = Stratum( - parent_stratum_id=state_stratum_id, - stratum_group_id=0, + parent_stratum_id=parent_stratum_id, + stratum_group_id=4, # SNAP strata group notes=note, ) new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="in", - value=row["ucgid_str"], + constraint_variable="congressional_district_geoid", + operation="==", + value=str(cd_geoid), ), StratumConstraint( constraint_variable="snap", - operation="greater_than", + operation=">", value="0", ), ] @@ -267,7 +343,7 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): variable="household_count", period=year, value=row["snap_household_ct"], - source_id=4, + source_id=survey_source.source_id, active=True, ) ) @@ -276,7 +352,7 @@ def load_survey_snap_data(survey_df, year, stratum_lookup=None): session.commit() - return stratum_lookup + return snap_stratum_lookup def main(): @@ -291,8 +367,8 @@ def main(): district_survey_df = transform_survey_snap_data(raw_survey_df) # Load ----------- - stratum_lookup = load_administrative_snap_data(state_admin_df, year) - load_survey_snap_data(district_survey_df, year, stratum_lookup) + snap_stratum_lookup = load_administrative_snap_data(state_admin_df, year) + load_survey_snap_data(district_survey_df, year, snap_stratum_lookup) if __name__ == "__main__": diff --git a/policyengine_us_data/db/migrate_stratum_group_ids.py b/policyengine_us_data/db/migrate_stratum_group_ids.py new file mode 100644 index 00000000..03583ad5 --- /dev/null +++ b/policyengine_us_data/db/migrate_stratum_group_ids.py @@ -0,0 +1,135 @@ +""" +TODO: what is this file? Do we still need it? + + +Migration script to update stratum_group_id values to represent conceptual categories. + +New scheme: +- 1: Geographic (US, states, congressional districts) +- 2: Age-based strata +- 3: Income/AGI-based strata +- 4: SNAP recipient strata +- 5: Medicaid enrollment strata +- 6: EITC recipient strata +""" + +from sqlmodel import Session, create_engine, select +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, +) + + +def migrate_stratum_group_ids(): + """Update stratum_group_id values based on constraint variables.""" + + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + engine = create_engine(DATABASE_URL) + + with Session(engine) as session: + print("Starting stratum_group_id migration...") + print("=" * 60) + + # Track updates + updates = { + "Geographic": 0, + "Age": 0, + "Income/AGI": 0, + "SNAP": 0, + "Medicaid": 0, + "EITC": 0, + } + + # Get all strata + all_strata = session.exec(select(Stratum)).unique().all() + + for stratum in all_strata: + # Get constraints for this stratum + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == stratum.stratum_id + ) + ).all() + + # Determine new group_id based on constraints + constraint_vars = [c.constraint_variable for c in constraints] + + # Geographic strata (no demographic constraints) + if not constraint_vars or all( + cv in ["state_fips", "congressional_district_geoid"] + for cv in constraint_vars + ): + if stratum.stratum_group_id != 1: + stratum.stratum_group_id = 1 + updates["Geographic"] += 1 + + # Age strata + elif "age" in constraint_vars: + if stratum.stratum_group_id != 2: + stratum.stratum_group_id = 2 + updates["Age"] += 1 + + # Income/AGI strata + elif "adjusted_gross_income" in constraint_vars: + if stratum.stratum_group_id != 3: + stratum.stratum_group_id = 3 + updates["Income/AGI"] += 1 + + # SNAP strata + elif "snap" in constraint_vars: + if stratum.stratum_group_id != 4: + stratum.stratum_group_id = 4 + updates["SNAP"] += 1 + + # Medicaid strata + elif "medicaid_enrolled" in constraint_vars: + if stratum.stratum_group_id != 5: + stratum.stratum_group_id = 5 + updates["Medicaid"] += 1 + + # EITC strata + elif "eitc_child_count" in constraint_vars: + if stratum.stratum_group_id != 6: + stratum.stratum_group_id = 6 + updates["EITC"] += 1 + + # Commit changes + session.commit() + + # Report results + print("\nMigration complete!") + print("-" * 60) + print("Updates made:") + for category, count in updates.items(): + if count > 0: + print(f" {category:15}: {count:5} strata updated") + + # Verify final counts + print("\nFinal stratum_group_id distribution:") + print("-" * 60) + + group_names = { + 1: "Geographic", + 2: "Age", + 3: "Income/AGI", + 4: "SNAP", + 5: "Medicaid", + 6: "EITC", + } + + for group_id, name in group_names.items(): + count = len( + session.exec( + select(Stratum).where(Stratum.stratum_group_id == group_id) + ) + .unique() + .all() + ) + print(f" Group {group_id} ({name:12}): {count:5} strata") + + print("\n✅ Migration successful!") + + +if __name__ == "__main__": + migrate_stratum_group_ids() diff --git a/policyengine_us_data/db/validate_hierarchy.py b/policyengine_us_data/db/validate_hierarchy.py new file mode 100644 index 00000000..95964a10 --- /dev/null +++ b/policyengine_us_data/db/validate_hierarchy.py @@ -0,0 +1,324 @@ +""" +Validation script to ensure the parent-child hierarchy is working correctly. +Checks geographic and age strata relationships. +""" + +import sys +from sqlmodel import Session, create_engine, select +from policyengine_us_data.storage import STORAGE_FOLDER +from policyengine_us_data.db.create_database_tables import ( + Stratum, + StratumConstraint, +) + + +def validate_geographic_hierarchy(session): + """Validate the geographic hierarchy: US -> States -> Congressional Districts""" + + print("\n" + "=" * 60) + print("VALIDATING GEOGRAPHIC HIERARCHY") + print("=" * 60) + + errors = [] + + # Check US stratum exists and has no parent + us_stratum = session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, Stratum.parent_stratum_id == None + ) + ).first() + + if not us_stratum: + errors.append( + "ERROR: No US-level stratum found (should have parent_stratum_id = None)" + ) + else: + print( + f"✓ US stratum found: {us_stratum.notes} (ID: {us_stratum.stratum_id})" + ) + + # Check it has no constraints + us_constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == us_stratum.stratum_id + ) + ).all() + + if us_constraints: + errors.append( + f"ERROR: US stratum has {len(us_constraints)} constraints, should have 0" + ) + else: + print("✓ US stratum has no constraints (correct)") + + # Check states + states = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == us_stratum.stratum_id, + ) + ) + .unique() + .all() + ) + + print(f"\n✓ Found {len(states)} state strata") + if len(states) != 51: # 50 states + DC + errors.append( + f"WARNING: Expected 51 states (including DC), found {len(states)}" + ) + + # Verify each state has proper constraints + state_ids = {} + for state in states[:5]: # Sample first 5 states + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == state.stratum_id + ) + ).all() + + state_fips_constraint = [ + c for c in constraints if c.constraint_variable == "state_fips" + ] + if not state_fips_constraint: + errors.append( + f"ERROR: State '{state.notes}' has no state_fips constraint" + ) + else: + state_ids[state.stratum_id] = state.notes + print( + f" - {state.notes}: state_fips = {state_fips_constraint[0].value}" + ) + + # Check congressional districts + print("\nChecking Congressional Districts...") + + # Count total CDs (including delegate districts) + all_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + ( + Stratum.notes.like("%Congressional District%") + | Stratum.notes.like("%Delegate District%") + ), + ) + ) + .unique() + .all() + ) + + print(f"✓ Found {len(all_cds)} congressional/delegate districts") + if len(all_cds) != 436: + errors.append( + f"WARNING: Expected 436 congressional districts (including DC delegate), found {len(all_cds)}" + ) + + # Verify CDs are children of correct states (spot check) + wyoming_id = None + for state in states: + if "Wyoming" in state.notes: + wyoming_id = state.stratum_id + break + + if wyoming_id: + # Check Wyoming's congressional district + wyoming_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == wyoming_id, + Stratum.notes.like("%Congressional%"), + ) + ) + .unique() + .all() + ) + + if len(wyoming_cds) != 1: + errors.append( + f"ERROR: Wyoming should have 1 CD, found {len(wyoming_cds)}" + ) + else: + print(f"✓ Wyoming has correct number of CDs: 1") + + # Verify no other state's CDs are incorrectly parented to Wyoming + wrong_parent_cds = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id == 1, + Stratum.parent_stratum_id == wyoming_id, + ~Stratum.notes.like("%Wyoming%"), + Stratum.notes.like("%Congressional%"), + ) + ) + .unique() + .all() + ) + + if wrong_parent_cds: + errors.append( + f"ERROR: Found {len(wrong_parent_cds)} non-Wyoming CDs incorrectly parented to Wyoming" + ) + for cd in wrong_parent_cds[:5]: + errors.append(f" - {cd.notes}") + else: + print( + "✓ No congressional districts incorrectly parented to Wyoming" + ) + + return errors + + +def validate_demographic_strata(session): + """Validate demographic strata are properly attached to geographic strata""" + + print("\n" + "=" * 60) + print("VALIDATING DEMOGRAPHIC STRATA") + print("=" * 60) + + errors = [] + + # Group names for the new scheme + group_names = { + 2: ("Age", 18), + 3: ("Income/AGI", 9), + 4: ("SNAP", 1), + 5: ("Medicaid", 1), + 6: ("EITC", 4), + } + + # Validate each demographic group + for group_id, (name, expected_per_geo) in group_names.items(): + strata = ( + session.exec( + select(Stratum).where(Stratum.stratum_group_id == group_id) + ) + .unique() + .all() + ) + + expected_total = expected_per_geo * 488 # 488 geographic areas + print(f"\n{name} strata (group {group_id}):") + print(f" Found: {len(strata)}") + print( + f" Expected: {expected_total} ({expected_per_geo} × 488 geographic areas)" + ) + + if len(strata) != expected_total: + errors.append( + f"WARNING: {name} has {len(strata)} strata, expected {expected_total}" + ) + + # Check parent relationships for a sample of demographic strata + print("\nChecking parent relationships (sample):") + sample_strata = ( + session.exec( + select(Stratum).where( + Stratum.stratum_group_id > 1 + ) # All demographic groups + ) + .unique() + .all()[:100] + ) # Take first 100 + + correct_parents = 0 + wrong_parents = 0 + no_parents = 0 + + for stratum in sample_strata: + if stratum.parent_stratum_id: + parent = session.get(Stratum, stratum.parent_stratum_id) + if parent and parent.stratum_group_id == 1: # Geographic parent + correct_parents += 1 + else: + wrong_parents += 1 + errors.append( + f"ERROR: Stratum {stratum.stratum_id} has non-geographic parent" + ) + else: + no_parents += 1 + errors.append(f"ERROR: Stratum {stratum.stratum_id} has no parent") + + print(f" Sample of {len(sample_strata)} demographic strata:") + print(f" - With geographic parent: {correct_parents}") + print(f" - With wrong parent: {wrong_parents}") + print(f" - With no parent: {no_parents}") + + return errors + + +def validate_constraint_uniqueness(session): + """Check that constraint combinations produce unique hashes""" + + print("\n" + "=" * 60) + print("VALIDATING CONSTRAINT UNIQUENESS") + print("=" * 60) + + errors = [] + + # Check for duplicate definition_hashes + all_strata = session.exec(select(Stratum)).unique().all() + hash_counts = {} + + for stratum in all_strata: + if stratum.definition_hash in hash_counts: + hash_counts[stratum.definition_hash].append(stratum) + else: + hash_counts[stratum.definition_hash] = [stratum] + + duplicates = { + h: strata for h, strata in hash_counts.items() if len(strata) > 1 + } + + if duplicates: + errors.append( + f"ERROR: Found {len(duplicates)} duplicate definition_hashes" + ) + for hash_val, strata in list(duplicates.items())[:3]: # Show first 3 + errors.append( + f" Hash {hash_val[:10]}... appears {len(strata)} times:" + ) + for s in strata[:3]: + errors.append(f" - ID {s.stratum_id}: {s.notes[:50]}") + else: + print(f"✓ All {len(all_strata)} strata have unique definition_hashes") + + return errors + + +def main(): + """Run all validation checks""" + + DATABASE_URL = f"sqlite:///{STORAGE_FOLDER / 'policy_data.db'}" + engine = create_engine(DATABASE_URL) + + all_errors = [] + + with Session(engine) as session: + # Run validation checks + all_errors.extend(validate_geographic_hierarchy(session)) + all_errors.extend(validate_demographic_strata(session)) + all_errors.extend(validate_constraint_uniqueness(session)) + + # Summary + print("\n" + "=" * 60) + print("VALIDATION SUMMARY") + print("=" * 60) + + if all_errors: + print(f"\n❌ Found {len(all_errors)} issues:\n") + for error in all_errors: + print(f" {error}") + sys.exit(1) + else: + print("\n✅ All validation checks passed!") + print(" - Geographic hierarchy is correct") + print(" - Demographic strata properly organized and attached") + print(" - All constraint combinations are unique") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/policyengine_us_data/geography/zip_codes.csv.gz b/policyengine_us_data/geography/zip_codes.csv.gz deleted file mode 100644 index 2007b6ed..00000000 Binary files a/policyengine_us_data/geography/zip_codes.csv.gz and /dev/null differ diff --git a/policyengine_us_data/storage/download_private_prerequisites.py b/policyengine_us_data/storage/download_private_prerequisites.py index 26696d6c..3e080274 100644 --- a/policyengine_us_data/storage/download_private_prerequisites.py +++ b/policyengine_us_data/storage/download_private_prerequisites.py @@ -27,3 +27,9 @@ local_folder=FOLDER, version=None, ) +download( + repo="policyengine/policyengine-us-data", + repo_filename="policy_data.db", + local_folder=FOLDER, + version=None, +) diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index e99eed01..9c9b5aa4 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -18,6 +18,13 @@ def upload_datasets(): # STORAGE_FOLDER / "policy_data.db", ] + cd_states_dir = STORAGE_FOLDER / "cd_states" + if cd_states_dir.exists(): + state_files = list(cd_states_dir.glob("*.h5")) + if state_files: + print(f"Found {len(state_files)} state files in cd_states/") + dataset_files.extend(state_files) + # Filter to only existing files existing_files = [] for file_path in dataset_files: diff --git a/policyengine_us_data/tests/test_database.py b/policyengine_us_data/tests/test_database.py index c36ef828..4782fff5 100644 --- a/policyengine_us_data/tests/test_database.py +++ b/policyengine_us_data/tests/test_database.py @@ -29,14 +29,14 @@ def test_stratum_hash_and_relationships(engine): stratum.constraints_rel = [ StratumConstraint( constraint_variable="ucgid_str", - operation="equals", + operation="==", value="0400000US30", ), StratumConstraint( - constraint_variable="age", operation="greater_than", value="20" + constraint_variable="age", operation=">", value="20" ), StratumConstraint( - constraint_variable="age", operation="less_than", value="65" + constraint_variable="age", operation="<", value="65" ), ] stratum.targets_rel = [ @@ -48,9 +48,9 @@ def test_stratum_hash_and_relationships(engine): "\n".join( sorted( [ - "ucgid_str|equals|0400000US30", - "age|greater_than|20", - "age|less_than|65", + "ucgid_str|==|0400000US30", + "age|>|20", + "age|<|65", ] ) ).encode("utf-8") @@ -67,7 +67,7 @@ def test_unique_definition_hash(engine): s1.constraints_rel = [ StratumConstraint( constraint_variable="ucgid_str", - operation="equals", + operation="==", value="0400000US30", ) ] @@ -77,7 +77,7 @@ def test_unique_definition_hash(engine): s2.constraints_rel = [ StratumConstraint( constraint_variable="ucgid_str", - operation="equals", + operation="==", value="0400000US30", ) ] diff --git a/policyengine_us_data/tests/test_datasets/test_cd_state_files.py b/policyengine_us_data/tests/test_datasets/test_cd_state_files.py new file mode 100644 index 00000000..f59fcd9e --- /dev/null +++ b/policyengine_us_data/tests/test_datasets/test_cd_state_files.py @@ -0,0 +1,101 @@ +import pytest +from pathlib import Path +from policyengine_us import Microsimulation +from policyengine_core.data import Dataset + + +STATE_FILES_DIR = Path("policyengine_us_data/storage/cd_states") + +EXPECTED_CONGRESSIONAL_DISTRICTS = { + "NC": 14, + "CA": 52, + "TX": 38, + "FL": 28, + "NY": 26, + "PA": 17, +} + + +@pytest.mark.district_level_validation +@pytest.mark.parametrize( + "state_code,expected_districts", + [ + ("NC", 14), + ("CA", 52), + ("TX", 38), + ("FL", 28), + ("NY", 26), + ("PA", 17), + ], +) +def test_state_congressional_districts(state_code, expected_districts): + state_file = STATE_FILES_DIR / f"{state_code}.h5" + + if not state_file.exists(): + pytest.skip(f"State file {state_code}.h5 not yet generated") + + dataset = Dataset.from_file(state_file) + sim = Microsimulation(dataset=dataset) + + cd_geoids = sim.calculate("congressional_district_geoid") + unique_districts = len(set(cd_geoids)) + + assert unique_districts == expected_districts, ( + f"{state_code} should have {expected_districts} congressional districts, " + f"but found {unique_districts}" + ) + + +@pytest.mark.district_level_validation +def test_nc_has_positive_weights(): + state_file = STATE_FILES_DIR / "NC.h5" + + if not state_file.exists(): + pytest.skip("NC.h5 not yet generated") + + dataset = Dataset.from_file(state_file) + data = dataset.load_dataset() + weights = data["household_weight"]["2023"] + + assert (weights > 0).all(), "All household weights should be positive" + assert weights.sum() > 0, "Total weight should be positive" + + +@pytest.mark.district_level_validation +def test_nc_household_count_reasonable(): + state_file = STATE_FILES_DIR / "NC.h5" + + if not state_file.exists(): + pytest.skip("NC.h5 not yet generated") + + dataset = Dataset.from_file(state_file) + data = dataset.load_dataset() + weights = data["household_weight"]["2023"] + + total_households = weights.sum() + + NC_MIN_HOUSEHOLDS = 3_500_000 + NC_MAX_HOUSEHOLDS = 5_000_000 + + assert NC_MIN_HOUSEHOLDS < total_households < NC_MAX_HOUSEHOLDS, ( + f"NC total weighted households ({total_households:,.0f}) outside " + f"expected range ({NC_MIN_HOUSEHOLDS:,} - {NC_MAX_HOUSEHOLDS:,})" + ) + + +@pytest.mark.district_level_validation +def test_all_state_files_have_mapping_csv(): + state_files = list(STATE_FILES_DIR.glob("*.h5")) + + if not state_files: + pytest.skip("No state files generated yet") + + for state_file in state_files: + state_code = state_file.stem + if state_code == "cd_calibration": + continue + + mapping_file = STATE_FILES_DIR / f"{state_code}_household_mapping.csv" + assert ( + mapping_file.exists() + ), f"Missing household mapping CSV for {state_code}" diff --git a/policyengine_us_data/tests/test_uprating.py b/policyengine_us_data/tests/test_uprating.py new file mode 100644 index 00000000..cd2bf62c --- /dev/null +++ b/policyengine_us_data/tests/test_uprating.py @@ -0,0 +1,204 @@ +""" +Unit tests for calibration target uprating functionality. +""" + +import pytest +import pandas as pd +import numpy as np +from policyengine_us import Microsimulation +from policyengine_us_data.datasets.cps.geo_stacking_calibration.calibration_utils import ( + uprate_targets_df, +) + + +@pytest.fixture(scope="module") +def sim(): + """Create a microsimulation instance for testing.""" + return Microsimulation( + dataset="hf://policyengine/test/extended_cps_2023.h5" + ) + + +@pytest.fixture +def test_targets_2023(): + """Create test data with various source years to uprate to 2023.""" + return pd.DataFrame( + [ + # Income values from 2022 (should use CPI-U) + {"variable": "income_tax", "value": 1000000, "period": 2022}, + {"variable": "wages", "value": 5000000, "period": 2022}, + # Count values from 2022 (should use Population) + {"variable": "person_count", "value": 100000, "period": 2022}, + {"variable": "household_count", "value": 40000, "period": 2022}, + # Values from 2023 (should NOT be uprated) + {"variable": "income_tax", "value": 1100000, "period": 2023}, + {"variable": "person_count", "value": 101000, "period": 2023}, + # Values from 2024 (should be DOWNRATED to 2023) + {"variable": "income_tax", "value": 1200000, "period": 2024}, + {"variable": "person_count", "value": 102000, "period": 2024}, + ] + ) + + +def test_uprating_adds_tracking_columns(test_targets_2023, sim): + """Test that uprating adds the expected tracking columns.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + assert "original_value" in uprated.columns + assert "uprating_factor" in uprated.columns + assert "uprating_source" in uprated.columns + + +def test_no_uprating_for_target_year(test_targets_2023, sim): + """Test that values from the target year are not uprated.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + # Filter for 2023 data + target_year_data = uprated[uprated["period"] == 2023] + + # Check that 2023 data was not modified + assert (target_year_data["uprating_factor"] == 1.0).all() + assert (target_year_data["uprating_source"] == "None").all() + assert ( + target_year_data["value"] == target_year_data["original_value"] + ).all() + + +def test_cpi_uprating_for_monetary_values(test_targets_2023, sim): + """Test that monetary values use CPI-U uprating.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + # Check income tax from 2022 + income_2022 = uprated[ + (uprated["variable"] == "income_tax") & (uprated["period"] == 2022) + ].iloc[0] + assert income_2022["uprating_source"] == "CPI-U" + assert ( + income_2022["uprating_factor"] > 1.0 + ) # Should be inflated from 2022 to 2023 + assert ( + abs(income_2022["uprating_factor"] - 1.0641) < 0.001 + ) # Expected CPI factor + + # Check wages from 2022 + wages_2022 = uprated[ + (uprated["variable"] == "wages") & (uprated["period"] == 2022) + ].iloc[0] + assert wages_2022["uprating_source"] == "CPI-U" + assert ( + wages_2022["uprating_factor"] == income_2022["uprating_factor"] + ) # Same CPI factor + + +def test_population_uprating_for_counts(test_targets_2023, sim): + """Test that count variables use population uprating.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + # Check person count from 2022 + person_2022 = uprated[ + (uprated["variable"] == "person_count") & (uprated["period"] == 2022) + ].iloc[0] + assert person_2022["uprating_source"] == "Population" + assert ( + person_2022["uprating_factor"] > 1.0 + ) # Population grew from 2022 to 2023 + assert ( + abs(person_2022["uprating_factor"] - 1.0094) < 0.001 + ) # Expected population factor + + # Check household count from 2022 + household_2022 = uprated[ + (uprated["variable"] == "household_count") + & (uprated["period"] == 2022) + ].iloc[0] + assert household_2022["uprating_source"] == "Population" + assert ( + household_2022["uprating_factor"] == person_2022["uprating_factor"] + ) # Same population factor + + +def test_downrating_from_future_years(test_targets_2023, sim): + """Test that values from future years are correctly downrated.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + # Check income tax from 2024 (should be downrated) + income_2024 = uprated[ + (uprated["variable"] == "income_tax") & (uprated["period"] == 2024) + ].iloc[0] + assert income_2024["uprating_source"] == "CPI-U" + assert ( + income_2024["uprating_factor"] < 1.0 + ) # Should be deflated from 2024 to 2023 + assert ( + abs(income_2024["uprating_factor"] - 0.9700) < 0.001 + ) # Expected CPI factor + + # Check person count from 2024 + person_2024 = uprated[ + (uprated["variable"] == "person_count") & (uprated["period"] == 2024) + ].iloc[0] + assert person_2024["uprating_source"] == "Population" + assert ( + person_2024["uprating_factor"] < 1.0 + ) # Population was higher in 2024 + assert ( + abs(person_2024["uprating_factor"] - 0.9892) < 0.001 + ) # Expected population factor + + +def test_values_are_modified_correctly(test_targets_2023, sim): + """Test that values are actually modified by the uprating factors.""" + uprated = uprate_targets_df(test_targets_2023, target_year=2023, sim=sim) + + for _, row in uprated.iterrows(): + if row["uprating_factor"] != 1.0: + # Check that value was modified + expected_value = row["original_value"] * row["uprating_factor"] + assert ( + abs(row["value"] - expected_value) < 1.0 + ) # Allow for rounding + + +def test_no_double_uprating(test_targets_2023, sim): + """Test that calling uprate_targets_df twice doesn't double-uprate.""" + uprated_once = uprate_targets_df( + test_targets_2023, target_year=2023, sim=sim + ) + uprated_twice = uprate_targets_df(uprated_once, target_year=2023, sim=sim) + + # Values should be identical after second call + pd.testing.assert_series_equal( + uprated_once["value"], uprated_twice["value"] + ) + pd.testing.assert_series_equal( + uprated_once["uprating_factor"], uprated_twice["uprating_factor"] + ) + + +def test_numpy_int_compatibility(sim): + """Test that numpy int64 types work correctly (regression test).""" + # Create data with numpy int64 period column + data = pd.DataFrame( + { + "variable": ["income_tax"], + "value": [1000000], + "period": np.array([2022], dtype=np.int64), + } + ) + + # This should not raise an exception + uprated = uprate_targets_df(data, target_year=2023, sim=sim) + + # And should actually uprate + assert uprated["uprating_factor"].iloc[0] > 1.0 + assert uprated["value"].iloc[0] > uprated["original_value"].iloc[0] + + +def test_missing_period_column(): + """Test that missing period column is handled gracefully.""" + data = pd.DataFrame({"variable": ["income_tax"], "value": [1000000]}) + + result = uprate_targets_df(data, target_year=2023) + + # Should return unchanged + pd.testing.assert_frame_equal(result, data) diff --git a/policyengine_us_data/utils/__init__.py b/policyengine_us_data/utils/__init__.py index 2b93ecbf..c473dc6f 100644 --- a/policyengine_us_data/utils/__init__.py +++ b/policyengine_us_data/utils/__init__.py @@ -1,5 +1,5 @@ from .soi import * from .uprating import * from .loss import * -from .l0 import * +from .l0_modules import * from .seed import * diff --git a/policyengine_us_data/utils/db.py b/policyengine_us_data/utils/db.py index a8081db4..1cde7c7e 100644 --- a/policyengine_us_data/utils/db.py +++ b/policyengine_us_data/utils/db.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Dict from sqlmodel import Session, select import sqlalchemy as sa @@ -66,3 +66,87 @@ def get_stratum_parent(session: Session, stratum_id: int) -> Optional[Stratum]: if child_stratum: return child_stratum.parent_rel return None + + +def parse_ucgid(ucgid_str: str) -> Dict: + """Parse UCGID string to extract geographic information. + + UCGID (Universal Census Geographic ID) is a Census Bureau format + for identifying geographic areas. + + Returns: + dict with keys: 'type' ('national', 'state', 'district'), + 'state_fips' (if applicable), + 'district_number' (if applicable), + 'congressional_district_geoid' (if applicable) + """ + if ucgid_str == "0100000US": + return {"type": "national"} + elif ucgid_str.startswith("0400000US"): + state_fips = int(ucgid_str[9:]) + return {"type": "state", "state_fips": state_fips} + elif ucgid_str.startswith("5001800US"): + # Format: 5001800USSSDD where SS is state FIPS, DD is district + state_and_district = ucgid_str[9:] + state_fips = int(state_and_district[:2]) + district_number = int(state_and_district[2:]) + # Convert district 00 to 01 for at-large districts (matches create_initial_strata.py) + # Also convert DC's delegate district 98 to 01 + if district_number == 0 or ( + state_fips == 11 and district_number == 98 + ): + district_number = 1 + cd_geoid = state_fips * 100 + district_number + return { + "type": "district", + "state_fips": state_fips, + "district_number": district_number, + "congressional_district_geoid": cd_geoid, + } + else: + raise ValueError(f"Unknown UCGID format: {ucgid_str}") + + +def get_geographic_strata(session: Session) -> Dict: + """Fetch existing geographic strata from database. + + Returns dict mapping: + - 'national': stratum_id for US + - 'state': {state_fips: stratum_id} + - 'district': {congressional_district_geoid: stratum_id} + """ + strata_map = { + "national": None, + "state": {}, + "district": {}, + } + + # Get all strata with stratum_group_id = 1 (geographic strata) + stmt = select(Stratum).where(Stratum.stratum_group_id == 1) + geographic_strata = session.exec(stmt).unique().all() + + for stratum in geographic_strata: + # Get constraints for this stratum + constraints = session.exec( + select(StratumConstraint).where( + StratumConstraint.stratum_id == stratum.stratum_id + ) + ).all() + + if not constraints: + # No constraints = national level + strata_map["national"] = stratum.stratum_id + else: + # Check constraint types + constraint_vars = { + c.constraint_variable: c.value for c in constraints + } + + if "congressional_district_geoid" in constraint_vars: + cd_geoid = int(constraint_vars["congressional_district_geoid"]) + strata_map["district"][cd_geoid] = stratum.stratum_id + elif "state_fips" in constraint_vars: + state_fips = int(constraint_vars["state_fips"]) + strata_map["state"][state_fips] = stratum.stratum_id + + return strata_map diff --git a/policyengine_us_data/utils/db_metadata.py b/policyengine_us_data/utils/db_metadata.py new file mode 100644 index 00000000..5058c408 --- /dev/null +++ b/policyengine_us_data/utils/db_metadata.py @@ -0,0 +1,151 @@ +""" +Utility functions for managing database metadata (sources, variable groups, etc.) +""" + +from typing import Optional +from sqlmodel import Session, select +from policyengine_us_data.db.create_database_tables import ( + Source, + SourceType, + VariableGroup, + VariableMetadata, +) + + +def get_or_create_source( + session: Session, + name: str, + source_type: SourceType, + vintage: Optional[str] = None, + description: Optional[str] = None, + url: Optional[str] = None, + notes: Optional[str] = None, +) -> Source: + """ + Get an existing source or create a new one. + + Args: + session: Database session + name: Name of the data source + source_type: Type of source (administrative, survey, etc.) + vintage: Version or year of the data + description: Detailed description + url: Reference URL + notes: Additional notes + + Returns: + Source object with source_id populated + """ + # Try to find existing source by name and vintage + query = select(Source).where(Source.name == name) + if vintage: + query = query.where(Source.vintage == vintage) + + source = session.exec(query).first() + + if not source: + # Create new source + source = Source( + name=name, + type=source_type, + vintage=vintage, + description=description, + url=url, + notes=notes, + ) + session.add(source) + session.flush() # Get the auto-generated ID + + return source + + +def get_or_create_variable_group( + session: Session, + name: str, + category: str, + is_histogram: bool = False, + is_exclusive: bool = False, + aggregation_method: Optional[str] = None, + display_order: Optional[int] = None, + description: Optional[str] = None, +) -> VariableGroup: + """ + Get an existing variable group or create a new one. + + Args: + session: Database session + name: Unique name of the variable group + category: High-level category (demographic, benefit, tax, income) + is_histogram: Whether this represents a distribution + is_exclusive: Whether variables are mutually exclusive + aggregation_method: How to aggregate (sum, weighted_avg, etc.) + display_order: Order for display + description: Description of the group + + Returns: + VariableGroup object with group_id populated + """ + group = session.exec( + select(VariableGroup).where(VariableGroup.name == name) + ).first() + + if not group: + group = VariableGroup( + name=name, + category=category, + is_histogram=is_histogram, + is_exclusive=is_exclusive, + aggregation_method=aggregation_method, + display_order=display_order, + description=description, + ) + session.add(group) + session.flush() # Get the auto-generated ID + + return group + + +def get_or_create_variable_metadata( + session: Session, + variable: str, + group: Optional[VariableGroup] = None, + display_name: Optional[str] = None, + display_order: Optional[int] = None, + units: Optional[str] = None, + is_primary: bool = True, + notes: Optional[str] = None, +) -> VariableMetadata: + """ + Get existing variable metadata or create new. + + Args: + session: Database session + variable: PolicyEngine variable name + group: Variable group this belongs to + display_name: Human-readable name + display_order: Order within group + units: Units of measurement + is_primary: Whether this is a primary variable + notes: Additional notes + + Returns: + VariableMetadata object + """ + metadata = session.exec( + select(VariableMetadata).where(VariableMetadata.variable == variable) + ).first() + + if not metadata: + metadata = VariableMetadata( + variable=variable, + group_id=group.group_id if group else None, + display_name=display_name or variable, + display_order=display_order, + units=units, + is_primary=is_primary, + notes=notes, + ) + session.add(metadata) + session.flush() + + return metadata diff --git a/policyengine_us_data/utils/l0.py b/policyengine_us_data/utils/l0_modules.py similarity index 100% rename from policyengine_us_data/utils/l0.py rename to policyengine_us_data/utils/l0_modules.py diff --git a/pyproject.toml b/pyproject.toml index 3d00d389..5af97765 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "tqdm>=4.60.0", "microdf_python>=1.0.0", "setuptools>=60", - "microimpute>=1.1.4", + "microimpute>=1.1.4, <= 1.2.0", "pip-system-certs>=3.0", "google-cloud-storage>=2.0.0", "google-auth>=2.0.0", @@ -71,6 +71,9 @@ addopts = "-v" testpaths = [ "policyengine_us_data/tests", ] +markers = [ + "district_level_validation: tests that require generated data files from the district-level calibration pipeline", +] [tool.black] line-length = 79 diff --git a/tests/test_geo_stacking_reconciliation.py b/tests/test_geo_stacking_reconciliation.py new file mode 100644 index 00000000..fb41e173 --- /dev/null +++ b/tests/test_geo_stacking_reconciliation.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +""" +Unit tests for geo-stacking reconciliation logic. + +These are self-contained tests that verify the reconciliation of +targets across geographic hierarchies (CD -> State -> National). +""" + +import unittest +from unittest.mock import Mock, MagicMock, patch +import pandas as pd +import numpy as np + + +class TestReconciliationLogic(unittest.TestCase): + """Test reconciliation of hierarchical targets.""" + + def test_age_reconciliation_cd_to_state(self): + """Test that CD age targets are adjusted to match state totals.""" + # Create mock CD targets for California + cd_geoids = ["601", "602", "603"] + age_bins = ["age_0_4", "age_5_9", "age_10_14"] + + # CD targets (survey-based, undercount state totals) + cd_targets = [] + for cd in cd_geoids: + for age_bin in age_bins: + cd_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 2, # Age + "variable": "person_count", + "constraint": age_bin, + "value": 10000, # Each CD has 10,000 per age bin + "source": "survey", + } + ) + + cd_df = pd.DataFrame(cd_targets) + + # State targets (administrative, authoritative) + state_targets = [] + for age_bin in age_bins: + state_targets.append( + { + "geographic_id": "6", # California FIPS + "stratum_group_id": 2, + "variable": "person_count", + "constraint": age_bin, + "value": 33000, # State total: 33,000 per age bin (10% higher) + "source": "administrative", + } + ) + + state_df = pd.DataFrame(state_targets) + + # Calculate reconciliation factors + reconciliation_factors = {} + for age_bin in age_bins: + cd_sum = cd_df[cd_df["constraint"] == age_bin]["value"].sum() + state_val = state_df[state_df["constraint"] == age_bin][ + "value" + ].iloc[0] + reconciliation_factors[age_bin] = ( + state_val / cd_sum if cd_sum > 0 else 1.0 + ) + + # Apply reconciliation + reconciled_cd_df = cd_df.copy() + reconciled_cd_df["original_value"] = reconciled_cd_df["value"] + reconciled_cd_df["reconciliation_factor"] = reconciled_cd_df[ + "constraint" + ].map(reconciliation_factors) + reconciled_cd_df["value"] = ( + reconciled_cd_df["original_value"] + * reconciled_cd_df["reconciliation_factor"] + ) + + # Verify reconciliation + for age_bin in age_bins: + reconciled_sum = reconciled_cd_df[ + reconciled_cd_df["constraint"] == age_bin + ]["value"].sum() + state_val = state_df[state_df["constraint"] == age_bin][ + "value" + ].iloc[0] + + self.assertAlmostEqual( + reconciled_sum, + state_val, + 2, + f"Reconciled CD sum for {age_bin} should match state total", + ) + + # Check factor is correct (should be 1.1 = 33000/30000) + factor = reconciliation_factors[age_bin] + self.assertAlmostEqual( + factor, + 1.1, + 4, + f"Reconciliation factor for {age_bin} should be 1.1", + ) + + def test_medicaid_reconciliation_survey_to_admin(self): + """Test Medicaid reconciliation from survey to administrative data.""" + # CD-level survey data (typically undercounts) + cd_geoids = ["601", "602", "603", "604", "605"] + + cd_medicaid = pd.DataFrame( + { + "geographic_id": cd_geoids, + "stratum_group_id": [5] * 5, # Medicaid group + "variable": ["person_count"] * 5, + "value": [45000, 48000, 42000, 50000, 40000], # Survey counts + "source": ["survey"] * 5, + } + ) + + cd_total = cd_medicaid["value"].sum() # 225,000 + + # State-level administrative data (authoritative) + state_medicaid = pd.DataFrame( + { + "geographic_id": ["6"], # California + "stratum_group_id": [5], + "variable": ["person_count"], + "value": [270000], # 20% higher than survey + "source": ["administrative"], + } + ) + + state_total = state_medicaid["value"].iloc[0] + + # Calculate reconciliation + reconciliation_factor = state_total / cd_total + expected_factor = 1.2 # 270000 / 225000 + + self.assertAlmostEqual( + reconciliation_factor, + expected_factor, + 4, + "Reconciliation factor should be 1.2", + ) + + # Apply reconciliation + cd_medicaid["reconciliation_factor"] = reconciliation_factor + cd_medicaid["original_value"] = cd_medicaid["value"] + cd_medicaid["value"] = cd_medicaid["value"] * reconciliation_factor + + # Verify total matches + reconciled_total = cd_medicaid["value"].sum() + self.assertAlmostEqual( + reconciled_total, + state_total, + 2, + "Reconciled CD total should match state administrative total", + ) + + # Verify each CD was scaled proportionally + for i, cd in enumerate(cd_geoids): + original = cd_medicaid.iloc[i]["original_value"] + reconciled = cd_medicaid.iloc[i]["value"] + expected_reconciled = original * expected_factor + + self.assertAlmostEqual( + reconciled, + expected_reconciled, + 2, + f"CD {cd} should be scaled by factor {expected_factor}", + ) + + def test_snap_household_reconciliation(self): + """Test SNAP household count reconciliation.""" + # CD-level SNAP household counts + cd_geoids = ["601", "602", "603"] + + cd_snap = pd.DataFrame( + { + "geographic_id": cd_geoids, + "stratum_group_id": [4] * 3, # SNAP group + "variable": ["household_count"] * 3, + "value": [20000, 25000, 18000], # Survey counts + "source": ["survey"] * 3, + } + ) + + cd_total = cd_snap["value"].sum() # 63,000 + + # State-level administrative SNAP households + state_snap = pd.DataFrame( + { + "geographic_id": ["6"], + "stratum_group_id": [4], + "variable": ["household_count"], + "value": [69300], # 10% higher + "source": ["administrative"], + } + ) + + state_total = state_snap["value"].iloc[0] + + # Calculate and apply reconciliation + factor = state_total / cd_total + cd_snap["reconciled_value"] = cd_snap["value"] * factor + + # Verify + self.assertAlmostEqual( + factor, 1.1, 4, "SNAP reconciliation factor should be 1.1" + ) + + reconciled_total = cd_snap["reconciled_value"].sum() + self.assertAlmostEqual( + reconciled_total, + state_total, + 2, + "Reconciled SNAP totals should match state administrative data", + ) + + def test_no_reconciliation_when_no_higher_level(self): + """Test that targets are not modified when no higher-level data exists.""" + # CD targets with no corresponding state data + cd_targets = pd.DataFrame( + { + "geographic_id": ["601", "602"], + "stratum_group_id": [ + 999, + 999, + ], # Some group without state targets + "variable": ["custom_var", "custom_var"], + "value": [1000, 2000], + "source": ["survey", "survey"], + } + ) + + # No state targets available + state_targets = pd.DataFrame() # Empty + + # Reconciliation should not change values + reconciled = cd_targets.copy() + reconciled["reconciliation_factor"] = 1.0 # No change + + # Verify no change + for i in range(len(cd_targets)): + self.assertEqual( + reconciled.iloc[i]["value"], + cd_targets.iloc[i]["value"], + "Values should not change when no higher-level data exists", + ) + self.assertEqual( + reconciled.iloc[i]["reconciliation_factor"], + 1.0, + "Reconciliation factor should be 1.0 when no adjustment needed", + ) + + def test_undercount_percentage_calculation(self): + """Test calculation of undercount percentages.""" + # Survey total: 900,000 + # Admin total: 1,000,000 + # Undercount: 100,000 (10%) + + survey_total = 900000 + admin_total = 1000000 + + undercount = admin_total - survey_total + undercount_pct = (undercount / admin_total) * 100 + + self.assertAlmostEqual( + undercount_pct, 10.0, 2, "Undercount percentage should be 10%" + ) + + # Alternative calculation using factor + factor = admin_total / survey_total + undercount_pct_alt = (1 - 1 / factor) * 100 + + self.assertAlmostEqual( + undercount_pct_alt, + 10.0, + 2, + "Alternative undercount calculation should also give 10%", + ) + + def test_hierarchical_reconciliation_order(self): + """Test that reconciliation preserves hierarchical consistency.""" + # National -> State -> CD hierarchy + + # National target + national_total = 1000000 + + # State targets (should sum to national) + state_targets = pd.DataFrame( + { + "state_fips": ["6", "36", "48"], # CA, NY, TX + "value": [400000, 350000, 250000], + } + ) + + # CD targets (should sum to respective states) + cd_targets = pd.DataFrame( + { + "cd_geoid": ["601", "602", "3601", "3602", "4801"], + "state_fips": ["6", "6", "36", "36", "48"], + "value": [ + 180000, + 200000, + 160000, + 170000, + 240000, + ], # Slightly off from state totals + } + ) + + # Step 1: Reconcile states to national + state_sum = state_targets["value"].sum() + self.assertEqual( + state_sum, national_total, "States should sum to national" + ) + + # Step 2: Reconcile CDs to states + for state_fips in ["6", "36", "48"]: + state_total = state_targets[ + state_targets["state_fips"] == state_fips + ]["value"].iloc[0] + cd_state_mask = cd_targets["state_fips"] == state_fips + cd_state_sum = cd_targets[cd_state_mask]["value"].sum() + + if cd_state_sum > 0: + factor = state_total / cd_state_sum + cd_targets.loc[cd_state_mask, "reconciled_value"] = ( + cd_targets.loc[cd_state_mask, "value"] * factor + ) + + # Verify hierarchical consistency + for state_fips in ["6", "36", "48"]: + state_total = state_targets[ + state_targets["state_fips"] == state_fips + ]["value"].iloc[0] + cd_state_mask = cd_targets["state_fips"] == state_fips + cd_reconciled_sum = cd_targets[cd_state_mask][ + "reconciled_value" + ].sum() + + self.assertAlmostEqual( + cd_reconciled_sum, + state_total, + 2, + f"Reconciled CDs in state {state_fips} should sum to state total", + ) + + # Verify grand total + total_reconciled = cd_targets["reconciled_value"].sum() + self.assertAlmostEqual( + total_reconciled, + national_total, + 2, + "All reconciled CDs should sum to national total", + ) + + +class TestReconciliationEdgeCases(unittest.TestCase): + """Test edge cases in reconciliation logic.""" + + def test_zero_survey_values(self): + """Test handling of zero values in survey data.""" + cd_targets = pd.DataFrame( + { + "geographic_id": ["601", "602", "603"], + "value": [0, 1000, 2000], # First CD has zero + } + ) + + state_total = 3300 # 10% higher than non-zero sum + + # Calculate factor based on non-zero values + non_zero_sum = cd_targets[cd_targets["value"] > 0]["value"].sum() + factor = state_total / non_zero_sum if non_zero_sum > 0 else 1.0 + + # Apply reconciliation + cd_targets["reconciled"] = cd_targets["value"] * factor + + # Zero should remain zero + self.assertEqual( + cd_targets.iloc[0]["reconciled"], + 0, + "Zero values should remain zero after reconciliation", + ) + + # Non-zero values should be scaled + self.assertAlmostEqual( + cd_targets.iloc[1]["reconciled"], + 1100, + 2, + "Non-zero values should be scaled appropriately", + ) + + def test_missing_geographic_coverage(self): + """Test when some CDs are missing from survey data.""" + # Only 3 of 5 CDs have data + cd_targets = pd.DataFrame( + { + "geographic_id": ["601", "602", "603"], + "value": [30000, 35000, 25000], + } + ) + + # State total covers all 5 CDs + state_total = 150000 # Implies 60,000 for missing CDs + + # Can only reconcile the CDs we have + cd_sum = cd_targets["value"].sum() + available_ratio = cd_sum / state_total # 90,000 / 150,000 = 0.6 + + self.assertAlmostEqual( + available_ratio, + 0.6, + 4, + "Available CDs represent 60% of state total", + ) + + # Options for handling: + # 1. Scale up existing CDs (not recommended - distorts distribution) + # 2. Flag as incomplete coverage (recommended) + # 3. Impute missing CDs first, then reconcile + + # Test option 2: Flag incomplete coverage + coverage_threshold = 0.8 # Require 80% coverage + has_sufficient_coverage = available_ratio >= coverage_threshold + + self.assertFalse( + has_sufficient_coverage, + "Should flag insufficient coverage when <80% of CDs present", + ) + + def test_negative_values(self): + """Test handling of negative values (should not occur but test anyway).""" + cd_targets = pd.DataFrame( + { + "geographic_id": ["601", "602"], + "value": [-1000, 2000], # Negative value (data error) + } + ) + + # Should either: + # 1. Raise an error + # 2. Treat as zero + # 3. Take absolute value + + # Test option 2: Treat negatives as zero + cd_targets["cleaned_value"] = cd_targets["value"].apply( + lambda x: max(0, x) + ) + + self.assertEqual( + cd_targets.iloc[0]["cleaned_value"], + 0, + "Negative values should be treated as zero", + ) + + self.assertEqual( + cd_targets.iloc[1]["cleaned_value"], + 2000, + "Positive values should remain unchanged", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_geo_stacking_targets.py b/tests/test_geo_stacking_targets.py new file mode 100644 index 00000000..709c9c39 --- /dev/null +++ b/tests/test_geo_stacking_targets.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +""" +Unit tests for geo-stacking target counts. + +These are self-contained tests that verify target count expectations +without requiring database connections or external dependencies. +""" + +import unittest +from unittest.mock import Mock, MagicMock, patch +import pandas as pd +import numpy as np + + +class TestGeoStackingTargets(unittest.TestCase): + """Test target count expectations for geo-stacking calibration.""" + + def setUp(self): + """Set up test fixtures with mocked components.""" + # Mock the builder class entirely + self.mock_builder = Mock() + self.mock_sim = Mock() + + def test_age_targets_per_cd(self): + """Test that each CD gets exactly 18 age bins.""" + test_cds = ["601", "652", "3601"] + + # Create expected targets DataFrame + mock_targets = [] + for cd in test_cds: + for age_bin in range(18): # 18 age bins per CD + mock_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 2, # Age group + "variable": "person_count", + "value": 10000, + "description": f"age_bin_{age_bin}", + } + ) + + targets_df = pd.DataFrame(mock_targets) + + # Verify age targets per CD + age_mask = targets_df["stratum_group_id"] == 2 + age_targets = targets_df[age_mask] + + for cd in test_cds: + cd_age_targets = age_targets[age_targets["geographic_id"] == cd] + self.assertEqual( + len(cd_age_targets), + 18, + f"CD {cd} should have exactly 18 age bins", + ) + + def test_medicaid_targets_count(self): + """Test that we get one Medicaid target per CD.""" + test_cds = ["601", "652", "3601", "4801"] + + # Create expected targets with one Medicaid target per CD + mock_targets = [] + for cd in test_cds: + mock_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 5, # Medicaid group + "variable": "person_count", + "value": 50000, + "description": f"medicaid_enrollment_cd_{cd}", + } + ) + + targets_df = pd.DataFrame(mock_targets) + + # Check Medicaid targets + medicaid_mask = targets_df["stratum_group_id"] == 5 + medicaid_targets = targets_df[medicaid_mask] + + self.assertEqual( + len(medicaid_targets), + len(test_cds), + f"Should have exactly one Medicaid target per CD", + ) + + # Verify each CD has exactly one + for cd in test_cds: + cd_medicaid = medicaid_targets[ + medicaid_targets["geographic_id"] == cd + ] + self.assertEqual( + len(cd_medicaid), + 1, + f"CD {cd} should have exactly one Medicaid target", + ) + + def test_snap_targets_structure(self): + """Test SNAP targets: one household_count per CD plus state costs.""" + test_cds = ["601", "602", "3601", "4801", "1201"] # CA, CA, NY, TX, FL + expected_states = ["6", "36", "48", "12"] # Unique state FIPS + + mock_targets = [] + + # CD-level SNAP household counts + for cd in test_cds: + mock_targets.append( + { + "geographic_id": cd, + "geographic_level": "congressional_district", + "stratum_group_id": 4, # SNAP group + "variable": "household_count", + "value": 20000, + "description": f"snap_households_cd_{cd}", + } + ) + + # State-level SNAP costs + for state_fips in expected_states: + mock_targets.append( + { + "geographic_id": state_fips, + "geographic_level": "state", + "stratum_group_id": 4, # SNAP group + "variable": "snap", + "value": 1000000000, # $1B + "description": f"snap_cost_state_{state_fips}", + } + ) + + targets_df = pd.DataFrame(mock_targets) + + # Check CD-level SNAP + cd_snap = targets_df[ + (targets_df["geographic_level"] == "congressional_district") + & (targets_df["variable"] == "household_count") + & (targets_df["stratum_group_id"] == 4) + ] + self.assertEqual( + len(cd_snap), + len(test_cds), + "Should have one SNAP household_count per CD", + ) + + # Check state-level SNAP costs + state_snap = targets_df[ + (targets_df["geographic_level"] == "state") + & (targets_df["variable"] == "snap") + & (targets_df["stratum_group_id"] == 4) + ] + self.assertEqual( + len(state_snap), + len(expected_states), + "Should have one SNAP cost per unique state", + ) + + def test_irs_targets_per_cd(self): + """Test that each CD gets approximately 76 IRS targets.""" + test_cds = ["601", "3601"] + expected_irs_per_cd = 76 + + mock_targets = [] + + # Generate IRS targets for each CD + for cd in test_cds: + # AGI bins (group 3) - 18 bins + for i in range(18): + mock_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 3, + "variable": "tax_unit_count", + "value": 5000, + "description": f"agi_bin_{i}_cd_{cd}", + } + ) + + # EITC bins (group 6) - 18 bins + for i in range(18): + mock_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 6, + "variable": "tax_unit_count", + "value": 2000, + "description": f"eitc_bin_{i}_cd_{cd}", + } + ) + + # IRS scalars (groups >= 100) - 40 scalars + # This gives us 18 + 18 + 40 = 76 total + scalar_count = 40 + for i in range(scalar_count): + mock_targets.append( + { + "geographic_id": cd, + "stratum_group_id": 100 + (i % 10), + "variable": "irs_scalar_" + str(i), + "value": 100000, + "description": f"irs_scalar_{i}_cd_{cd}", + } + ) + + targets_df = pd.DataFrame(mock_targets) + + # Count IRS targets per CD + for cd in test_cds: + cd_targets = targets_df[targets_df["geographic_id"] == cd] + self.assertEqual( + len(cd_targets), + expected_irs_per_cd, + f"CD {cd} should have exactly {expected_irs_per_cd} IRS targets", + ) + + def test_total_target_counts_for_full_run(self): + """Test expected total target counts for a full 436 CD run.""" + n_cds = 436 + n_states = 51 + + # Expected counts per category + expected_counts = { + "national": 30, + "age_per_cd": 18, + "medicaid_per_cd": 1, + "snap_per_cd": 1, + "irs_per_cd": 76, + "state_snap": n_states, + } + + # Calculate totals + total_cd_targets = n_cds * ( + expected_counts["age_per_cd"] + + expected_counts["medicaid_per_cd"] + + expected_counts["snap_per_cd"] + + expected_counts["irs_per_cd"] + ) + + total_expected = ( + expected_counts["national"] + + total_cd_targets + + expected_counts["state_snap"] + ) + + # Verify calculation matches known expectation (allowing some tolerance) + self.assertTrue( + 41837 <= total_expected <= 42037, + f"Total targets for 436 CDs should be approximately 41,937, got {total_expected}", + ) + + # Check individual components + age_total = expected_counts["age_per_cd"] * n_cds + self.assertEqual(age_total, 7848, "Age targets should total 7,848") + + medicaid_total = expected_counts["medicaid_per_cd"] * n_cds + self.assertEqual( + medicaid_total, 436, "Medicaid targets should total 436" + ) + + snap_cd_total = expected_counts["snap_per_cd"] * n_cds + snap_total = snap_cd_total + expected_counts["state_snap"] + self.assertEqual(snap_total, 487, "SNAP targets should total 487") + + irs_total = expected_counts["irs_per_cd"] * n_cds + self.assertEqual(irs_total, 33136, "IRS targets should total 33,136") + + +class TestTargetDeduplication(unittest.TestCase): + """Test deduplication of targets across CDs.""" + + def test_irs_scalar_deduplication_within_state(self): + """Test that IRS scalars are not duplicated for CDs in the same state.""" + # Test with two California CDs + test_cds = ["601", "602"] + + # Create mock targets with overlapping state-level IRS scalars + mock_targets_601 = [ + { + "stratum_id": 1001, + "stratum_group_id": 100, + "variable": "income_tax", + "value": 1000000, + "geographic_id": "601", + }, + { + "stratum_id": 1002, + "stratum_group_id": 100, + "variable": "salt", + "value": 500000, + "geographic_id": "601", + }, + ] + + mock_targets_602 = [ + { + "stratum_id": 1001, + "stratum_group_id": 100, + "variable": "income_tax", + "value": 1000000, + "geographic_id": "602", + }, + { + "stratum_id": 1002, + "stratum_group_id": 100, + "variable": "salt", + "value": 500000, + "geographic_id": "602", + }, + ] + + # The deduplication should recognize these are the same stratum_ids + seen_strata = set() + deduplicated_targets = [] + + for targets in [mock_targets_601, mock_targets_602]: + for target in targets: + if target["stratum_id"] not in seen_strata: + seen_strata.add(target["stratum_id"]) + deduplicated_targets.append(target) + + self.assertEqual( + len(deduplicated_targets), + 2, + "Should only count unique stratum_ids once across CDs", + ) + + # Verify we kept the unique targets + unique_strata_ids = {t["stratum_id"] for t in deduplicated_targets} + self.assertEqual(unique_strata_ids, {1001, 1002}) + + +if __name__ == "__main__": + unittest.main()