From f761fa2dde96b2a237b91beaabb19efb40b90560 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 23 Jan 2026 09:29:21 -0800 Subject: [PATCH] Drop '_pg_collection' in model config when ckpting Signed-off-by: Asha Anoosheh --- modelopt/torch/opt/plugins/mcore_dist_checkpointing.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py index 3e5b35946..87072ff67 100644 --- a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py +++ b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py @@ -39,7 +39,6 @@ DROP_SUBSTRINGS = [ "fp4", "fp8", - "tp_", "parallel", "cuda_graph", "init_", @@ -49,6 +48,10 @@ "pipeline", "comm", "batch", + "pg_collection", +] +DROP_STARTSWITH = [ + "tp_", # would drop 'mtp_*' otherwise ] @@ -145,6 +148,8 @@ def _parse_transformer_config(transformer_config: dict) -> dict: for k, v in transformer_config.items(): if any(substring in k for substring in DROP_SUBSTRINGS): continue + if any(k.startswith(prefix) for prefix in DROP_STARTSWITH): + continue if isinstance(v, (bool, int, str)): config[k] = v else: