Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions analysis/clustering/prepare_clustering_data_households.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def get_metadata_and_samples( # noqa: C901
sample_days: int,
day_strategy: Literal["stratified", "random"],
seed: int = 42,
year: int | None = None,
month: int | None = None,
) -> dict[str, Any]:
"""
Get summary statistics and sample households + dates using MANIFESTS.
Expand Down Expand Up @@ -163,10 +165,17 @@ def get_metadata_and_samples( # noqa: C901
accounts_df = pl.concat([accounts_df, pl.read_parquet(acc_manifest)]).unique()
dates_df = pl.concat([dates_df, pl.read_parquet(date_manifest_extra)]).unique()

# Apply July-only filter (after all dates are assembled)
# THIS IS JUST A BANDAID IT WILL GET FIXED ASAP
dates_df = dates_df.filter((pl.col("date") >= pl.date(2023, 7, 1)) & (pl.col("date") <= pl.date(2023, 7, 31)))
logger.info(" Dates available after July filter: %d", dates_df.height)
# Apply month filter if year/month are specified (after all dates are assembled)
if year is not None and month is not None:
from calendar import monthrange

_, last_day = monthrange(year, month)
start_date = pl.date(year, month, 1)
end_date = pl.date(year, month, last_day)
dates_df = dates_df.filter((pl.col("date") >= start_date) & (pl.col("date") <= end_date))
logger.info(" Dates available after %d-%02d filter: %d", year, month, dates_df.height)
else:
logger.info(" No month filter applied (using all available dates): %d", dates_df.height)

if accounts_df.height == 0:
raise ValueError("No account_identifier values found in manifest.")
Expand Down Expand Up @@ -416,6 +425,8 @@ def prepare_clustering_data(
streaming: bool = False,
chunk_size: int = 5000,
seed: int = 42,
year: int | None = None,
month: int | None = None,
) -> dict[str, Any]:
"""Prepare household-level clustering data from interval parquet."""
logger.info("=" * 70)
Expand All @@ -437,6 +448,8 @@ def prepare_clustering_data(
sample_days=sample_days,
day_strategy=day_strategy,
seed=seed,
year=year,
month=month,
)

accounts = metadata["accounts"]
Expand Down Expand Up @@ -520,6 +533,8 @@ def main() -> int:
parser.add_argument(
"--chunk-size", type=int, default=5000, help="Households per chunk when --streaming is enabled."
)
parser.add_argument("--year", type=int, default=None, help="Year to filter dates (e.g., 2023).")
parser.add_argument("--month", type=int, default=None, help="Month to filter dates (1-12).")

args = parser.parse_args()

Expand All @@ -538,6 +553,8 @@ def main() -> int:
streaming=args.streaming,
chunk_size=args.chunk_size,
seed=args.seed,
year=args.year,
month=args.month,
)
return 0

Expand Down
Loading
Loading