Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
4c0f1c6
Shrink datasets
nikhilwoodruff Jul 10, 2025
6b2a56f
Move to package
nikhilwoodruff Jul 10, 2025
05ee7e4
Try L0
nikhilwoodruff Jul 10, 2025
e38c647
Format
nikhilwoodruff Jul 10, 2025
bdf3d6d
attempting to vectorize minimizing of ecps
juaristi22 Jul 11, 2025
03e5d0d
adding random sampling minimization strategy
juaristi22 Jul 11, 2025
cd0776c
add notebook with testing functionality (havent tested locally)
juaristi22 Jul 11, 2025
2c050fc
lint
juaristi22 Jul 11, 2025
ee98fc3
debugged 2nd cell: created path & removed optional parameters.
eccuraa Jul 12, 2025
f6d7f0f
few updates to the testing framework
juaristi22 Jul 14, 2025
a042a01
added CPS_2023 to lite mode generation
baogorek Jul 11, 2025
cabeb56
Fixed manual test
baogorek Jul 11, 2025
7b76afb
try again with locked version
baogorek Jul 11, 2025
4056df4
trying things
baogorek Jul 11, 2025
96c4c25
lint
baogorek Jul 11, 2025
e20c75c
trying 3.11.12
baogorek Jul 11, 2025
776eda8
now actually specifying py version
baogorek Jul 11, 2025
cd77179
pandas v
baogorek Jul 11, 2025
d0ce44d
small runner
baogorek Jul 11, 2025
eb96cd5
trying everything
baogorek Jul 11, 2025
59ff94e
relaxing python version in pyproject.toml
baogorek Jul 11, 2025
d3fa67b
putting things back in order.
baogorek Jul 11, 2025
273c48d
Use normal runner in PR tests
nikhilwoodruff Jul 12, 2025
8c2fbda
added the 3.11.12 pin
baogorek Jul 12, 2025
edb0945
cps.py
baogorek Jul 14, 2025
994ac15
adding diagnostics
baogorek Jul 14, 2025
341a355
lint
baogorek Jul 14, 2025
c2ab4b6
taking out bad targets
baogorek Jul 14, 2025
6f7a03a
fixing workflow arg passthrough
baogorek Jul 14, 2025
3dba2a2
deps and defaults
baogorek Jul 14, 2025
7710a4c
wrong pipeline for manual test
baogorek Jul 14, 2025
27f46fd
trying again to get the manual test to work
baogorek Jul 14, 2025
fef1eca
reverting to older workflow code
baogorek Jul 14, 2025
5eb1050
cleaning up enhanced_cps.py
baogorek Jul 14, 2025
1fb4318
Update package version
MaxGhenis Jul 14, 2025
a62328a
attempting to vectorize minimizing of ecps
juaristi22 Jul 11, 2025
6d3f8b4
add notebook with testing functionality (havent tested locally)
juaristi22 Jul 11, 2025
94cacde
few updates to the testing framework
juaristi22 Jul 14, 2025
a71530b
fix calibration for each approach
juaristi22 Jul 14, 2025
f146620
fixed testing framework
juaristi22 Jul 14, 2025
7a2d074
Merge branch 'main' into maria/shrink
juaristi22 Jul 14, 2025
68349f8
starting to collect results
eccuraa Jul 14, 2025
dedf49a
Resolved merge conflict by accepting incoming changes
eccuraa Jul 14, 2025
4d593b9
added functionality for running multiple L0/L1 penalty values & dataf…
eccuraa Jul 14, 2025
a8af62a
pulling data from files for plotting
eccuraa Jul 15, 2025
a917d35
deleted testing cell
eccuraa Jul 15, 2025
734f54f
current testing arena for Ben
eccuraa Jul 15, 2025
64c8149
not much new
eccuraa Jul 15, 2025
226b2d9
synthetic dataset
eccuraa Jul 15, 2025
6a8160b
committing before changing file
eccuraa Jul 15, 2025
842dfa6
Merge minimize.py from maria/ecps_minimization branch
eccuraa Jul 15, 2025
f815c7e
renaming to american naming (maria started it haha)
eccuraa Jul 15, 2025
096fb0f
more american spelling for debugging
eccuraa Jul 15, 2025
41980ac
initial visualization with synthetic data
eccuraa Jul 15, 2025
b3208db
full test arena?? (trying it now)
eccuraa Jul 15, 2025
4d8f60c
forgot a file
eccuraa Jul 15, 2025
112658f
added some headers, just need to add pruning
eccuraa Jul 15, 2025
fa9aa02
fixed a scraping bug & deleted synthetic data
eccuraa Jul 16, 2025
791f0d9
fixed a scraping bug
eccuraa Jul 16, 2025
9520b16
added pruning to L0, L1 approaches (and discovered candidate_loss app…
eccuraa Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
added:
- Enhanced CPS minimizing tests.
74 changes: 67 additions & 7 deletions policyengine_us_data/datasets/cps/enhanced_cps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,27 @@
torch = None


bad_targets = [
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household",
"nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household",
"nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse",
"nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse",
]


def reweight(
original_weights,
loss_matrix,
targets_array,
dropout_rate=0.05,
log_path="calibration_log.csv",
epochs=150,
log_path="calibration_log.csv",
penalty_approach=None,
penalty_weight=None,
):
target_names = np.array(loss_matrix.columns)
is_national = loss_matrix.columns.str.startswith("nation/")
Expand All @@ -46,8 +60,12 @@ def reweight(
np.log(original_weights), requires_grad=True, dtype=torch.float32
)

# TODO: replace this functionality from the microcalibrate package.
def loss(weights):
# TO DO: replace this with a call to the python reweight.py package.
def loss(
weights,
penalty_approach=penalty_approach,
penalty_weight=penalty_weight,
):
# Check for Nans in either the weights or the loss matrix
if torch.isnan(weights).any():
raise ValueError("Weights contain NaNs")
Expand All @@ -60,9 +78,51 @@ def loss(weights):
((estimate - targets_array) + 1) / (targets_array + 1)
) ** 2
rel_error_normalized = rel_error * normalisation_factor

if torch.isnan(rel_error_normalized).any():
raise ValueError("Relative error contains NaNs")
return rel_error_normalized.mean()

if penalty_approach is not None and penalty_weight is not None:
# L0 penalty (approximated with smooth function)
# Since L0 is non-differentiable, we use a smooth approximation
# Common approaches:

epsilon = 1e-3 # Threshold for "near zero"

# Option 1: Sigmoid approximation
if penalty_approach == "l0_sigmoid":
smoothed_l0 = torch.sigmoid(
(weights - epsilon) / (epsilon * 0.1)
).mean()

# Option 2: Log-sum penalty (smoother)
if penalty_approach == "l0_log":
smoothed_l0 = torch.log(1 + weights / epsilon).sum() / len(
weights
)

# Option 3: Exponential penalty
if penalty_approach == "l0_exp":
smoothed_l0 = (1 - torch.exp(-weights / epsilon)).mean()

if penalty_approach == "l1":
l1 = torch.mean(weights)
return rel_error_normalized.mean() + penalty_weight * l1

return rel_error_normalized.mean() + penalty_weight * smoothed_l0

else:
return rel_error_normalized.mean()

def prune_dataset(weights, epsilon=1e-3):
"""
Prune dataset samples based on learned weights.
Returns indices of samples to keep.
"""
importance_scores = weights.detach().cpu().numpy()
keep_indices = np.where(importance_scores > epsilon)[0]

return keep_indices

def dropout_weights(weights, p):
if p == 0:
Expand Down Expand Up @@ -207,9 +267,9 @@ def generate(self):
loss_matrix, targets_array = build_loss_matrix(
self.input_dataset, year
)
zero_mask = np.isclose(targets_array, 0.0, atol=0.1)

bad_mask = loss_matrix.columns.isin(bad_targets)
keep_mask_bool = ~(zero_mask | bad_mask)
keep_mask_bool = ~bad_mask
keep_idx = np.where(keep_mask_bool)[0]
loss_matrix_clean = loss_matrix.iloc[:, keep_idx]
targets_array_clean = targets_array[keep_idx]
Expand All @@ -220,7 +280,7 @@ def generate(self):
loss_matrix_clean,
targets_array_clean,
log_path="calibration_log.csv",
epochs=150,
epochs= 150,
)
data["household_weight"][year] = optimised_weights

Expand Down
1 change: 1 addition & 0 deletions policyengine_us_data/storage/upload_completed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def upload_datasets():
Pooled_3_Year_CPS_2023.file_path,
CPS_2023.file_path,
STORAGE_FOLDER / "small_enhanced_cps_2024.h5",
STORAGE_FOLDER / "enhanced_cps_2024_minified.h5",
]

for file_path in dataset_files:
Expand Down
5 changes: 0 additions & 5 deletions policyengine_us_data/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,6 @@ def build_loss_matrix(dataset: type, time_period):
# Convert to thousands for the target
targets_array.append(row["enrollment"])

print(
f"Targeting Medicaid enrollment for {row['state']} "
f"with target {row['enrollment']:.0f}k"
)

# State 10-year age targets

age_targets = pd.read_csv(STORAGE_FOLDER / "age_state.csv")
Expand Down
Loading