Skip to content

Commit adcf138

Browse files
authored
Merge pull request #383 from PolicyEngine/age-reshape
Age reshape
2 parents 045b585 + 690910f commit adcf138

File tree

8 files changed

+529
-287
lines changed

8 files changed

+529
-287
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
**/*.csv
66
**/_build
77
**/*.pkl
8+
**/*.db
89
venv
910

1011
## old (not clean) targets

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ documentation:
3838
jb clean docs && jb build docs
3939
python docs/add_plotly_to_book.py docs
4040

41+
database:
42+
python policyengine_us_data/db/create_database_tables.py
43+
python policyengine_us_data/db/load_age_targets.py
44+
45+
clean-database:
46+
rm *.db
4147

4248
data:
4349
python policyengine_us_data/utils/uprating.py

changelog_entry.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- bump: minor
2+
changes:
3+
added:
4+
- Added creation script to build relational database for targets
5+
- Refactored age targets load script to load the database

policyengine_us_data/db/__init__.py

Whitespace-only changes.
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import logging
2+
import hashlib
3+
from typing import List, Optional
4+
5+
from sqlalchemy import event, UniqueConstraint
6+
from sqlalchemy.orm.attributes import get_history
7+
8+
from sqlmodel import (
9+
Field,
10+
Relationship,
11+
SQLModel,
12+
create_engine,
13+
)
14+
15+
logging.basicConfig(
16+
level=logging.INFO,
17+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
18+
)
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class Stratum(SQLModel, table=True):
24+
"""Represents a unique population subgroup (stratum)."""
25+
26+
__tablename__ = "strata"
27+
__table_args__ = (
28+
UniqueConstraint("definition_hash", name="uq_strata_definition_hash"),
29+
)
30+
31+
stratum_id: Optional[int] = Field(
32+
default=None,
33+
primary_key=True,
34+
description="Unique identifier for the stratum.",
35+
)
36+
definition_hash: str = Field(
37+
sa_column_kwargs={
38+
"comment": "SHA-256 hash of the stratum's constraints."
39+
},
40+
max_length=64,
41+
)
42+
parent_stratum_id: Optional[int] = Field(
43+
default=None,
44+
foreign_key="strata.stratum_id",
45+
index=True,
46+
description="Identifier for a parent stratum, creating a hierarchy.",
47+
)
48+
stratum_group_id: Optional[int] = Field(
49+
default=None, description="Identifier for a group of related strata."
50+
)
51+
notes: Optional[str] = Field(
52+
default=None, description="Descriptive notes about the stratum."
53+
)
54+
55+
children_rel: List["Stratum"] = Relationship(
56+
back_populates="parent_rel",
57+
sa_relationship_kwargs={"remote_side": "Stratum.parent_stratum_id"},
58+
)
59+
parent_rel: Optional["Stratum"] = Relationship(
60+
back_populates="children_rel",
61+
sa_relationship_kwargs={"remote_side": "Stratum.stratum_id"},
62+
)
63+
constraints_rel: List["StratumConstraint"] = Relationship(
64+
back_populates="strata_rel",
65+
sa_relationship_kwargs={
66+
"cascade": "all, delete-orphan",
67+
"lazy": "joined",
68+
},
69+
)
70+
targets_rel: List["Target"] = Relationship(
71+
back_populates="strata_rel",
72+
sa_relationship_kwargs={"cascade": "all, delete-orphan"},
73+
)
74+
75+
76+
class StratumConstraint(SQLModel, table=True):
77+
"""Defines the rules that make up a stratum."""
78+
79+
__tablename__ = "stratum_constraints"
80+
81+
stratum_id: int = Field(foreign_key="strata.stratum_id", primary_key=True)
82+
constraint_variable: str = Field(
83+
primary_key=True,
84+
description="The variable the constraint applies to (e.g., 'age').",
85+
)
86+
operation: str = Field(
87+
primary_key=True,
88+
description="The comparison operator (e.g., 'greater_than_or_equal').",
89+
)
90+
value: str = Field(
91+
description="The value for the constraint rule (e.g., '25')."
92+
)
93+
notes: Optional[str] = Field(
94+
default=None, description="Optional notes about the constraint."
95+
)
96+
97+
strata_rel: Stratum = Relationship(back_populates="constraints_rel")
98+
99+
100+
class Target(SQLModel, table=True):
101+
"""Stores the data values for a specific stratum."""
102+
103+
__tablename__ = "targets"
104+
__table_args__ = (
105+
UniqueConstraint(
106+
"variable",
107+
"period",
108+
"stratum_id",
109+
"reform_id",
110+
name="_target_unique",
111+
),
112+
)
113+
114+
target_id: Optional[int] = Field(default=None, primary_key=True)
115+
variable: str = Field(
116+
description="A variable defined in policyengine-us (e.g., 'income_tax')."
117+
)
118+
period: int = Field(
119+
description="The time period for the data, typically a year."
120+
)
121+
stratum_id: int = Field(foreign_key="strata.stratum_id", index=True)
122+
reform_id: int = Field(
123+
default=0,
124+
description="Identifier for a policy reform scenario (0 for baseline).",
125+
)
126+
value: Optional[float] = Field(
127+
default=None, description="The numerical value of the target variable."
128+
)
129+
source_id: Optional[int] = Field(
130+
default=None, description="Identifier for the data source."
131+
)
132+
active: bool = Field(
133+
default=True,
134+
description="Flag to indicate if the record is currently active.",
135+
)
136+
tolerance: Optional[float] = Field(
137+
default=None,
138+
description="Allowed relative error as a percent (e.g., 25 for 25%).",
139+
)
140+
notes: Optional[str] = Field(
141+
default=None,
142+
description="Optional descriptive notes about the target row.",
143+
)
144+
145+
strata_rel: Stratum = Relationship(back_populates="targets_rel")
146+
147+
148+
# This SQLAlchemy event listener works directly with the SQLModel class
149+
@event.listens_for(Stratum, "before_insert")
150+
@event.listens_for(Stratum, "before_update")
151+
def calculate_definition_hash(mapper, connection, target: Stratum):
152+
"""
153+
Calculate and set the definition_hash before saving a Stratum instance.
154+
"""
155+
constraints_history = get_history(target, "constraints_rel")
156+
if not (
157+
constraints_history.has_changes() or target.definition_hash is None
158+
):
159+
return
160+
161+
if not target.constraints_rel: # Handle cases with no constraints
162+
target.definition_hash = hashlib.sha256(b"").hexdigest()
163+
return
164+
165+
constraint_strings = [
166+
f"{c.constraint_variable}|{c.operation}|{c.value}"
167+
for c in target.constraints_rel
168+
]
169+
170+
constraint_strings.sort()
171+
fingerprint_text = "\n".join(constraint_strings)
172+
h = hashlib.sha256(fingerprint_text.encode("utf-8"))
173+
target.definition_hash = h.hexdigest()
174+
logger.info(
175+
f"Set definition_hash for Stratum to '{target.definition_hash}'"
176+
)
177+
178+
179+
def create_database(db_uri="sqlite:///policy_data.db"):
180+
"""
181+
Creates a SQLite database and all the defined tables.
182+
183+
Args:
184+
db_uri (str): The connection string for the database.
185+
186+
Returns:
187+
An SQLAlchemy Engine instance connected to the database.
188+
"""
189+
engine = create_engine(db_uri)
190+
SQLModel.metadata.create_all(engine)
191+
logger.info(f"Database and tables created successfully at {db_uri}")
192+
return engine
193+
194+
195+
if __name__ == "__main__":
196+
engine = create_database()

0 commit comments

Comments
 (0)