diff --git a/policyengine_us_data/datasets/cps/cps.py b/policyengine_us_data/datasets/cps/cps.py index 84969500..a02c4c9e 100644 --- a/policyengine_us_data/datasets/cps/cps.py +++ b/policyengine_us_data/datasets/cps/cps.py @@ -228,9 +228,25 @@ def add_takeup(self): rng = seeded_rng("takes_up_dc_ptc") data["takes_up_dc_ptc"] = rng.random(n_tax_units) < dc_ptc_rate - # SNAP + # SNAP: prioritize reported recipients rng = seeded_rng("takes_up_snap_if_eligible") - data["takes_up_snap_if_eligible"] = rng.random(n_spm_units) < snap_rate + reported_snap = data["snap_reported"] > 0 + + # Calculate adjusted rate for non-reporters to hit target + n_snap_reporters = reported_snap.sum() + n_snap_non_reporters = (~reported_snap).sum() + target_snap_takeup_count = int(snap_rate * n_spm_units) + remaining_snap_needed = max(0, target_snap_takeup_count - n_snap_reporters) + snap_non_reporter_rate = ( + remaining_snap_needed / n_snap_non_reporters + if n_snap_non_reporters > 0 + else 0 + ) + + # Assign: all reporters + adjusted rate for non-reporters + data["takes_up_snap_if_eligible"] = reported_snap | ( + (~reported_snap) & (rng.random(n_spm_units) < snap_non_reporter_rate) + ) # ACA rng = seeded_rng("takes_up_aca_if_eligible") @@ -264,9 +280,25 @@ def add_takeup(self): rng.random(n_persons) < early_head_start_rate ) - # SSI + # SSI: prioritize reported recipients rng = seeded_rng("takes_up_ssi_if_eligible") - data["takes_up_ssi_if_eligible"] = rng.random(n_persons) < ssi_rate + reported_ssi = data["ssi_reported"] > 0 + + # Calculate adjusted rate for non-reporters to hit target + n_ssi_reporters = reported_ssi.sum() + n_ssi_non_reporters = (~reported_ssi).sum() + target_ssi_takeup_count = int(ssi_rate * n_persons) + remaining_ssi_needed = max(0, target_ssi_takeup_count - n_ssi_reporters) + ssi_non_reporter_rate = ( + remaining_ssi_needed / n_ssi_non_reporters + if n_ssi_non_reporters > 0 + else 0 + ) + + # Assign: all reporters + adjusted rate for non-reporters + data["takes_up_ssi_if_eligible"] = reported_ssi | ( + (~reported_ssi) & (rng.random(n_persons) < ssi_non_reporter_rate) + ) # WIC: resolve draws to bools using category-specific rates wic_categories = baseline.calculate("wic_category_str").values