diff --git a/Makefile b/Makefile index c538a7bc..c1188cb5 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,7 @@ documentation-dev: database: rm -f policyengine_us_data/storage/calibration/policy_data.db python policyengine_us_data/db/create_database_tables.py + python policyengine_us_data/db/create_field_valid_values.py python policyengine_us_data/db/create_initial_strata.py python policyengine_us_data/db/etl_national_targets.py python policyengine_us_data/db/etl_age.py diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..bb06e171 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: minor + changes: + added: + - field_valid_values table in the targets database as source of truth for semantic or external target information. + - constraint_validation.py to ensure constraint operations result in consistent and valid sets of constraints for a given stratum. \ No newline at end of file diff --git a/policyengine_us_data/db/DATABASE_GUIDE.md b/policyengine_us_data/db/DATABASE_GUIDE.md index ac038cb7..bb3ea24d 100644 --- a/policyengine_us_data/db/DATABASE_GUIDE.md +++ b/policyengine_us_data/db/DATABASE_GUIDE.md @@ -24,15 +24,16 @@ make promote-database # Copy DB + raw inputs to HuggingFace clone | # | Script | Network? | What it does | |---|--------|----------|--------------| -| 1 | `create_database_tables.py` | No | Creates empty SQLite schema (7 tables) | -| 2 | `create_initial_strata.py` | Census ACS 5-year | Builds geographic hierarchy: US > 51 states > 436 CDs | -| 3 | `etl_national_targets.py` | No | Loads ~40 hardcoded national targets (CBO, Treasury, CMS) | -| 4 | `etl_age.py` | Census ACS 1-year | Age distribution: 18 bins x 488 geographies | -| 5 | `etl_medicaid.py` | Census ACS + CMS | Medicaid enrollment (admin state-level, survey district-level) | -| 6 | `etl_snap.py` | USDA FNS + Census ACS | SNAP participation (admin state-level, survey district-level) | -| 7 | `etl_state_income_tax.py` | No | State income tax collections (Census STC FY2023, hardcoded) | -| 8 | `etl_irs_soi.py` | IRS | Tax variables, EITC by child count, AGI brackets, conditional strata | -| 9 | `validate_database.py` | No | Checks all target variables exist in policyengine-us | +| 1 | `create_database_tables.py` | No | Creates SQLite schema (8 tables) + validation triggers | +| 2 | `create_field_valid_values.py` | No | Populates field_valid_values with allowed values | +| 3 | `create_initial_strata.py` | Census ACS 5-year | Builds geographic hierarchy: US > 51 states > 436 CDs | +| 4 | `etl_national_targets.py` | No | Loads ~40 hardcoded national targets (CBO, Treasury, CMS) | +| 5 | `etl_age.py` | Census ACS 1-year | Age distribution: 18 bins x 488 geographies | +| 6 | `etl_medicaid.py` | Census ACS + CMS | Medicaid enrollment (admin state-level, survey district-level) | +| 7 | `etl_snap.py` | USDA FNS + Census ACS | SNAP participation (admin state-level, survey district-level) | +| 8 | `etl_state_income_tax.py` | No | State income tax collections (Census STC FY2023, hardcoded) | +| 9 | `etl_irs_soi.py` | IRS | Tax variables, EITC by child count, AGI brackets, conditional strata | +| 10 | `validate_database.py` | No | Checks all target variables exist in policyengine-us | ### Raw Input Caching @@ -94,6 +95,36 @@ make database **variable_metadata** - Display info for variables (display name, units, ordering) +### Validation Table + +**field_valid_values** - Centralized registry of valid values for semantic fields + +This table is the source of truth for what values are allowed in specific fields throughout +the database. Expecifically those that deal with semantic external information rather than designing relationships inherent to teh database itself. SQL triggers enforce validation on INSERT and UPDATE operations. + +| Field Validated | Table | Valid Values | +|-----------------|-------|--------------| +| `operation` | stratum_constraints | `==`, `!=`, `>`, `>=`, `<`, `<=` | +| `constraint_variable` | stratum_constraints | All policyengine-us variables | +| `active` | targets | `0`, `1` | +| `period` | targets | `2022`, `2023`, `2024`, `2025` | +| `variable` | targets | All policyengine-us variables | +| `type` | sources | `administrative`, `survey`, `synthetic`, `derived`, `hardcoded` | + +**Triggers**: `validate_stratum_constraints_insert`, `validate_stratum_constraints_update`, +`validate_targets_insert`, `validate_targets_update`, `validate_sources_insert`, `validate_sources_update` + +To add a new valid value (e.g., a new year): +```sql +INSERT INTO field_valid_values (field_name, valid_value, description) +VALUES ('period', '2026', NULL); +``` + +To check what values are valid for a field: +```sql +SELECT valid_value, description FROM field_valid_values WHERE field_name = 'operation'; +``` + ## Key Concepts ### Stratum Groups @@ -153,9 +184,63 @@ ETL scripts that pull Census data receive UCGIDs and create their own domain-spe ### Constraint Operations -All constraints use standardized operators validated by the `ConstraintOperation` enum: +All constraints use standardized operators validated by the `field_valid_values` table: `==`, `!=`, `>`, `>=`, `<`, `<=` +### Constraint Validation + +ETL scripts validate constraint sets before inserting them into the database using `ensure_consistent_constraint_set()` from `policyengine_us_data.utils.constraint_validation`. This prevents logically inconsistent constraints from being stored. + +**Validation Rules:** + +1. **Operation Compatibility** (per constraint_variable): + +| Operation | Can combine with | Rationale | +|-----------|-----------------|-----------| +| `==` | Nothing (must be alone) | Equality is absolute | +| `!=` | Nothing (must be alone) | Exclusion is absolute | +| `>` | `<` or `<=` only | Forms valid range | +| `>=` | `<` or `<=` only | Forms valid range | +| `<` | `>` or `>=` only | Forms valid range | +| `<=` | `>` or `>=` only | Forms valid range | + +**Invalid combinations:** +- `>` with `>=` (redundant/conflicting lower bounds) +- `<` with `<=` (redundant/conflicting upper bounds) +- `==` with anything else +- `!=` with anything else + +2. **Value Checks** (if operations are compatible): +- No empty ranges: lower bound must be < upper bound +- For equal bounds, both must be inclusive (`>=` and `<=`) to be valid + +**Usage in ETL:** +```python +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) + +# Build constraint list +constraint_list = [ + Constraint(variable="age", operation=">=", value="25"), + Constraint(variable="age", operation="<", value="30"), +] + +# Validate before creating StratumConstraint objects +ensure_consistent_constraint_set(constraint_list) + +# Now safe to add to stratum +stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in constraint_list +] +``` + ### Constraint Value Types The `value` column stores all values as strings. Downstream code deserializes: diff --git a/policyengine_us_data/db/create_database_tables.py b/policyengine_us_data/db/create_database_tables.py index 9485d02e..08316bd0 100644 --- a/policyengine_us_data/db/create_database_tables.py +++ b/policyengine_us_data/db/create_database_tables.py @@ -1,9 +1,8 @@ import logging import hashlib from typing import List, Optional -from enum import Enum -from sqlalchemy import event, UniqueConstraint +from sqlalchemy import event, UniqueConstraint, text from sqlalchemy.orm.attributes import get_history from sqlmodel import ( Field, @@ -11,8 +10,6 @@ SQLModel, create_engine, ) -from pydantic import validator -from policyengine_us.system import system from policyengine_us_data.storage import STORAGE_FOLDER @@ -24,21 +21,29 @@ logger = logging.getLogger(__name__) -# An Enum type to ensure the variable exists in policyengine-us -USVariable = Enum( - "USVariable", {name: name for name in system.variables.keys()}, type=str -) +class FieldValidValues(SQLModel, table=True): + """Valid values for semantic fields in the database. + This table serves as the source of truth for what values are allowed + in specific fields throughout the database. SQL triggers enforce that + values inserted or updated match entries in this table. + """ -class ConstraintOperation(str, Enum): - """Allowed operations for stratum constraints.""" + __tablename__ = "field_valid_values" - EQ = "==" # Equals - NE = "!=" # Not equals - GT = ">" # Greater than - GE = ">=" # Greater than or equal - LT = "<" # Less than - LE = "<=" # Less than or equal + field_name: str = Field( + primary_key=True, + description="The field/column being validated " + "(e.g., 'operation', 'variable')", + ) + valid_value: str = Field( + primary_key=True, + description="A valid value for this field", + ) + description: Optional[str] = Field( + default=None, + description="Human-readable description of this value", + ) class Stratum(SQLModel, table=True): @@ -117,16 +122,6 @@ class StratumConstraint(SQLModel, table=True): strata_rel: Stratum = Relationship(back_populates="constraints_rel") - @validator("operation") - def validate_operation(cls, v): - """Validate that the operation is one of the allowed values.""" - allowed_ops = [op.value for op in ConstraintOperation] - if v not in allowed_ops: - raise ValueError( - f"Invalid operation '{v}'. Must be one of: {', '.join(allowed_ops)}" - ) - return v - class Target(SQLModel, table=True): """Stores the data values for a specific stratum.""" @@ -143,7 +138,7 @@ class Target(SQLModel, table=True): ) target_id: Optional[int] = Field(default=None, primary_key=True) - variable: USVariable = Field( + variable: str = Field( description="A variable defined in policyengine-us (e.g., 'income_tax')." ) period: int = Field( @@ -179,18 +174,6 @@ class Target(SQLModel, table=True): source_rel: Optional["Source"] = Relationship() -class SourceType(str, Enum): - """Types of data sources.""" - - ADMINISTRATIVE = "administrative" - SURVEY = "survey" - SYNTHETIC = "synthetic" - DERIVED = "derived" - HARDCODED = ( - "hardcoded" # Values from various sources, hardcoded into the system - ) - - class Source(SQLModel, table=True): """Metadata about data sources.""" @@ -208,7 +191,7 @@ class Source(SQLModel, table=True): description="Name of the data source (e.g., 'IRS SOI', 'Census ACS').", index=True, ) - type: SourceType = Field( + type: str = Field( description="Type of data source (administrative, survey, etc.)." ) description: Optional[str] = Field( @@ -344,6 +327,177 @@ def calculate_definition_hash(mapper, connection, target: Stratum): target.definition_hash = h.hexdigest() +def create_validation_triggers(engine): + """Create SQL triggers to enforce field validation on INSERT and UPDATE. + + These triggers check that values in semantic fields match entries in the + field_valid_values table before allowing the operation to proceed. + + Args: + engine: SQLAlchemy Engine instance connected to the database. + """ + with engine.connect() as conn: + # ============================================ + # Triggers for stratum_constraints table + # ============================================ + + # INSERT trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_stratum_constraints_insert + BEFORE INSERT ON stratum_constraints + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'operation' + AND valid_value = NEW.operation + ) + THEN RAISE(ABORT, 'Invalid operation value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'constraint_variable' + AND valid_value = NEW.constraint_variable + ) + THEN RAISE(ABORT, 'Invalid constraint_variable value') + END; + END; + """)) + + # UPDATE trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_stratum_constraints_update + BEFORE UPDATE ON stratum_constraints + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'operation' + AND valid_value = NEW.operation + ) + THEN RAISE(ABORT, 'Invalid operation value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'constraint_variable' + AND valid_value = NEW.constraint_variable + ) + THEN RAISE(ABORT, 'Invalid constraint_variable value') + END; + END; + """)) + + # ============================================ + # Triggers for targets table + # ============================================ + + # INSERT trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_targets_insert + BEFORE INSERT ON targets + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'active' + AND valid_value = CAST(NEW.active AS TEXT) + ) + THEN RAISE(ABORT, 'Invalid active value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'period' + AND valid_value = CAST(NEW.period AS TEXT) + ) + THEN RAISE(ABORT, 'Invalid period value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'variable' + AND valid_value = NEW.variable + ) + THEN RAISE(ABORT, 'Invalid variable value') + END; + END; + """)) + + # UPDATE trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_targets_update + BEFORE UPDATE ON targets + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'active' + AND valid_value = CAST(NEW.active AS TEXT) + ) + THEN RAISE(ABORT, 'Invalid active value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'period' + AND valid_value = CAST(NEW.period AS TEXT) + ) + THEN RAISE(ABORT, 'Invalid period value') + END; + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'variable' + AND valid_value = NEW.variable + ) + THEN RAISE(ABORT, 'Invalid variable value') + END; + END; + """)) + + # ============================================ + # Triggers for sources table + # ============================================ + + # INSERT trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_sources_insert + BEFORE INSERT ON sources + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'type' + AND valid_value = NEW.type + ) + THEN RAISE(ABORT, 'Invalid source type value') + END; + END; + """)) + + # UPDATE trigger + conn.execute(text(""" + CREATE TRIGGER IF NOT EXISTS validate_sources_update + BEFORE UPDATE ON sources + BEGIN + SELECT CASE + WHEN NOT EXISTS ( + SELECT 1 FROM field_valid_values + WHERE field_name = 'type' + AND valid_value = NEW.type + ) + THEN RAISE(ABORT, 'Invalid source type value') + END; + END; + """)) + + conn.commit() + + logger.info("Validation triggers created successfully") + + def create_database( db_uri: str = f"sqlite:///{STORAGE_FOLDER / 'calibration' / 'policy_data.db'}", ): @@ -358,6 +512,9 @@ def create_database( """ engine = create_engine(db_uri) SQLModel.metadata.create_all(engine) + + create_validation_triggers(engine) + logger.info(f"Database and tables created successfully at {db_uri}") return engine diff --git a/policyengine_us_data/db/create_field_valid_values.py b/policyengine_us_data/db/create_field_valid_values.py new file mode 100644 index 00000000..ce3df005 --- /dev/null +++ b/policyengine_us_data/db/create_field_valid_values.py @@ -0,0 +1,117 @@ +"""Populate the field_valid_values table with valid values for semantic fields. + +This module provides functionality to populate the field_valid_values table +with static values (operations, active flags, periods) and dynamic values +(policyengine-us variables). +""" + +import logging +from sqlmodel import Session + +from policyengine_us.system import system + +from policyengine_us_data.db.create_database_tables import FieldValidValues + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) + +logger = logging.getLogger(__name__) + + +def populate_field_valid_values(session: Session) -> None: + """Populate the field_valid_values table with valid values for fields. + + This function populates the table with: + - Static values: operation, active, period + - Dynamic values: variable, constraint_variable (from policyengine-us) + + Args: + session: SQLModel Session instance for database operations. + """ + # Static values for operation field + operation_values = [ + ("operation", "==", "Equals"), + ("operation", "!=", "Not equals"), + ("operation", ">", "Greater than"), + ("operation", ">=", "Greater than or equal"), + ("operation", "<", "Less than"), + ("operation", "<=", "Less than or equal"), + ] + + # Static values for active field + active_values = [ + ("active", "0", "Inactive"), + ("active", "1", "Active"), + ] + + # Static values for period field (years) + period_values = [ + ("period", "2022", None), + ("period", "2023", None), + ("period", "2024", None), + ("period", "2025", None), + ] + + # Static values for type field (sources table) + source_type_values = [ + ("type", "administrative", "Administrative data sources"), + ("type", "survey", "Survey data sources"), + ("type", "synthetic", "Synthetic/generated data"), + ("type", "derived", "Derived from other sources"), + ("type", "hardcoded", "Values hardcoded into the system"), + ] + + # Add all static values + static_count = 0 + for field_name, valid_value, description in ( + operation_values + active_values + period_values + source_type_values + ): + session.add( + FieldValidValues( + field_name=field_name, + valid_value=valid_value, + description=description, + ) + ) + static_count += 1 + + # Dynamic values from policyengine-us + variable_count = 0 + for var_name in system.variables.keys(): + # Add for 'variable' field (targets table) + session.add( + FieldValidValues( + field_name="variable", + valid_value=var_name, + ) + ) + # Add for 'constraint_variable' field (stratum_constraints table) + session.add( + FieldValidValues( + field_name="constraint_variable", + valid_value=var_name, + ) + ) + variable_count += 1 + + session.commit() + + logger.info( + f"Populated field_valid_values with {static_count} static values " + f"and {variable_count * 2} variable values " + f"({variable_count} variables x 2 fields)" + ) + + +if __name__ == "__main__": + from sqlmodel import create_engine + + from policyengine_us_data.storage import STORAGE_FOLDER + + db_path = STORAGE_FOLDER / "calibration" / "policy_data.db" + engine = create_engine(f"sqlite:///{db_path}") + + with Session(engine) as session: + populate_field_valid_values(session) diff --git a/policyengine_us_data/db/create_initial_strata.py b/policyengine_us_data/db/create_initial_strata.py index f3edb1b4..4aadd7f3 100644 --- a/policyengine_us_data/db/create_initial_strata.py +++ b/policyengine_us_data/db/create_initial_strata.py @@ -15,6 +15,10 @@ save_json, load_json, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) logger = logging.getLogger(__name__) @@ -163,13 +167,23 @@ def main(): notes=state_name, stratum_group_id=1, ) - state_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="state_fips", + # Validate constraints before adding + state_constraints = [ + Constraint( + variable="state_fips", operation="==", value=str(state_fips), ) ] + ensure_consistent_constraint_set(state_constraints) + state_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in state_constraints + ] session.add(state_stratum) session.flush() state_stratum_ids[state_fips] = state_stratum.stratum_id @@ -185,13 +199,23 @@ def main(): notes=f"{name} (CD GEOID {cd_geoid})", stratum_group_id=1, ) - cd_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="congressional_district_geoid", + # Validate constraints before adding + cd_constraints = [ + Constraint( + variable="congressional_district_geoid", operation="==", value=str(cd_geoid), ) ] + ensure_consistent_constraint_set(cd_constraints) + cd_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in cd_constraints + ] session.add(cd_stratum) session.commit() diff --git a/policyengine_us_data/db/etl_age.py b/policyengine_us_data/db/etl_age.py index 39ffedf2..ce89efda 100644 --- a/policyengine_us_data/db/etl_age.py +++ b/policyengine_us_data/db/etl_age.py @@ -8,7 +8,6 @@ Stratum, StratumConstraint, Target, - SourceType, ) from policyengine_us_data.utils.census import get_census_docs, pull_acs_table from policyengine_us_data.utils.db import parse_ucgid, get_geographic_strata @@ -17,6 +16,10 @@ get_or_create_variable_group, get_or_create_variable_metadata, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) LABEL_TO_SHORT = { "Estimate!!Total!!Total population!!AGE!!Under 5 years": "0-4", @@ -116,7 +119,7 @@ def load_age_data(df_long, geo, year): census_source = get_or_create_source( session, name="Census ACS Table S0101", - source_type=SourceType.SURVEY, + source_type="survey", vintage=f"{year} ACS 5-year estimates", description="American Community Survey Age and Sex demographics", url="https://data.census.gov/", @@ -219,22 +222,22 @@ def load_age_data(df_long, geo, year): notes=note, ) - # Create constraints including both age and geographic for uniqueness - new_stratum.constraints_rel = [] + # Build constraint list for validation + constraint_list = [] # Add geographic constraints based on level if geo_info["type"] == "state": - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="state_fips", + constraint_list.append( + Constraint( + variable="state_fips", operation="==", value=str(geo_info["state_fips"]), ) ) elif geo_info["type"] == "district": - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="congressional_district_geoid", + constraint_list.append( + Constraint( + variable="congressional_district_geoid", operation="==", value=str(geo_info["congressional_district_geoid"]), ) @@ -242,9 +245,9 @@ def load_age_data(df_long, geo, year): # For national level, no geographic constraint needed # Add age constraints - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="age", + constraint_list.append( + Constraint( + variable="age", operation=">", value=str(row["age_greater_than"]), ) @@ -252,14 +255,25 @@ def load_age_data(df_long, geo, year): age_lt_value = row["age_less_than"] if not np.isinf(age_lt_value): - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="age", + constraint_list.append( + Constraint( + variable="age", operation="<", value=str(row["age_less_than"]), ) ) + # Validate constraints before adding + ensure_consistent_constraint_set(constraint_list) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in constraint_list + ] + # Create the Target and link it to the parent. new_stratum.targets_rel.append( Target( diff --git a/policyengine_us_data/db/etl_irs_soi.py b/policyengine_us_data/db/etl_irs_soi.py index ed4da4e5..3fd6593e 100644 --- a/policyengine_us_data/db/etl_irs_soi.py +++ b/policyengine_us_data/db/etl_irs_soi.py @@ -19,7 +19,6 @@ Stratum, StratumConstraint, Target, - SourceType, ) from policyengine_us_data.utils.db import ( get_stratum_by_id, @@ -38,6 +37,10 @@ from policyengine_us_data.storage.calibration_targets.make_district_mapping import ( get_district_mapping, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) """See the 22incddocguide.docx manual from the IRS SOI""" # Language in the doc: '$10,000 under $25,000' means >= $10,000 and < $25,000 @@ -394,7 +397,7 @@ def load_soi_data(long_dfs, year): irs_source = get_or_create_source( session, name="IRS Statistics of Income", - source_type=SourceType.ADMINISTRATIVE, + source_type="administrative", vintage=f"{year} Tax Year", description="IRS Statistics of Income administrative tax data", url="https://www.irs.gov/statistics", @@ -589,13 +592,23 @@ def load_soi_data(long_dfs, year): stratum_group_id=2, # Filer population group notes="United States - Tax Filers", ) - national_filer_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + # Validate constraints before adding + nat_filer_constraints = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ) ] + ensure_consistent_constraint_set(nat_filer_constraints) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in nat_filer_constraints + ] session.add(national_filer_stratum) session.flush() @@ -619,18 +632,28 @@ def load_soi_data(long_dfs, year): stratum_group_id=2, # Filer population group notes=f"State FIPS {state_fips} - Tax Filers", ) - state_filer_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + # Validate constraints before adding + state_filer_constraints = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="state_fips", + Constraint( + variable="state_fips", operation="==", value=str(state_fips), ), ] + ensure_consistent_constraint_set(state_filer_constraints) + state_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in state_filer_constraints + ] session.add(state_filer_stratum) session.flush() @@ -657,18 +680,28 @@ def load_soi_data(long_dfs, year): stratum_group_id=2, # Filer population group notes=f"Congressional District {district_geoid} - Tax Filers", ) - district_filer_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + # Validate constraints before adding + district_filer_constraints = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="congressional_district_geoid", + Constraint( + variable="congressional_district_geoid", operation="==", value=str(district_geoid), ), ] + ensure_consistent_constraint_set(district_filer_constraints) + district_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in district_filer_constraints + ] session.add(district_filer_stratum) session.flush() @@ -694,12 +727,14 @@ def load_soi_data(long_dfs, year): geo_info = parse_ucgid(ucgid_i) # Determine parent stratum based on geographic level - use filer strata not geo strata + # Build constraint list for validation + constraint_list = [] if geo_info["type"] == "national": parent_stratum_id = filer_strata["national"] note = f"National EITC received with {n_children} children (filers)" - constraints = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + constraint_list = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ) @@ -709,14 +744,14 @@ def load_soi_data(long_dfs, year): geo_info["state_fips"] ] note = f"State FIPS {geo_info['state_fips']} EITC received with {n_children} children (filers)" - constraints = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + constraint_list = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="state_fips", + Constraint( + variable="state_fips", operation="==", value=str(geo_info["state_fips"]), ), @@ -726,19 +761,40 @@ def load_soi_data(long_dfs, year): geo_info["congressional_district_geoid"] ] note = f"Congressional District {geo_info['congressional_district_geoid']} EITC received with {n_children} children (filers)" - constraints = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + constraint_list = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="congressional_district_geoid", + Constraint( + variable="congressional_district_geoid", operation="==", value=str(geo_info["congressional_district_geoid"]), ), ] + # Add EITC child count constraint + if n_children == "3+": + constraint_list.append( + Constraint( + variable="eitc_child_count", + operation=">", + value="2", + ) + ) + else: + constraint_list.append( + Constraint( + variable="eitc_child_count", + operation="==", + value=f"{n_children}", + ) + ) + + # Validate constraints before adding + ensure_consistent_constraint_set(constraint_list) + # Check if stratum already exists existing_stratum = ( session.query(Stratum) @@ -759,23 +815,14 @@ def load_soi_data(long_dfs, year): notes=note, ) - new_stratum.constraints_rel = constraints - if n_children == "3+": - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation=">", - value="2", - ) - ) - else: - new_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="eitc_child_count", - operation="==", - value=f"{n_children}", - ) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, ) + for c in constraint_list + ] session.add(new_stratum) session.flush() @@ -893,42 +940,33 @@ def load_soi_data(long_dfs, year): if existing_stratum: child_stratum = existing_stratum else: - # Create new child stratum with constraint - child_stratum = Stratum( - parent_stratum_id=parent_stratum_id, - stratum_group_id=stratum_group_id, - notes=note, - ) - - # Add constraints - filer status and this IRS variable must be positive - child_stratum.constraints_rel.extend( - [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", - operation="==", - value="1", - ), - StratumConstraint( - constraint_variable=amount_variable_name, - operation=">", - value="0", - ), - ] - ) + # Build constraint list for validation + irs_constraint_list = [ + Constraint( + variable="tax_unit_is_filer", + operation="==", + value="1", + ), + Constraint( + variable=amount_variable_name, + operation=">", + value="0", + ), + ] # Add geographic constraints if applicable if geo_info["type"] == "state": - child_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="state_fips", + irs_constraint_list.append( + Constraint( + variable="state_fips", operation="==", value=str(geo_info["state_fips"]), ) ) elif geo_info["type"] == "district": - child_stratum.constraints_rel.append( - StratumConstraint( - constraint_variable="congressional_district_geoid", + irs_constraint_list.append( + Constraint( + variable="congressional_district_geoid", operation="==", value=str( geo_info["congressional_district_geoid"] @@ -936,6 +974,25 @@ def load_soi_data(long_dfs, year): ) ) + # Validate constraints before adding + ensure_consistent_constraint_set(irs_constraint_list) + + # Create new child stratum with constraint + child_stratum = Stratum( + parent_stratum_id=parent_stratum_id, + stratum_group_id=stratum_group_id, + notes=note, + ) + + child_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in irs_constraint_list + ] + session.add(child_stratum) session.flush() @@ -1055,30 +1112,39 @@ def load_soi_data(long_dfs, year): ) if not nat_stratum: + # Build constraint list for validation + nat_agi_constraints = [ + Constraint( + variable="tax_unit_is_filer", + operation="==", + value="1", + ), + Constraint( + variable="adjusted_gross_income", + operation=">=", + value=str(agi_income_lower), + ), + Constraint( + variable="adjusted_gross_income", + operation="<", + value=str(agi_income_upper), + ), + ] + ensure_consistent_constraint_set(nat_agi_constraints) + nat_stratum = Stratum( parent_stratum_id=filer_strata["national"], stratum_group_id=3, # Income/AGI strata group notes=note, ) - nat_stratum.constraints_rel.extend( - [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", - operation="==", - value="1", - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation=">=", - value=str(agi_income_lower), - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="<", - value=str(agi_income_upper), - ), - ] - ) + nat_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in nat_agi_constraints + ] session.add(nat_stratum) session.flush() @@ -1092,19 +1158,21 @@ def load_soi_data(long_dfs, year): geo_info = parse_ucgid(ucgid_i) person_count = agi_df.iloc[i][["target_value"]].values[0] + # Build constraint list for validation + agi_constraint_list = [] if geo_info["type"] == "state": parent_stratum_id = filer_strata["state"][ geo_info["state_fips"] ] note = f"State FIPS {geo_info['state_fips']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" - constraints = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + agi_constraint_list = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="state_fips", + Constraint( + variable="state_fips", operation="==", value=str(geo_info["state_fips"]), ), @@ -1114,14 +1182,14 @@ def load_soi_data(long_dfs, year): geo_info["congressional_district_geoid"] ] note = f"Congressional District {geo_info['congressional_district_geoid']} filers, AGI >= {agi_income_lower}, AGI < {agi_income_upper}" - constraints = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + agi_constraint_list = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ), - StratumConstraint( - constraint_variable="congressional_district_geoid", + Constraint( + variable="congressional_district_geoid", operation="==", value=str(geo_info["congressional_district_geoid"]), ), @@ -1129,6 +1197,25 @@ def load_soi_data(long_dfs, year): else: continue # Skip if not state or district (shouldn't happen, but defensive) + # Add AGI range constraints + agi_constraint_list.extend( + [ + Constraint( + variable="adjusted_gross_income", + operation=">=", + value=str(agi_income_lower), + ), + Constraint( + variable="adjusted_gross_income", + operation="<", + value=str(agi_income_upper), + ), + ] + ) + + # Validate constraints before adding + ensure_consistent_constraint_set(agi_constraint_list) + # Check if stratum already exists existing_stratum = ( session.query(Stratum) @@ -1148,21 +1235,14 @@ def load_soi_data(long_dfs, year): stratum_group_id=3, # Income/AGI strata group notes=note, ) - new_stratum.constraints_rel = constraints - new_stratum.constraints_rel.extend( - [ - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation=">=", - value=str(agi_income_lower), - ), - StratumConstraint( - constraint_variable="adjusted_gross_income", - operation="<", - value=str(agi_income_upper), - ), - ] - ) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in agi_constraint_list + ] session.add(new_stratum) session.flush() diff --git a/policyengine_us_data/db/etl_medicaid.py b/policyengine_us_data/db/etl_medicaid.py index ed184144..d31dcf7b 100644 --- a/policyengine_us_data/db/etl_medicaid.py +++ b/policyengine_us_data/db/etl_medicaid.py @@ -11,7 +11,6 @@ Stratum, StratumConstraint, Target, - SourceType, ) from policyengine_us_data.utils.census import ( STATE_ABBREV_TO_FIPS, @@ -30,6 +29,10 @@ load_json, save_bytes, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) logger = logging.getLogger(__name__) @@ -166,7 +169,7 @@ def load_medicaid_data(long_state, long_cd, year): admin_source = get_or_create_source( session, name="Medicaid T-MSIS", - source_type=SourceType.ADMINISTRATIVE, + source_type="administrative", vintage=f"{year} Final Report", description="Medicaid Transformed MSIS administrative enrollment data", url="https://data.medicaid.gov/", @@ -176,7 +179,7 @@ def load_medicaid_data(long_state, long_cd, year): survey_source = get_or_create_source( session, name="Census ACS Table S2704", - source_type=SourceType.SURVEY, + source_type="survey", vintage=f"{year} ACS 1-year estimates", description="American Community Survey health insurance coverage data", url="https://data.census.gov/", @@ -218,12 +221,20 @@ def load_medicaid_data(long_state, long_cd, year): stratum_group_id=5, # Medicaid strata group notes="National Medicaid Enrolled", ) + # Validate constraints before adding + nat_medicaid_constraints = [ + Constraint( + variable="medicaid_enrolled", operation="==", value="True" + ), + ] + ensure_consistent_constraint_set(nat_medicaid_constraints) nat_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="medicaid_enrolled", - operation="==", - value="True", - ), + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in nat_medicaid_constraints ] # No target at the national level is provided at this time. @@ -250,18 +261,26 @@ def load_medicaid_data(long_state, long_cd, year): stratum_group_id=5, # Medicaid strata group notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="state_fips", + # Validate constraints before adding + state_medicaid_constraints = [ + Constraint( + variable="state_fips", operation="==", value=str(state_fips), ), - StratumConstraint( - constraint_variable="medicaid_enrolled", - operation="==", - value="True", + Constraint( + variable="medicaid_enrolled", operation="==", value="True" ), ] + ensure_consistent_constraint_set(state_medicaid_constraints) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in state_medicaid_constraints + ] new_stratum.targets_rel.append( Target( variable="person_count", @@ -297,18 +316,26 @@ def load_medicaid_data(long_state, long_cd, year): stratum_group_id=5, # Medicaid strata group notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="congressional_district_geoid", + # Validate constraints before adding + cd_medicaid_constraints = [ + Constraint( + variable="congressional_district_geoid", operation="==", value=str(cd_geoid), ), - StratumConstraint( - constraint_variable="medicaid_enrolled", - operation="==", - value="True", + Constraint( + variable="medicaid_enrolled", operation="==", value="True" ), ] + ensure_consistent_constraint_set(cd_medicaid_constraints) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in cd_medicaid_constraints + ] new_stratum.targets_rel.append( Target( variable="person_count", diff --git a/policyengine_us_data/db/etl_national_targets.py b/policyengine_us_data/db/etl_national_targets.py index 7e02d6f0..26ea5b33 100644 --- a/policyengine_us_data/db/etl_national_targets.py +++ b/policyengine_us_data/db/etl_national_targets.py @@ -6,11 +6,14 @@ Stratum, StratumConstraint, Target, - SourceType, ) from policyengine_us_data.utils.db_metadata import ( get_or_create_source, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) def extract_national_targets(): @@ -305,11 +308,12 @@ def extract_national_targets(): # CBO projection targets - get for a specific year CBO_YEAR = 2023 # Year the CBO projections are for cbo_vars = [ - # Note: income_tax_positive matches CBO's receipts definition - # where refundable credit payments in excess of liability are + # Note: For income_tax, CBO's receipts definition counts only positive + # values - refundable credit payments in excess of liability are # classified as outlays, not negative receipts. See: # https://www.cbo.gov/publication/43767 - "income_tax_positive", + # We handle this by adding a constraint (income_tax >= 0) when loading. + "income_tax", "snap", "social_security", "ssi", @@ -318,7 +322,7 @@ def extract_national_targets(): # Mapping from target variable to CBO parameter name (when different) cbo_param_name_map = { - "income_tax_positive": "income_tax", # CBO param is income_tax + # No mapping needed - income_tax matches CBO param name } cbo_targets = [] @@ -390,17 +394,11 @@ def transform_national_targets(raw_targets): """ # Process direct sum targets (non-tax items and some CBO items) - # Note: income_tax_positive from CBO and eitc from Treasury need - # filer constraint cbo_non_tax = [ - t - for t in raw_targets["cbo_targets"] - if t["variable"] != "income_tax_positive" + t for t in raw_targets["cbo_targets"] if t["variable"] != "income_tax" ] cbo_tax = [ - t - for t in raw_targets["cbo_targets"] - if t["variable"] == "income_tax_positive" + t for t in raw_targets["cbo_targets"] if t["variable"] == "income_tax" ] all_direct_targets = raw_targets["direct_sum_targets"] + cbo_non_tax @@ -455,7 +453,7 @@ def load_national_targets( calibration_source = get_or_create_source( session, name="PolicyEngine Calibration Targets", - source_type=SourceType.HARDCODED, + source_type="hardcoded", vintage="Mixed (2023-2024)", description="National calibration targets from various authoritative sources", url=None, @@ -535,13 +533,23 @@ def load_national_targets( stratum_group_id=2, # Filer population group notes="United States - Tax Filers", ) - national_filer_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="tax_unit_is_filer", + # Validate constraints before adding + filer_constraints = [ + Constraint( + variable="tax_unit_is_filer", operation="==", value="1", ) ] + ensure_consistent_constraint_set(filer_constraints) + national_filer_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in filer_constraints + ] session.add(national_filer_stratum) session.flush() print("Created national filer stratum") @@ -549,12 +557,67 @@ def load_national_targets( # Add tax-related targets to filer stratum for _, target_data in tax_filer_df.iterrows(): target_year = target_data["year"] + variable_name = target_data["variable"] + + # NOTE: For income_tax, we need a special stratum with income_tax >= 0 + # constraint to match CBO's receipts definition (only positive + # values count as receipts; negative values from refundable + # credits are classified as outlays). See: + # https://www.cbo.gov/publication/43767 + # (unless income_tax_positive is added to policyengine-us) + if variable_name == "income_tax": + # Get or create stratum for positive income tax + positive_income_tax_stratum = ( + session.query(Stratum) + .filter( + Stratum.parent_stratum_id + == national_filer_stratum.stratum_id, + Stratum.notes + == "United States - Tax Filers with Positive Income Tax", + ) + .first() + ) + + if not positive_income_tax_stratum: + positive_income_tax_stratum = Stratum( + parent_stratum_id=national_filer_stratum.stratum_id, + stratum_group_id=2, # Filer population group + notes="United States - Tax Filers with Positive Income Tax", + ) + # Validate constraints before adding + pos_tax_constraints = [ + Constraint( + variable="income_tax", + operation=">=", + value="0", + ) + ] + ensure_consistent_constraint_set(pos_tax_constraints) + positive_income_tax_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in pos_tax_constraints + ] + session.add(positive_income_tax_stratum) + session.flush() + print( + "Created positive income tax stratum " + "(income_tax >= 0 constraint for CBO definition)" + ) + + target_stratum = positive_income_tax_stratum + else: + target_stratum = national_filer_stratum + # Check if target already exists existing_target = ( session.query(Target) .filter( - Target.stratum_id == national_filer_stratum.stratum_id, - Target.variable == target_data["variable"], + Target.stratum_id == target_stratum.stratum_id, + Target.variable == variable_name, Target.period == target_year, ) .first() @@ -573,12 +636,12 @@ def load_national_targets( # Update existing target existing_target.value = target_data["value"] existing_target.notes = combined_notes - print(f"Updated filer target: {target_data['variable']}") + print(f"Updated filer target: {variable_name}") else: # Create new target target = Target( - stratum_id=national_filer_stratum.stratum_id, - variable=target_data["variable"], + stratum_id=target_stratum.stratum_id, + variable=variable_name, period=target_year, value=target_data["value"], source_id=calibration_source.source_id, @@ -586,7 +649,7 @@ def load_national_targets( notes=combined_notes, ) session.add(target) - print(f"Added filer target: {target_data['variable']}") + print(f"Added filer target: {variable_name}") # Process conditional count targets (enrollment counts) for cond_target in conditional_targets: @@ -610,7 +673,7 @@ def load_national_targets( elif constraint_var == "ssn_card_type": stratum_group_id = 7 # SSN card type group stratum_notes = "National Undocumented Population" - constraint_operation = "=" + constraint_operation = "==" constraint_value = cond_target.get("constraint_value", "NONE") else: stratum_notes = f"National {constraint_var} Recipients" @@ -664,14 +727,23 @@ def load_national_targets( notes=stratum_notes, ) - # Add constraint - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable=constraint_var, + # Validate constraints before adding + cond_constraints = [ + Constraint( + variable=constraint_var, operation=constraint_operation, value=constraint_value, ) ] + ensure_consistent_constraint_set(cond_constraints) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in cond_constraints + ] # Add target new_stratum.targets_rel = [ diff --git a/policyengine_us_data/db/etl_snap.py b/policyengine_us_data/db/etl_snap.py index 48c1eb83..611d570c 100644 --- a/policyengine_us_data/db/etl_snap.py +++ b/policyengine_us_data/db/etl_snap.py @@ -15,7 +15,6 @@ StratumConstraint, Target, Source, - SourceType, VariableGroup, VariableMetadata, ) @@ -35,6 +34,10 @@ save_bytes, load_bytes, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) logger = logging.getLogger(__name__) @@ -169,7 +172,7 @@ def load_administrative_snap_data(df_states, year): admin_source = get_or_create_source( session, name="USDA FNS SNAP Data", - source_type=SourceType.ADMINISTRATIVE, + source_type="administrative", vintage=f"FY {year}", description="SNAP administrative data from USDA Food and Nutrition Service", url="https://www.fns.usda.gov/pd/supplemental-nutrition-assistance-program-snap", @@ -219,12 +222,18 @@ def load_administrative_snap_data(df_states, year): stratum_group_id=4, # SNAP strata group notes="National Received SNAP Benefits", ) + # Validate constraints before adding + nat_constraints = [ + Constraint(variable="snap", operation=">", value="0"), + ] + ensure_consistent_constraint_set(nat_constraints) nat_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="snap", - operation=">", - value="0", - ), + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in nat_constraints ] # No target at the national level is provided at this time. Keeping it # so that the state strata can have a parent stratum @@ -249,17 +258,23 @@ def load_administrative_snap_data(df_states, year): stratum_group_id=4, # SNAP strata group notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="state_fips", + # Validate constraints before adding + state_snap_constraints = [ + Constraint( + variable="state_fips", operation="==", value=str(state_fips), ), + Constraint(variable="snap", operation=">", value="0"), + ] + ensure_consistent_constraint_set(state_snap_constraints) + new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="snap", - operation=">", - value="0", - ), + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in state_snap_constraints ] # Two targets now. Same data source. Same stratum new_stratum.targets_rel.append( @@ -305,7 +320,7 @@ def load_survey_snap_data(survey_df, year, snap_stratum_lookup): survey_source = get_or_create_source( session, name="Census ACS Table S2201", - source_type=SourceType.SURVEY, + source_type="survey", vintage=f"{year} ACS 5-year estimates", description="American Community Survey SNAP/Food Stamps data", url="https://data.census.gov/", @@ -333,17 +348,23 @@ def load_survey_snap_data(survey_df, year, snap_stratum_lookup): notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="congressional_district_geoid", + # Validate constraints before adding + cd_snap_constraints = [ + Constraint( + variable="congressional_district_geoid", operation="==", value=str(cd_geoid), ), + Constraint(variable="snap", operation=">", value="0"), + ] + ensure_consistent_constraint_set(cd_snap_constraints) + new_stratum.constraints_rel = [ StratumConstraint( - constraint_variable="snap", - operation=">", - value="0", - ), + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in cd_snap_constraints ] new_stratum.targets_rel.append( Target( diff --git a/policyengine_us_data/db/etl_state_income_tax.py b/policyengine_us_data/db/etl_state_income_tax.py index df0f40a6..2d229779 100644 --- a/policyengine_us_data/db/etl_state_income_tax.py +++ b/policyengine_us_data/db/etl_state_income_tax.py @@ -21,7 +21,6 @@ StratumConstraint, Target, Source, - SourceType, VariableGroup, VariableMetadata, ) @@ -36,6 +35,10 @@ save_json, load_json, ) +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ensure_consistent_constraint_set, +) logger = logging.getLogger(__name__) @@ -261,7 +264,7 @@ def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict: source = get_or_create_source( session, name="Census Bureau Annual Survey of State Tax Collections", - source_type=SourceType.ADMINISTRATIVE, + source_type="administrative", url="https://www.census.gov/programs-surveys/stc.html", notes="Individual income tax collections by state", ) @@ -310,13 +313,23 @@ def load_state_income_tax_data(df: pd.DataFrame, year: int) -> dict: stratum_group_id=STRATUM_GROUP_ID_STATE_INCOME_TAX, notes=note, ) - new_stratum.constraints_rel = [ - StratumConstraint( - constraint_variable="state_fips", + # Validate constraints before adding + state_tax_constraints = [ + Constraint( + variable="state_fips", operation="==", value=state_fips, ), ] + ensure_consistent_constraint_set(state_tax_constraints) + new_stratum.constraints_rel = [ + StratumConstraint( + constraint_variable=c.variable, + operation=c.operation, + value=c.value, + ) + for c in state_tax_constraints + ] # Add target for state_income_tax total new_stratum.targets_rel.append( diff --git a/policyengine_us_data/tests/test_constraint_validation.py b/policyengine_us_data/tests/test_constraint_validation.py new file mode 100644 index 00000000..29920475 --- /dev/null +++ b/policyengine_us_data/tests/test_constraint_validation.py @@ -0,0 +1,238 @@ +""" +Unit tests for constraint validation logic. + +Tests cover: +- Valid range constraints +- Empty range detection +- Equality operations must be alone +- Conflicting lower/upper bounds +- Multiple variables (each validated independently) +""" + +import pytest + +from policyengine_us_data.utils.constraint_validation import ( + Constraint, + ConstraintValidationError, + ensure_consistent_constraint_set, +) + + +class TestValidRanges: + """Tests for valid range constraint combinations.""" + + def test_valid_range_ge_lt(self): + """age >= 25 AND age < 30 should pass.""" + constraints = [ + Constraint(variable="age", operation=">=", value="25"), + Constraint(variable="age", operation="<", value="30"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_valid_range_gt_le(self): + """age > 20 AND age <= 65 should pass.""" + constraints = [ + Constraint(variable="age", operation=">", value="20"), + Constraint(variable="age", operation="<=", value="65"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_valid_range_gt_lt(self): + """age > 0 AND age < 100 should pass.""" + constraints = [ + Constraint(variable="age", operation=">", value="0"), + Constraint(variable="age", operation="<", value="100"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_valid_range_ge_le(self): + """age >= 0 AND age <= 85 should pass.""" + constraints = [ + Constraint(variable="age", operation=">=", value="0"), + Constraint(variable="age", operation="<=", value="85"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + +class TestEmptyRanges: + """Tests for empty range detection.""" + + def test_empty_range_lower_greater_than_upper(self): + """age >= 50 AND age < 30 should fail (empty range).""" + constraints = [ + Constraint(variable="age", operation=">=", value="50"), + Constraint(variable="age", operation="<", value="30"), + ] + with pytest.raises(ConstraintValidationError, match="empty range"): + ensure_consistent_constraint_set(constraints) + + def test_empty_range_equal_bounds_not_inclusive(self): + """age > 30 AND age < 30 should fail (empty range).""" + constraints = [ + Constraint(variable="age", operation=">", value="30"), + Constraint(variable="age", operation="<", value="30"), + ] + with pytest.raises(ConstraintValidationError, match="empty range"): + ensure_consistent_constraint_set(constraints) + + def test_empty_range_equal_bounds_one_inclusive(self): + """age >= 30 AND age < 30 should fail (empty range).""" + constraints = [ + Constraint(variable="age", operation=">=", value="30"), + Constraint(variable="age", operation="<", value="30"), + ] + with pytest.raises(ConstraintValidationError, match="empty range"): + ensure_consistent_constraint_set(constraints) + + def test_valid_point_range_both_inclusive(self): + """age >= 30 AND age <= 30 should pass (valid point).""" + constraints = [ + Constraint(variable="age", operation=">=", value="30"), + Constraint(variable="age", operation="<=", value="30"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + +class TestEqualityOperations: + """Tests for equality operation rules.""" + + def test_equality_alone_is_valid(self): + """state_fips == '06' alone should pass.""" + constraints = [ + Constraint(variable="state_fips", operation="==", value="06"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_not_equal_alone_is_valid(self): + """state_fips != '72' alone should pass.""" + constraints = [ + Constraint(variable="state_fips", operation="!=", value="72"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_equality_with_range_fails(self): + """state_fips == '06' AND state_fips > '05' should fail.""" + constraints = [ + Constraint(variable="state_fips", operation="==", value="06"), + Constraint(variable="state_fips", operation=">", value="05"), + ] + with pytest.raises(ConstraintValidationError, match="cannot combine"): + ensure_consistent_constraint_set(constraints) + + def test_not_equal_with_range_fails(self): + """state_fips != '06' AND state_fips < '10' should fail.""" + constraints = [ + Constraint(variable="state_fips", operation="!=", value="06"), + Constraint(variable="state_fips", operation="<", value="10"), + ] + with pytest.raises(ConstraintValidationError, match="cannot combine"): + ensure_consistent_constraint_set(constraints) + + +class TestConflictingBounds: + """Tests for conflicting lower/upper bound detection.""" + + def test_conflicting_lower_bounds(self): + """age > 20 AND age >= 25 should fail.""" + constraints = [ + Constraint(variable="age", operation=">", value="20"), + Constraint(variable="age", operation=">=", value="25"), + ] + with pytest.raises( + ConstraintValidationError, match="conflicting lower bounds" + ): + ensure_consistent_constraint_set(constraints) + + def test_conflicting_upper_bounds(self): + """age < 50 AND age <= 45 should fail.""" + constraints = [ + Constraint(variable="age", operation="<", value="50"), + Constraint(variable="age", operation="<=", value="45"), + ] + with pytest.raises( + ConstraintValidationError, match="conflicting upper bounds" + ): + ensure_consistent_constraint_set(constraints) + + +class TestMultipleVariables: + """Tests for constraints on multiple variables.""" + + def test_multiple_variables_independent(self): + """age >= 25 AND state_fips == '06' should pass.""" + constraints = [ + Constraint(variable="age", operation=">=", value="25"), + Constraint(variable="state_fips", operation="==", value="06"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_multiple_variables_both_ranges(self): + """age >= 25 AND age < 65 AND income > 0 AND income < 50000 should pass.""" + constraints = [ + Constraint(variable="age", operation=">=", value="25"), + Constraint(variable="age", operation="<", value="65"), + Constraint(variable="income", operation=">", value="0"), + Constraint(variable="income", operation="<", value="50000"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_one_variable_invalid_other_valid(self): + """Invalid on one variable should fail even if other is valid.""" + constraints = [ + Constraint(variable="age", operation=">=", value="50"), + Constraint(variable="age", operation="<", value="30"), # Invalid + Constraint(variable="state_fips", operation="==", value="06"), + ] + with pytest.raises(ConstraintValidationError, match="empty range"): + ensure_consistent_constraint_set(constraints) + + +class TestNonNumericValues: + """Tests for non-numeric constraint values.""" + + def test_string_equality_valid(self): + """medicaid_enrolled == 'True' should pass.""" + constraints = [ + Constraint( + variable="medicaid_enrolled", operation="==", value="True" + ), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_string_values_skip_range_check(self): + """Non-numeric values should skip range validation.""" + constraints = [ + Constraint(variable="ssn_card_type", operation="==", value="NONE"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + +class TestEdgeCases: + """Tests for edge cases.""" + + def test_empty_constraint_list(self): + """Empty constraint list should pass.""" + constraints = [] + ensure_consistent_constraint_set(constraints) # No exception + + def test_single_lower_bound(self): + """Single lower bound should pass.""" + constraints = [ + Constraint(variable="snap", operation=">", value="0"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_single_upper_bound(self): + """Single upper bound should pass.""" + constraints = [ + Constraint(variable="income", operation="<", value="100000"), + ] + ensure_consistent_constraint_set(constraints) # No exception + + def test_infinity_bounds(self): + """AGI >= -inf AND AGI < 1 should pass.""" + constraints = [ + Constraint(variable="agi", operation=">=", value="-inf"), + Constraint(variable="agi", operation="<", value="1"), + ] + ensure_consistent_constraint_set(constraints) # No exception diff --git a/policyengine_us_data/tests/test_database.py b/policyengine_us_data/tests/test_database.py index c36ef828..b0f9958d 100644 --- a/policyengine_us_data/tests/test_database.py +++ b/policyengine_us_data/tests/test_database.py @@ -11,12 +11,19 @@ Target, create_database, ) +from policyengine_us_data.db.create_field_valid_values import ( + populate_field_valid_values, +) @pytest.fixture def engine(tmp_path): db_uri = f"sqlite:///{tmp_path/'test.db'}" - return create_database(db_uri) + eng = create_database(db_uri) + # Populate field_valid_values for trigger validation + with Session(eng) as session: + populate_field_valid_values(session) + return eng # TODO: Re-enable this test once database issues are resolved in PR #437 @@ -28,15 +35,15 @@ def test_stratum_hash_and_relationships(engine): stratum = Stratum(notes="test", stratum_group_id=0) stratum.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="equals", - value="0400000US30", + constraint_variable="state_fips", + operation="==", + value="30", ), StratumConstraint( - constraint_variable="age", operation="greater_than", value="20" + constraint_variable="age", operation=">", value="20" ), StratumConstraint( - constraint_variable="age", operation="less_than", value="65" + constraint_variable="age", operation="<", value="65" ), ] stratum.targets_rel = [ @@ -48,9 +55,9 @@ def test_stratum_hash_and_relationships(engine): "\n".join( sorted( [ - "ucgid_str|equals|0400000US30", - "age|greater_than|20", - "age|less_than|65", + "state_fips|==|30", + "age|>|20", + "age|<|65", ] ) ).encode("utf-8") @@ -66,9 +73,9 @@ def test_unique_definition_hash(engine): s1 = Stratum(stratum_group_id=0) s1.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="equals", - value="0400000US30", + constraint_variable="state_fips", + operation="==", + value="30", ) ] session.add(s1) @@ -76,9 +83,9 @@ def test_unique_definition_hash(engine): s2 = Stratum(stratum_group_id=0) s2.constraints_rel = [ StratumConstraint( - constraint_variable="ucgid_str", - operation="equals", - value="0400000US30", + constraint_variable="state_fips", + operation="==", + value="30", ) ] session.add(s2) diff --git a/policyengine_us_data/utils/constraint_validation.py b/policyengine_us_data/utils/constraint_validation.py new file mode 100644 index 00000000..c3e512c7 --- /dev/null +++ b/policyengine_us_data/utils/constraint_validation.py @@ -0,0 +1,165 @@ +""" +Constraint validation for stratum definitions. + +This module provides validation for constraint sets BEFORE they are +inserted into the database. This prevents logically inconsistent +constraints from ever being stored. + +Validation Rules: +1. Operation Compatibility (per constraint_variable): + - `==` and `!=` must be alone (cannot combine with other operations) + - `>` and `>=` cannot coexist (conflicting lower bounds) + - `<` and `<=` cannot coexist (conflicting upper bounds) + - `>` or `>=` can combine with `<` or `<=` to form valid ranges + +2. Value Checks (if operations are compatible): + - No empty ranges: lower bound must be < upper bound + - For equal bounds, both must be inclusive to be valid +""" + +from dataclasses import dataclass +from typing import List + + +@dataclass +class Constraint: + """A constraint to validate (before creating StratumConstraint).""" + + variable: str + operation: str + value: str + + +class ConstraintValidationError(Exception): + """Raised when constraint set is logically inconsistent.""" + + pass + + +# Operation compatibility groups +EQUALITY_OPS = {"==", "!="} +LOWER_BOUND_OPS = {">", ">="} +UPPER_BOUND_OPS = {"<", "<="} +RANGE_OPS = LOWER_BOUND_OPS | UPPER_BOUND_OPS + + +def ensure_consistent_constraint_set(constraints: List[Constraint]) -> None: + """ + Validate that a set of constraints is logically consistent. + + Call this BEFORE inserting constraints into the database. + + Args: + constraints: List of Constraint objects to validate. + + Raises: + ConstraintValidationError: If constraints are logically inconsistent. + + Example: + >>> constraints = [ + ... Constraint(variable="age", operation=">=", value="25"), + ... Constraint(variable="age", operation="<", value="30"), + ... ] + >>> ensure_consistent_constraint_set(constraints) # No exception + """ + # Group constraints by variable + by_variable: dict = {} + for c in constraints: + by_variable.setdefault(c.variable, []).append(c) + + for var_name, var_constraints in by_variable.items(): + _validate_variable_constraints(var_name, var_constraints) + + +def _validate_variable_constraints( + var_name: str, constraints: List[Constraint] +) -> None: + """Validate all constraints on a single variable.""" + operations = {c.operation for c in constraints} + + # Rule 1: Check operation compatibility + _check_operation_compatibility(var_name, operations) + + # Rule 2: If range operations, check for empty range + if operations & RANGE_OPS: + _check_range_validity(var_name, constraints) + + +def _check_operation_compatibility(var_name: str, operations: set) -> None: + """Check that operations on a variable are compatible.""" + has_equality = bool(operations & EQUALITY_OPS) + has_range = bool(operations & RANGE_OPS) + + # Equality ops must be alone + if has_equality: + if len(operations) > 1: + raise ConstraintValidationError( + f"{var_name}: '==' or '!=' cannot combine with other " + f"operations, found: {operations}" + ) + + # Cannot have both > and >= (conflicting lower bounds) + if ">" in operations and ">=" in operations: + raise ConstraintValidationError( + f"{var_name}: cannot have both '>' and '>=' " + "(conflicting lower bounds)" + ) + + # Cannot have both < and <= (conflicting upper bounds) + if "<" in operations and "<=" in operations: + raise ConstraintValidationError( + f"{var_name}: cannot have both '<' and '<=' " + "(conflicting upper bounds)" + ) + + +def _check_range_validity( + var_name: str, constraints: List[Constraint] +) -> None: + """Check that range constraints don't create an empty range.""" + lower_bound = float("-inf") + upper_bound = float("inf") + lower_inclusive = False + upper_inclusive = False + + for c in constraints: + try: + val = float(c.value) + except ValueError: + # Non-numeric value - skip range check + continue + + if c.operation == ">": + if val > lower_bound or ( + val == lower_bound and not lower_inclusive + ): + lower_bound = val + lower_inclusive = False + elif c.operation == ">=": + if val > lower_bound or (val == lower_bound and lower_inclusive): + lower_bound = val + lower_inclusive = True + elif c.operation == "<": + if val < upper_bound or ( + val == upper_bound and not upper_inclusive + ): + upper_bound = val + upper_inclusive = False + elif c.operation == "<=": + if val < upper_bound or (val == upper_bound and upper_inclusive): + upper_bound = val + upper_inclusive = True + + # Check for empty range + if lower_bound > upper_bound: + raise ConstraintValidationError( + f"{var_name}: empty range - lower bound {lower_bound} > " + f"upper bound {upper_bound}" + ) + if lower_bound == upper_bound and not ( + lower_inclusive and upper_inclusive + ): + raise ConstraintValidationError( + f"{var_name}: empty range - bounds equal at {lower_bound} " + "but not both inclusive" + ) diff --git a/policyengine_us_data/utils/db_metadata.py b/policyengine_us_data/utils/db_metadata.py index b3e63ebe..b15b75cc 100644 --- a/policyengine_us_data/utils/db_metadata.py +++ b/policyengine_us_data/utils/db_metadata.py @@ -7,7 +7,6 @@ from sqlmodel import Session, select from policyengine_us_data.db.create_database_tables import ( Source, - SourceType, VariableGroup, VariableMetadata, ) @@ -16,7 +15,7 @@ def get_or_create_source( session: Session, name: str, - source_type: SourceType, + source_type: str, vintage: Optional[str] = None, description: Optional[str] = None, url: Optional[str] = None,