diff --git a/.github/workflows/code_changes.yaml b/.github/workflows/code_changes.yaml index b752e953..6c619b40 100644 --- a/.github/workflows/code_changes.yaml +++ b/.github/workflows/code_changes.yaml @@ -56,6 +56,12 @@ jobs: with: name: calibration_log.csv path: calibration_log.csv + - name: Save minimized ECPS calibration log + uses: actions/upload-artifact@v4 + with: + name: minimized_enhanced_cps_2024_calibration_log.csv + path: minimized_enhanced_cps_2024_calibration_log.csv + if-no-files-found: ignore - name: Run tests run: pytest - name: Upload data diff --git a/.github/workflows/pr_code_changes.yaml b/.github/workflows/pr_code_changes.yaml index 4e30d089..4c2d6cbf 100644 --- a/.github/workflows/pr_code_changes.yaml +++ b/.github/workflows/pr_code_changes.yaml @@ -71,6 +71,7 @@ jobs: run: make download env: HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Build datasets run: make data @@ -82,6 +83,12 @@ jobs: with: name: calibration_log.csv path: calibration_log.csv + - name: Save minimized ECPS calibration log + uses: actions/upload-artifact@v4 + with: + name: minimized_enhanced_cps_2024_calibration_log.csv + path: minimized_enhanced_cps_2024_calibration_log.csv + if-no-files-found: ignore - name: Run tests run: pytest diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..725035b9 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Minimized Enhanced CPS. \ No newline at end of file diff --git a/policyengine_us_data/datasets/__init__.py b/policyengine_us_data/datasets/__init__.py index 87461837..c0f2c8fd 100644 --- a/policyengine_us_data/datasets/__init__.py +++ b/policyengine_us_data/datasets/__init__.py @@ -14,6 +14,8 @@ CensusCPS_2023, EnhancedCPS_2024, ReweightedCPS_2024, + MinimizedEnhancedCPS_2024, + SparseEnhancedCPS_2024, ) from .puf import PUF_2015, PUF_2021, PUF_2024, IRS_PUF_2015 from .acs import ACS_2022 diff --git a/policyengine_us_data/datasets/cps/enhanced_cps.py b/policyengine_us_data/datasets/cps/enhanced_cps.py index 7a471d40..984308bc 100644 --- a/policyengine_us_data/datasets/cps/enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/enhanced_cps.py @@ -1,10 +1,8 @@ from policyengine_core.data import Dataset import pandas as pd from policyengine_us_data.utils import ( - pe_to_soi, - get_soi, build_loss_matrix, - fmt, + HardConcrete, ) import numpy as np from typing import Type @@ -15,6 +13,10 @@ CPS_2024, ) import os +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) try: import torch @@ -22,101 +24,251 @@ torch = None +bad_targets = [ + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", +] + + def reweight( original_weights, loss_matrix, targets_array, dropout_rate=0.05, + epochs=500, log_path="calibration_log.csv", - epochs=150, + l0_lambda=1e-5, + init_mean=0.999, + temperature=0.5, + sparse=False, ): + if loss_matrix.shape[1] == 0: + raise ValueError("loss_matrix has no columns after filtering") + + # Store column names before converting to tensor target_names = np.array(loss_matrix.columns) is_national = loss_matrix.columns.str.startswith("nation/") - loss_matrix = torch.tensor(loss_matrix.values, dtype=torch.float32) + + # Keep numpy versions for final diagnostics + loss_matrix_numpy = loss_matrix.values + targets_array_numpy = np.array(targets_array) + + # Convert to tensors for training + loss_matrix_tensor = torch.tensor(loss_matrix_numpy, dtype=torch.float32) + targets_array_tensor = torch.tensor( + targets_array_numpy, dtype=torch.float32 + ) + + # Compute normalization factors nation_normalisation_factor = is_national * (1 / is_national.sum()) state_normalisation_factor = ~is_national * (1 / (~is_national).sum()) normalisation_factor = np.where( is_national, nation_normalisation_factor, state_normalisation_factor ) - normalisation_factor = torch.tensor( + normalisation_factor_tensor = torch.tensor( normalisation_factor, dtype=torch.float32 ) - targets_array = torch.tensor(targets_array, dtype=torch.float32) + inv_mean_normalisation = 1 / np.mean(normalisation_factor) + + # Initialize weights weights = torch.tensor( np.log(original_weights), requires_grad=True, dtype=torch.float32 ) - # TODO: replace this functionality from the microcalibrate package. def loss(weights): - # Check for Nans in either the weights or the loss matrix if torch.isnan(weights).any(): raise ValueError("Weights contain NaNs") - if torch.isnan(loss_matrix).any(): + if torch.isnan(loss_matrix_tensor).any(): raise ValueError("Loss matrix contains NaNs") - estimate = weights @ loss_matrix + + estimate = weights @ loss_matrix_tensor + if torch.isnan(estimate).any(): raise ValueError("Estimate contains NaNs") + rel_error = ( - ((estimate - targets_array) + 1) / (targets_array + 1) + ((estimate - targets_array_tensor) + 1) + / (targets_array_tensor + 1) ) ** 2 - rel_error_normalized = rel_error * normalisation_factor + rel_error_normalized = ( + inv_mean_normalisation * rel_error * normalisation_factor_tensor + ) + if torch.isnan(rel_error_normalized).any(): raise ValueError("Relative error contains NaNs") + return rel_error_normalized.mean() def dropout_weights(weights, p): if p == 0: return weights - # Replace p% of the weights with the mean value of the rest of them mask = torch.rand_like(weights) < p mean = weights[~mask].mean() masked_weights = weights.clone() masked_weights[mask] = mean return masked_weights - optimizer = torch.optim.Adam([weights], lr=3e-1) - from tqdm import trange - - start_loss = None - - iterator = trange(epochs) - performance = pd.DataFrame() - for i in iterator: - optimizer.zero_grad() - weights_ = dropout_weights(weights, dropout_rate) - l = loss(torch.exp(weights_)) - if (log_path is not None) and (i % 10 == 0): - estimates = torch.exp(weights) @ loss_matrix - estimates = estimates.detach().numpy() - df = pd.DataFrame( - { - "target_name": target_names, - "estimate": estimates, - "target": targets_array.detach().numpy(), - } + def compute_diagnostics(final_weights, label=""): + """Helper function to compute and log diagnostics""" + estimate = final_weights @ loss_matrix_numpy + rel_error = ( + ((estimate - targets_array_numpy) + 1) / (targets_array_numpy + 1) + ) ** 2 + within_10_percent_mask = np.abs(estimate - targets_array_numpy) <= ( + 0.10 * np.abs(targets_array_numpy) + ) + percent_within_10 = np.mean(within_10_percent_mask) * 100 + + logger.info( + f"\n\n---{label} Solutions: reweighting quick diagnostics----\n" + ) + logger.info( + f"{np.sum(final_weights == 0)} are zero, {np.sum(final_weights != 0)} weights are nonzero" + ) + logger.info( + f"rel_error: min: {np.min(rel_error):.2f}\n" + f"max: {np.max(rel_error):.2f}\n" + f"mean: {np.mean(rel_error):.2f}\n" + f"median: {np.median(rel_error):.2f}\n" + f"Within 10% of target: {percent_within_10:.2f}%" + ) + logger.info("Relative error over 100% for:") + for i in np.where(rel_error > 1)[0]: + logger.info(f"target_name: {target_names[i]}") + logger.info(f"target_value: {targets_array_numpy[i]}") + logger.info(f"estimate_value: {estimate[i]}") + logger.info(f"has rel_error: {rel_error[i]:.2f}\n") + logger.info("---End of reweighting quick diagnostics------") + + if not sparse: + # Dense training + optimizer = torch.optim.Adam([weights], lr=3e-1) + from tqdm import trange + + start_loss = None + iterator = trange(epochs) + performance = pd.DataFrame() + + for i in iterator: + optimizer.zero_grad() + weights_ = dropout_weights(weights, dropout_rate) + l = loss(torch.exp(weights_)) + + if (log_path is not None) and (i % 10 == 0): + with torch.no_grad(): + estimates = ( + torch.exp(weights) @ loss_matrix_tensor + ).numpy() + df = pd.DataFrame( + { + "target_name": target_names, + "estimate": estimates, + "target": targets_array_numpy, + } + ) + df["epoch"] = i + df["error"] = df.estimate - df.target + df["rel_error"] = df.error / df.target + df["abs_error"] = df.error.abs() + df["rel_abs_error"] = df.rel_error.abs() + df["loss"] = df.rel_abs_error**2 + performance = pd.concat([performance, df], ignore_index=True) + + if (log_path is not None) and (i % 1000 == 0): + performance.to_csv(log_path, index=False) + + if start_loss is None: + start_loss = l.item() + loss_rel_change = (l.item() - start_loss) / start_loss + + l.backward() + iterator.set_postfix( + {"loss": l.item(), "loss_rel_change": loss_rel_change} ) - df["epoch"] = i - df["error"] = df.estimate - df.target - df["rel_error"] = df.error / df.target - df["abs_error"] = df.error.abs() - df["rel_abs_error"] = df.rel_error.abs() - df["loss"] = df.rel_abs_error**2 - performance = pd.concat([performance, df], ignore_index=True) - - if (log_path is not None) and (i % 1000 == 0): + optimizer.step() + + if log_path is not None: performance.to_csv(log_path, index=False) - if start_loss is None: - start_loss = l.item() - loss_rel_change = (l.item() - start_loss) / start_loss - l.backward() - iterator.set_postfix( - {"loss": l.item(), "loss_rel_change": loss_rel_change} + + final_weights_dense = torch.exp(weights).detach().numpy() + compute_diagnostics(final_weights_dense, "Dense") + return final_weights_dense + + else: + # Sparse training + weights = torch.tensor( + np.log(original_weights), requires_grad=True, dtype=torch.float32 + ) + gates = HardConcrete( + len(original_weights), init_mean=init_mean, temperature=temperature ) - optimizer.step() + + optimizer = torch.optim.Adam( + [weights] + list(gates.parameters()), lr=3e-1 + ) + from tqdm import trange + + start_loss = None + iterator = trange(epochs) + performance = pd.DataFrame() + + for i in iterator: + optimizer.zero_grad() + weights_ = dropout_weights(weights, dropout_rate) + masked = torch.exp(weights_) * gates() + l_main = loss(masked) + l = l_main + l0_lambda * gates.get_penalty() + + if (log_path is not None) and (i % 10 == 0): + gates.eval() + with torch.no_grad(): + estimates = ( + (torch.exp(weights) * gates()) @ loss_matrix_tensor + ).numpy() + gates.train() + + df = pd.DataFrame( + { + "target_name": target_names, + "estimate": estimates, + "target": targets_array_numpy, + } + ) + df["epoch"] = i + df["error"] = df.estimate - df.target + df["rel_error"] = df.error / df.target + df["abs_error"] = df.error.abs() + df["rel_abs_error"] = df.rel_error.abs() + df["loss"] = df.rel_abs_error**2 + performance = pd.concat([performance, df], ignore_index=True) + + if (log_path is not None) and (i % 1000 == 0): + performance.to_csv(log_path, index=False) + + if start_loss is None: + start_loss = l.item() + loss_rel_change = (l.item() - start_loss) / start_loss + + l.backward() + iterator.set_postfix( + {"loss": l.item(), "loss_rel_change": loss_rel_change} + ) + optimizer.step() + if log_path is not None: performance.to_csv(log_path, index=False) - return torch.exp(weights).detach().numpy() + gates.eval() + final_weights_sparse = (torch.exp(weights) * gates()).detach().numpy() + compute_diagnostics(final_weights_sparse, "Sparse") + + return final_weights_sparse def train_previous_year_income_model(): @@ -180,36 +332,14 @@ def generate(self): 1, 0.1, len(original_weights) ) - bad_targets = [ - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", - "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", - "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", - "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", - "state/RI/adjusted_gross_income/amount/-inf_1", - "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All", - ] - # Run the optimization procedure to get (close to) minimum loss weights for year in range(self.start_year, self.end_year + 1): loss_matrix, targets_array = build_loss_matrix( self.input_dataset, year ) - zero_mask = np.isclose(targets_array, 0.0, atol=0.1) + bad_mask = loss_matrix.columns.isin(bad_targets) - keep_mask_bool = ~(zero_mask | bad_mask) + keep_mask_bool = ~bad_mask keep_idx = np.where(keep_mask_bool)[0] loss_matrix_clean = loss_matrix.iloc[:, keep_idx] targets_array_clean = targets_array[keep_idx] @@ -224,26 +354,6 @@ def generate(self): ) data["household_weight"][year] = optimised_weights - print("\n\n---reweighting quick diagnostics----\n") - estimate = optimised_weights @ loss_matrix_clean - rel_error = ( - ((estimate - targets_array_clean) + 1) - / (targets_array_clean + 1) - ) ** 2 - print( - f"rel_error: min: {np.min(rel_error):.2f}, " - f"max: {np.max(rel_error):.2f} " - f"mean: {np.mean(rel_error):.2f}, " - f"median: {np.median(rel_error):.2f}" - ) - print("Relative error over 100% for:") - for i in np.where(rel_error > 1)[0]: - print(f"target_name: {loss_matrix_clean.columns[i]}") - print(f"target_value: {targets_array_clean[i]}") - print(f"estimate_value: {estimate[i]}") - print(f"has rel_error: {rel_error[i]:.2f}\n") - print("---End of reweighting quick diagnostics------") - self.save_dataset(data) @@ -276,6 +386,124 @@ def generate(self): self.save_dataset(data) +class MinimizedEnhancedCPS_2024(EnhancedCPS): + input_dataset = ExtendedCPS_2024 + start_year = 2024 + end_year = 2024 + name = "minimized_enhanced_cps_2024" + label = "Minimized Enhanced CPS 2024" + file_path = STORAGE_FOLDER / "minimized_enhanced_cps_2024.h5" + url = ( + "hf://policyengine/policyengine-us-data/minimized_enhanced_cps_2024.h5" + ) + + def generate(self): + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=self.input_dataset) + data = sim.dataset.load_dataset() + data["household_weight"] = {} + original_weights = sim.calculate("household_weight") + original_weights = original_weights.values + np.random.normal( + 1, 0.1, len(original_weights) + ) + + # Run the optimization procedure to get (close to) minimum loss weights + for year in range(self.start_year, self.end_year + 1): + loss_matrix, targets_array = build_loss_matrix( + self.input_dataset, year + ) + + bad_mask = loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~bad_mask + keep_idx = np.where(keep_mask_bool)[0] + + # Check if filtering would remove all columns + if len(keep_idx) == 0: + print( + "WARNING: bad_targets filtering would remove all columns, using all columns instead" + ) + keep_idx = np.arange(loss_matrix.shape[1]) + targets_array_clean = targets_array + loss_matrix_clean = loss_matrix + else: + loss_matrix_clean = loss_matrix.iloc[:, keep_idx] + targets_array_clean = targets_array[keep_idx] + assert loss_matrix_clean.shape[1] == targets_array_clean.size + + from policyengine_us_data.utils.minimize import ( + candidate_loss_contribution, + minimize_dataset, + ) + + minimize_dataset( + self.input_dataset, + self.file_path, + minimization_function=candidate_loss_contribution, + loss_matrix=loss_matrix_clean, + targets=targets_array_clean, + loss_rel_change_max=[0.1], # maximum relative change in loss + count_iterations=6, + view_fraction_per_iteration=0.4, + fraction_remove_per_iteration=0.1, + ) + + +class SparseEnhancedCPS_2024(EnhancedCPS): + input_dataset = ExtendedCPS_2024 + start_year = 2024 + end_year = 2024 + name = "sparse_enhanced_cps_2024" + label = "Sparse Enhanced CPS 2024" + file_path = STORAGE_FOLDER / "sparse_enhanced_cps_2024.h5" + url = "hf://policyengine/policyengine-us-data/sparse_enhanced_cps_2024.h5" + + def generate(self): + from policyengine_us import Microsimulation + from policyengine_us_data.utils.minimize import ( + create_calibration_log_file, + ) + + sim = Microsimulation(dataset=self.input_dataset) + data = sim.dataset.load_dataset() + data["household_weight"] = {} + original_weights = sim.calculate("household_weight") + original_weights = original_weights.values + np.random.normal( + 1, 0.1, len(original_weights) + ) + + # Run the optimization procedure to get (close to) minimum loss weights + for year in range(self.start_year, self.end_year + 1): + loss_matrix, targets_array = build_loss_matrix( + self.input_dataset, year + ) + + bad_mask = loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~bad_mask + keep_idx = np.where(keep_mask_bool)[0] + loss_matrix_clean = loss_matrix.iloc[:, keep_idx] + targets_array_clean = targets_array[keep_idx] + assert loss_matrix_clean.shape[1] == targets_array_clean.size + + optimised_weights = reweight( + original_weights, + loss_matrix_clean, + targets_array_clean, + log_path="calibration_log.csv", + epochs=150, + sparse=True, + ) + data["household_weight"][year] = optimised_weights + # Also save as sparse weights for small_enhanced_cps.py + if "household_sparse_weight" not in data: + data["household_sparse_weight"] = {} + data["household_sparse_weight"][year] = optimised_weights + + self.save_dataset(data) + + create_calibration_log_file(self.file_path) + + class EnhancedCPS_2024(EnhancedCPS): input_dataset = ExtendedCPS_2024 start_year = 2024 @@ -288,3 +516,5 @@ class EnhancedCPS_2024(EnhancedCPS): if __name__ == "__main__": EnhancedCPS_2024().generate() + # MinimizedEnhancedCPS_2024().generate() + SparseEnhancedCPS_2024().generate() diff --git a/policyengine_us_data/datasets/cps/small_enhanced_cps.py b/policyengine_us_data/datasets/cps/small_enhanced_cps.py index 976725d9..db13b770 100644 --- a/policyengine_us_data/datasets/cps/small_enhanced_cps.py +++ b/policyengine_us_data/datasets/cps/small_enhanced_cps.py @@ -1,5 +1,8 @@ +import pandas as pd import numpy as np +from policyengine_core.data.dataset import Dataset + def create_small_ecps(): from policyengine_us import Microsimulation @@ -37,6 +40,122 @@ def create_small_ecps(): grp.create_dataset(str(period), data=values) +def create_sparse_ecps(): + from policyengine_us import Microsimulation + from policyengine_us_data.datasets import SparseEnhancedCPS_2024 + from policyengine_us_data.storage import STORAGE_FOLDER + from policyengine_core.enums import Enum + + time_period = 2024 + + ecps = SparseEnhancedCPS_2024() + + # Check if sparse weights exist, if not generate them + try: + h5 = ecps.load() + sparse_weights = h5["household_sparse_weight"]["2024"][:] + hh_ids = h5["household_id"]["2024"][:] + except KeyError: + print( + "Sparse weights not found. Generating SparseEnhancedCPS_2024 dataset..." + ) + ecps.generate() + h5 = ecps.load() + sparse_weights = h5["household_sparse_weight"]["2024"][:] + hh_ids = h5["household_id"]["2024"][:] + + template_sim = Microsimulation( + dataset=EnhancedCPS_2024, + ) + template_sim.set_input("household_weight", 2024, sparse_weights) + + template_df = template_sim.to_input_dataframe() + + household_weight_column = f"household_weight__{time_period}" + df_household_id_column = f"household_id__{time_period}" + df_person_id_column = f"person_id__{time_period}" + + # Group by household ID and get the first entry for each group + df = template_df + h_df = df.groupby(df_household_id_column).first() + h_ids = pd.Series(h_df.index) + h_weights = pd.Series(h_df[household_weight_column].values) + + # Seed the random number generators for reproducibility + h_ids = h_ids[h_weights > 0] + h_weights = h_weights[h_weights > 0] + + subset_df = df[df[df_household_id_column].isin(h_ids)].copy() + + household_id_to_count = {} + for household_id in h_ids: + if household_id not in household_id_to_count: + household_id_to_count[household_id] = 0 + household_id_to_count[household_id] += 1 + + household_counts = subset_df[df_household_id_column].map( + lambda x: household_id_to_count.get(x, 0) + ) + + # NOTE: from subsample. I don't think I want to do this! + ## Adjust household weights to maintain the total weight + # for col in subset_df.columns: + # if "weight__" in col: + # target_total_weight = df[col].values.sum() + # if not quantize_weights: + # subset_df[col] *= household_counts.values + # else: + # subset_df[col] = household_counts.values + # subset_df[col] *= ( + # target_total_weight / subset_df[col].values.sum() + # ) + + df = subset_df + + # Update the dataset and rebuild the simulation + sim = Microsimulation() + sim.dataset = Dataset.from_dataframe(df, sim.dataset.time_period) + sim.build_from_dataset() + + # Ensure the baseline branch has the new data. + if "baseline" in sim.branches: + baseline_tax_benefit_system = sim.branches[ + "baseline" + ].tax_benefit_system + sim.branches["baseline"] = sim.clone() + sim.branches["tax_benefit_system"] = baseline_tax_benefit_system + + sim.default_calculation_period = time_period + + # Get ready to write it out + simulation = sim + data = {} + for variable in simulation.tax_benefit_system.variables: + data[variable] = {} + for time_period in simulation.get_holder(variable).get_known_periods(): + values = simulation.get_holder(variable).get_array(time_period) + values = np.array(values) + if simulation.tax_benefit_system.variables.get( + variable + ).value_type in (Enum, str): + values = values.astype("S") + if values is not None: + data[variable][time_period] = values + + if len(data[variable]) == 0: + del data[variable] + + import h5py + + with h5py.File(STORAGE_FOLDER / "sparse_enhanced_cps_2024.h5", "w") as f: + for variable, periods in data.items(): + grp = f.create_group(variable) + for period, values in periods.items(): + grp.create_dataset(str(period), data=values) + + if __name__ == "__main__": create_small_ecps() print("Small CPS dataset created successfully.") + create_sparse_ecps() + print("Sparse CPS dataset created successfully.") diff --git a/policyengine_us_data/storage/upload_completed_datasets.py b/policyengine_us_data/storage/upload_completed_datasets.py index f161a9ee..16885d8c 100644 --- a/policyengine_us_data/storage/upload_completed_datasets.py +++ b/policyengine_us_data/storage/upload_completed_datasets.py @@ -15,6 +15,7 @@ def upload_datasets(): Pooled_3_Year_CPS_2023.file_path, CPS_2023.file_path, STORAGE_FOLDER / "small_enhanced_cps_2024.h5", + STORAGE_FOLDER / "enhanced_cps_2024_minified.h5", ] for file_path in dataset_files: diff --git a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py index abf67301..7c815880 100644 --- a/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py +++ b/policyengine_us_data/tests/test_datasets/test_enhanced_cps.py @@ -1,4 +1,5 @@ import pytest +import pandas as pd def test_ecps_has_mortgage_interest(): @@ -254,3 +255,61 @@ def test_medicaid_calibration(): assert ( not failed ), f"One or more states exceeded tolerance of {TOLERANCE:.0%}." + + +def test_minimized_enhanced_cps_calibration_quality(): + """ + Test that minimized Enhanced CPS datasets maintain calibration quality above 75%. + Quality score formula: ((excellentCount * 100 + goodCount * 75) / totalTargets) + + Quality Categories: + - Excellent (< 5% error): 100 points each + - Good (5-20% error): 75 points each + - Poor (≥ 20% error): 0 points each + """ + from policyengine_us_data.datasets.cps import MinimizedEnhancedCPS_2024 + from policyengine_us_data.utils.minimize import create_calibration_log_file + from policyengine_us import Microsimulation + + sim = Microsimulation(dataset=MinimizedEnhancedCPS_2024) + assert ( + len(sim.calculate("household_weight")) < 30_000 + ), "Minimized Enhanced CPS should have fewer than 30,000 households." + + create_calibration_log_file(MinimizedEnhancedCPS_2024) + + calibration_log = pd.read_csv( + str(MinimizedEnhancedCPS_2024.file_path).replace( + ".h5", "_calibration_log.csv" + ) + ) + + # Calculate quality categories + excellent_count = ( + calibration_log["rel_abs_error"] < 0.05 + ).sum() # < 5% error + good_count = ( + (calibration_log["rel_abs_error"] >= 0.05) + & (calibration_log["rel_abs_error"] < 0.20) + ).sum() # 5-20% error + poor_count = ( + calibration_log["rel_abs_error"] >= 0.20 + ).sum() # ≥ 20% error + total_targets = len(calibration_log) + + # Calculate quality score + quality_score = (excellent_count * 100 + good_count * 75) / total_targets + + print(f" Total targets: {total_targets}") + print(f" Excellent (< 5% error): {excellent_count}") + print(f" Good (5-20% error): {good_count}") + print(f" Poor (≥ 20% error): {poor_count}") + print(f" Quality score: {quality_score:.1f}%") + + # Assert quality score is above 75% + assert quality_score >= 75.0, ( + f"Calibration quality score {quality_score:.1f}% is below 75% threshold " + f"for {MinimizedEnhancedCPS_2024.label}. " + f"Breakdown: {excellent_count} excellent, {good_count} good, {poor_count} poor " + f"out of {total_targets} total targets." + ) diff --git a/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py new file mode 100644 index 00000000..b807c1ef --- /dev/null +++ b/policyengine_us_data/tests/test_datasets/test_sparse_enhanced_cps.py @@ -0,0 +1,85 @@ +import pytest + +import numpy as np + +from policyengine_us_data.utils import build_loss_matrix + + +def test_sparse_ecps(): + from policyengine_core.data import Dataset + from policyengine_us_data.storage import STORAGE_FOLDER + from policyengine_us import Microsimulation + + # NOTE: replace with "small_enhanced_cps_2024.h5 to see the difference! + sim = Microsimulation( + dataset=Dataset.from_file( + STORAGE_FOLDER / f"sparse_enhanced_cps_2024.h5", + ) + ) + + data = sim.dataset.load_dataset() + bad_targets = [ + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "state/RI/adjusted_gross_income/amount/-inf_1", + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "state/RI/adjusted_gross_income/amount/-inf_1", + "nation/irs/exempt interest/count/AGI in -inf-inf/taxable/All", + ] + + year = 2024 + loss_matrix, targets_array = build_loss_matrix(sim.dataset, year) + zero_mask = np.isclose(targets_array, 0.0, atol=0.1) + bad_mask = loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~(zero_mask | bad_mask) + keep_idx = np.where(keep_mask_bool)[0] + loss_matrix_clean = loss_matrix.iloc[:, keep_idx] + targets_array_clean = targets_array[keep_idx] + assert loss_matrix_clean.shape[1] == targets_array_clean.size + + optimised_weights = data["household_weight"]["2024"] + print("\n\n---Sparse Solutions: reweighting quick diagnostics----\n") + print( + f"{np.sum(optimised_weights == 0)} are zero, {np.sum(optimised_weights != 0)} weights are nonzero" + ) + estimate = optimised_weights @ loss_matrix_clean + rel_error = ( + ((estimate - targets_array_clean) + 1) / (targets_array_clean + 1) + ) ** 2 + within_10_percent_mask = np.abs(estimate - targets_array_clean) <= ( + 0.10 * np.abs(targets_array_clean) + ) + percent_within_10 = np.mean(within_10_percent_mask) * 100 + print( + f"rel_error: min: {np.min(rel_error):.2f}\n" + f"max: {np.max(rel_error):.2f}\n" + f"mean: {np.mean(rel_error):.2f}\n" + f"median: {np.median(rel_error):.2f}\n" + f"Wthin 10% of target: {percent_within_10:.2f}%" + ) + print("Relative error over 100% for:") + for i in np.where(rel_error > 1)[0]: + print(f"target_name: {loss_matrix_clean.columns[i]}") + print(f"target_value: {targets_array_clean[i]}") + print(f"estimate_value: {estimate[i]}") + print(f"has rel_error: {rel_error[i]:.2f}\n") + print("---End of reweighting quick diagnostics------") + + assert percent_within_10 > 70.0 + + +if __name__ == "main": + test_sparse_ecps() diff --git a/policyengine_us_data/utils/__init__.py b/policyengine_us_data/utils/__init__.py index d25c6c2f..136d2503 100644 --- a/policyengine_us_data/utils/__init__.py +++ b/policyengine_us_data/utils/__init__.py @@ -3,3 +3,4 @@ from .uprating import * from .loss import * from .qrf import * +from .l0 import * diff --git a/policyengine_us_data/utils/l0.py b/policyengine_us_data/utils/l0.py new file mode 100644 index 00000000..ebd89d0a --- /dev/null +++ b/policyengine_us_data/utils/l0.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +class HardConcrete(nn.Module): + """HardConcrete distribution for L0 regularization.""" + + def __init__( + self, + input_dim, + output_dim=None, + temperature=0.5, + stretch=0.1, + init_mean=0.5, + ): + super().__init__() + if output_dim is None: + self.gate_size = (input_dim,) + else: + self.gate_size = (input_dim, output_dim) + self.qz_logits = nn.Parameter(torch.zeros(self.gate_size)) + self.temperature = temperature + self.stretch = stretch + self.gamma = -0.1 + self.zeta = 1.1 + self.init_mean = init_mean + self.reset_parameters() + + def reset_parameters(self): + if self.init_mean is not None: + init_val = math.log(self.init_mean / (1 - self.init_mean)) + self.qz_logits.data.fill_(init_val) + + def forward(self, input_shape=None): + if self.training: + gates = self._sample_gates() + else: + gates = self._deterministic_gates() + if input_shape is not None and len(input_shape) > len(gates.shape): + gates = gates.unsqueeze(-1).unsqueeze(-1) + return gates + + def _sample_gates(self): + u = torch.zeros_like(self.qz_logits).uniform_(1e-8, 1.0 - 1e-8) + s = torch.log(u) - torch.log(1 - u) + self.qz_logits + s = torch.sigmoid(s / self.temperature) + s = s * (self.zeta - self.gamma) + self.gamma + gates = torch.clamp(s, 0, 1) + return gates + + def _deterministic_gates(self): + probs = torch.sigmoid(self.qz_logits) + gates = probs * (self.zeta - self.gamma) + self.gamma + return torch.clamp(gates, 0, 1) + + def get_penalty(self): + logits_shifted = self.qz_logits - self.temperature * math.log( + -self.gamma / self.zeta + ) + prob_active = torch.sigmoid(logits_shifted) + return prob_active.sum() + + def get_active_prob(self): + logits_shifted = self.qz_logits - self.temperature * math.log( + -self.gamma / self.zeta + ) + return torch.sigmoid(logits_shifted) + + +class L0Linear(nn.Module): + """Linear layer with L0 regularization using HardConcrete gates.""" + + def __init__( + self, + in_features, + out_features, + bias=True, + temperature=0.5, + init_sparsity=0.5, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", None) + self.weight_gates = HardConcrete( + out_features, + in_features, + temperature=temperature, + init_mean=init_sparsity, + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_normal_(self.weight, mode="fan_out") + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, input): + gates = self.weight_gates() + masked_weight = self.weight * gates + return F.linear(input, masked_weight, self.bias) + + def get_l0_penalty(self): + return self.weight_gates.get_penalty() + + def get_sparsity(self): + with torch.no_grad(): + prob_active = self.weight_gates.get_active_prob() + return 1.0 - prob_active.mean().item() + + +class SparseMLP(nn.Module): + """Example MLP with L0 regularization on all layers""" + + def __init__( + self, + input_dim=784, + hidden_dim=256, + output_dim=10, + init_sparsity=0.5, + temperature=0.5, + ): + super().__init__() + self.fc1 = L0Linear( + input_dim, + hidden_dim, + init_sparsity=init_sparsity, + temperature=temperature, + ) + self.fc2 = L0Linear( + hidden_dim, + hidden_dim, + init_sparsity=init_sparsity, + temperature=temperature, + ) + self.fc3 = L0Linear( + hidden_dim, + output_dim, + init_sparsity=init_sparsity, + temperature=temperature, + ) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def get_l0_loss(self): + l0_loss = 0 + for module in self.modules(): + if isinstance(module, L0Linear): + l0_loss += module.get_l0_penalty() + return l0_loss + + def get_sparsity_stats(self): + stats = {} + for name, module in self.named_modules(): + if isinstance(module, L0Linear): + stats[name] = { + "sparsity": module.get_sparsity(), + "active_params": module.get_l0_penalty().item(), + } + return stats + + +def train_with_l0(model, train_loader, epochs=10, l0_lambda=1e-3): + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.CrossEntropyLoss() + for epoch in range(epochs): + total_loss = 0 + total_l0 = 0 + for batch_idx, (data, target) in enumerate(train_loader): + optimizer.zero_grad() + output = model(data) + ce_loss = criterion(output, target) + l0_loss = model.get_l0_loss() + loss = ce_loss + l0_lambda * l0_loss + loss.backward() + optimizer.step() + total_loss += ce_loss.item() + total_l0 += l0_loss.item() + if epoch % 1 == 0: + sparsity_stats = model.get_sparsity_stats() + print( + f"Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, L0={total_l0/len(train_loader):.4f}" + ) + for layer, stats in sparsity_stats.items(): + print( + f" {layer}: {stats['sparsity']*100:.1f}% sparse, {stats['active_params']:.1f} active params" + ) + + +def prune_model(model, threshold=0.05): + for module in model.modules(): + if isinstance(module, L0Linear): + with torch.no_grad(): + prob_active = module.weight_gates.get_active_prob() + mask = (prob_active > threshold).float() + module.weight.data *= mask + return model diff --git a/policyengine_us_data/utils/loss.py b/policyengine_us_data/utils/loss.py index 21abce0f..fbdbacef 100644 --- a/policyengine_us_data/utils/loss.py +++ b/policyengine_us_data/utils/loss.py @@ -552,11 +552,6 @@ def build_loss_matrix(dataset: type, time_period): # Convert to thousands for the target targets_array.append(row["enrollment"]) - print( - f"Targeting Medicaid enrollment for {row['state']} " - f"with target {row['enrollment']:.0f}k" - ) - # State 10-year age targets age_targets = pd.read_csv(STORAGE_FOLDER / "age_state.csv") diff --git a/policyengine_us_data/utils/minimize.py b/policyengine_us_data/utils/minimize.py new file mode 100644 index 00000000..ce2c6fdf --- /dev/null +++ b/policyengine_us_data/utils/minimize.py @@ -0,0 +1,444 @@ +from policyengine_us_data.utils.loss import build_loss_matrix +from policyengine_core.data import Dataset +from policyengine_us import Microsimulation +import numpy as np +import pandas as pd +import h5py +from policyengine_us_data.storage import STORAGE_FOLDER +from typing import Optional, Callable + +bad_targets = [ + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/adjusted gross income/total/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/adjusted gross income/total/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 10k-15k/taxable/Head of Household", + "nation/irs/count/count/AGI in 15k-20k/taxable/Head of Household", + "nation/irs/count/count/AGI in 10k-15k/taxable/Married Filing Jointly/Surviving Spouse", + "nation/irs/count/count/AGI in 15k-20k/taxable/Married Filing Jointly/Surviving Spouse", +] + + +def create_calibration_log_file(file_path, epoch=0): + dataset = Dataset.from_file(file_path) + sim = Microsimulation(dataset=dataset) + + loss_matrix, targets = build_loss_matrix(dataset, 2024) + + bad_mask = loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~bad_mask + keep_idx = np.where(keep_mask_bool)[0] + loss_matrix_clean = loss_matrix.iloc[:, keep_idx] + targets_clean = targets[keep_idx] + + assert loss_matrix_clean.shape[1] == targets_clean.size + + estimates = ( + sim.calculate("household_weight", 2024).values @ loss_matrix_clean + ) + target_names = loss_matrix_clean.columns + + # Calculate and print some key metrics + errors = estimates - targets_clean + rel_errors = errors / targets_clean + + df = pd.DataFrame( + { + "target_name": target_names, + "estimate": estimates, + "target": targets_clean, + } + ) + df["epoch"] = epoch + df["error"] = df["estimate"] - df["target"] + df["rel_error"] = df["error"] / df["target"] + df["abs_error"] = df["error"].abs() + df["rel_abs_error"] = ( + df["abs_error"] / df["target"].abs() + if df["target"].abs().sum() > 0 + else np.nan + ) + df["loss"] = (df["rel_error"] ** 2).mean() + + df.to_csv( + str(file_path).replace(".h5", "_calibration_log.csv"), index=False + ) + + +def losses_for_candidates( + base_weights: np.ndarray, + idxs: np.ndarray, + est_mat: np.ndarray, + targets: np.ndarray, + norm: np.ndarray, + chunk_size: Optional[int] = 25_000, +) -> np.ndarray: + """ + Return the loss value *for each* candidate deletion in `idxs` + in one matrix multiplication. + + Parameters + ---------- + base_weights : (n,) original weight vector + idxs : (k,) candidate row indices to zero-out + est_mat : (n, m) estimate matrix + targets : (m,) calibration targets + norm : (m,) normalisation factors + chunk_size : max number of candidates to process at once + + Returns + ------- + losses : (k,) loss if row i were removed (and weights rescaled) + """ + W = base_weights + total = W.sum() + k = len(idxs) + losses = np.empty(k, dtype=float) + + # Work through the candidate list in blocks + for start in range(0, k, chunk_size): + stop = min(start + chunk_size, k) + part = idxs[start:stop] # (p,) where p ≤ chunk_size + p = len(part) + + # Build the delta matrix only for this chunk + delta = np.zeros((p, len(W))) + delta[np.arange(p), part] = -W[part] + + keep_total = total + delta.sum(axis=1) # (p,) + delta *= (total / keep_total)[:, None] + + # Matrix–matrix multiply → one matrix multiplication per chunk + ests = (W + delta) @ est_mat # (p, m) + rel_err = ((ests - targets) + 1) / (targets + 1) + losses[start:stop] = ((rel_err * norm) ** 2).mean(axis=1) + + return losses + + +def get_loss_from_mask( + weights, inclusion_mask, estimate_matrix, targets, normalisation_factor +): + """ + Calculate the loss based on the inclusion mask and the estimate matrix. + """ + # Step 1: Apply mask and rescale weights + masked_weights = weights.copy() + original_weight_total = masked_weights.sum() + if (~inclusion_mask).sum() > 0: + masked_weights[~inclusion_mask] = 0 + masked_weight_total = masked_weights.sum() + masked_weights[inclusion_mask] *= ( + original_weight_total / masked_weight_total + ) + + # Step 2: Re-calibrate the masked weights to hit targets + # Only calibrate the included households + included_weights = masked_weights[inclusion_mask] + included_estimate_matrix = estimate_matrix.iloc[ + inclusion_mask + ] # Keep as DataFrame + + # Call reweight function to calibrate the selected households + from policyengine_us_data.datasets.cps.enhanced_cps import reweight + + calibrated_weights_included = reweight( + included_weights, + included_estimate_matrix, + targets, + epochs=250, + ) + + # Put calibrated weights back into full array + calibrated_weights = np.zeros_like(masked_weights) + calibrated_weights[inclusion_mask] = calibrated_weights_included + + # Calculate estimates and loss from calibrated weights + estimates = calibrated_weights @ estimate_matrix + rel_error = ((estimates - targets) + 1) / (targets + 1) + loss = ((rel_error * normalisation_factor) ** 2).mean() + + return loss + + +def candidate_loss_contribution( + weights: np.ndarray, + estimate_matrix: np.ndarray, + targets: np.ndarray, + normalisation_factor: np.ndarray, + loss_rel_change_max: float, + count_iterations: int = 5, + view_fraction_per_iteration: float = 0.5, + fraction_remove_per_iteration: float = 0.05, +) -> np.ndarray: + """ + Minimization approach based on candidate loss contribution. + + This function iteratively removes households that contribute least to the loss, + maintaining the calibration quality within the specified tolerance. + + Parameters + ---------- + weights : (n,) household weights + estimate_matrix : (n, m) matrix mapping weights to estimates + targets : (m,) calibration targets + normalisation_factor : (m,) normalisation factors for different targets + loss_rel_change_max : maximum allowed relative change in loss + count_iterations : number of iterations to perform + view_fraction_per_iteration : fraction of households to evaluate each iteration + fraction_remove_per_iteration : fraction of households to remove each iteration + + Returns + ------- + inclusion_mask : (n,) boolean mask of households to keep + """ + from tqdm import tqdm + + full_mask = np.ones_like(weights, dtype=bool) + + for i in range(count_iterations): + inclusion_mask = full_mask.copy() + baseline_loss = get_loss_from_mask( + weights, + inclusion_mask, + estimate_matrix, + targets, + normalisation_factor, + ) + + # Sample only households that are currently included + indices = np.random.choice( + np.where(full_mask)[0], + size=int(full_mask.sum() * view_fraction_per_iteration), + replace=False, + ) + # 2. compute losses for the batch in one shot + candidate_losses = losses_for_candidates( + weights, indices, estimate_matrix, targets, normalisation_factor + ) + # 3. convert to relative change vs. baseline + household_loss_rel_changes = ( + candidate_losses - baseline_loss + ) / baseline_loss + + inclusion_mask = full_mask.copy() + household_loss_rel_changes = np.array(household_loss_rel_changes) + # Sort by the relative change in loss + sorted_indices = np.argsort(household_loss_rel_changes) + + # Remove the worst households + num_to_remove = int(len(weights) * fraction_remove_per_iteration) + worst_indices = indices[sorted_indices[:num_to_remove]] + inclusion_mask[worst_indices] = False + + # Calculate the new loss + new_loss = get_loss_from_mask( + weights, + inclusion_mask, + estimate_matrix, + targets, + normalisation_factor, + ) + rel_change = (new_loss - baseline_loss) / baseline_loss + + if rel_change > loss_rel_change_max: + print( + f"Iteration {i + 1}: Loss changed from {baseline_loss} to {new_loss}, " + f"which is too high ({rel_change:.2%}). Stopping." + ) + break + + print( + f"Iteration {i + 1}: Loss changed from {baseline_loss} to {new_loss}" + ) + print( + f"Removed {num_to_remove} households with worst relative loss changes." + ) + + # Update the full mask + full_mask &= inclusion_mask + + return full_mask + + +def random_sampling_minimization( + weights, + estimate_matrix, + targets, + normalisation_factor, + random=True, + target_fractions=[0.5, 0.6, 0.7, 0.8, 0.9], +): + """A simple random sampling approach""" + n = len(weights) + + household_weights_normalized = weights / weights.sum() + + final_mask = None + lowest_loss = float("inf") + for fraction in target_fractions: + target_size = int(n * fraction) + # Random sampling with multiple attempts + best_mask = None + best_loss = float("inf") + + for _ in range(3): # Try 3 random samples + mask = np.zeros(n, dtype=bool) + mask[ + np.random.choice( + n, + target_size, + p=household_weights_normalized if random else None, + replace=False, + ) + ] = True + + loss = get_loss_from_mask( + weights, mask, estimate_matrix, targets, normalisation_factor + ) + + if loss < best_loss: + best_loss = loss + best_mask = mask + + if lowest_loss > best_loss: + lowest_loss = best_loss + final_mask = best_mask + + return final_mask + + +def minimize_dataset( + dataset, + output_path: str, + minimization_function: Callable = candidate_loss_contribution, + loss_matrix: Optional[pd.DataFrame] = None, + targets: Optional[np.ndarray] = None, + **kwargs, +) -> None: + """ + Main function to minimize a dataset using a specified minimization approach. + + Parameters + ---------- + dataset : path to the dataset file or Dataset object + output_path : path where the minimized dataset will be saved + loss_rel_change_max : maximum allowed relative change in loss + minimization_function : function that implements the minimization logic + **kwargs : additional arguments to pass to the minimization function + """ + # Handle both dataset class and file path + if hasattr(dataset, "file_path"): + dataset_path = str(dataset.file_path) + else: + dataset_path = str(dataset) + + create_calibration_log_file(dataset_path) + + dataset = Dataset.from_file(dataset_path) + if loss_matrix is None or targets is None: + loss_matrix, targets = build_loss_matrix(dataset, 2024) + + bad_mask = loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~bad_mask + keep_idx = np.where(keep_mask_bool)[0] + loss_matrix_clean = loss_matrix.iloc[:, keep_idx] + targets_clean = targets[keep_idx] + assert loss_matrix_clean.shape[1] == targets_clean.size + else: + loss_matrix_clean = loss_matrix + targets_clean = targets + + sim = Microsimulation(dataset=dataset) + + weights = sim.calculate("household_weight", 2024).values + is_national = loss_matrix_clean.columns.str.startswith("nation/") + nation_normalisation_factor = is_national * (1 / is_national.sum()) + state_normalisation_factor = ~is_national * (1 / (~is_national).sum()) + normalisation_factor = np.where( + is_national, nation_normalisation_factor, state_normalisation_factor + ) + + # Call the minimization function + inclusion_mask = minimization_function( + weights=weights, + estimate_matrix=loss_matrix_clean, + targets=targets_clean, + normalisation_factor=normalisation_factor, + **kwargs, # Allows for passing either loss_rel_change_max OR target_fractions, depending on normalisation_factor choice. + ) + + # Extract household IDs for remaining households + household_ids = sim.calculate("household_id", 2024).values + remaining_households = household_ids[inclusion_mask] + + # Create a smaller dataset with only the remaining households + df = sim.to_input_dataframe() + smaller_df = df[df["household_id__2024"].isin(remaining_households)] + + weight_rel_change = ( + smaller_df["household_weight__2024"].sum() + / df["household_weight__2024"].sum() + ) + print(f"Weight relative change: {weight_rel_change:.2%}") + + # Create new simulation with smaller dataset + sim = Microsimulation(dataset=smaller_df) + + # Rescale weights to maintain total + initial_weights = ( + sim.calculate("household_weight", 2024).values / weight_rel_change + ) + + # Re-calibrate the final selected households to hit targets + print("Re-calibrating final selected households...") + + # Build loss matrix for the smaller dataset + smaller_loss_matrix, smaller_targets = build_loss_matrix(sim.dataset, 2024) + + # Apply same filtering as before + bad_mask = smaller_loss_matrix.columns.isin(bad_targets) + keep_mask_bool = ~bad_mask + keep_idx = np.where(keep_mask_bool)[0] + smaller_loss_matrix_clean = smaller_loss_matrix.iloc[:, keep_idx] + smaller_targets_clean = smaller_targets[keep_idx] + + from policyengine_us_data.datasets.cps.enhanced_cps import reweight + + calibrated_weights = reweight( + initial_weights, + smaller_loss_matrix_clean, # Now matches the smaller dataset size + smaller_targets_clean, + epochs=250, # Reduced epochs for faster processing + ) + sim.set_input("household_weight", 2024, calibrated_weights) + print("Final calibration completed successfully") + # Prepare data for saving + data = {} + for variable in sim.input_variables: + data[variable] = {2024: sim.calculate(variable, 2024).values} + if data[variable][2024].dtype == "object": + data[variable][2024] = data[variable][2024].astype("S") + + # Save to HDF5 file + with h5py.File(output_path, "w") as f: + for variable, values in data.items(): + for year, value in values.items(): + f.create_dataset(f"{variable}/{year}", data=value) + + print(f"Saved minimised dataset to {output_path}") + create_calibration_log_file(output_path, epoch=250) + + +if __name__ == "__main__": + # Example usage + files = [ + STORAGE_FOLDER / "enhanced_cps_2024.h5", + ] + + for file in files: + output_path = file.with_name(file.stem + "_minimised.h5") + minimize_dataset( + file, + output_path, + )