From 52a19154044348a8fa456d0e00f8830710bf1eb5 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 7 Feb 2026 19:16:35 +0800 Subject: [PATCH 1/2] fix: correct gradient accumulation off-by-one and lr_scheduler over-stepping --- training/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/training/train.py b/training/train.py index d3139f77..bd3f336a 100644 --- a/training/train.py +++ b/training/train.py @@ -652,7 +652,7 @@ def wrap_ddp(model): list_loss.append(head_loss) list_loss_float.append(head_loss.item()) - is_accumulation_step = global_step % args.backward_passes_per_step != 0 + is_accumulation_step = (global_step + 1) % args.backward_passes_per_step != 0 scaled_loss = sum(list_loss) / args.backward_passes_per_step if is_accumulation_step: @@ -665,8 +665,7 @@ def wrap_ddp(model): clip_grad_norm_(pfc.parameters(), max_norm=5, norm_type=2) opt.step() opt.zero_grad() - - lr_scheduler.step() + lr_scheduler.step() batch_end_callback( global_step=global_step, From ab34a54a27828fb0ec3faaa83fea1ec5c8d411b2 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Sat, 7 Feb 2026 20:42:24 +0800 Subject: [PATCH 2/2] fix: align scheduler total_iters with optimizer steps under gradient accumulation lr_scheduler total_iters was set to micro-step count (total_steps), but after moving lr_scheduler.step() to only fire on optimizer steps, the scheduler would only traverse 1/backward_passes_per_step of its budget. Divide total_iters by backward_passes_per_step so the full LR curve (warmup + polynomial decay) completes over the actual optimizer steps. No-op when backward_passes_per_step=1 (Stage-1). --- training/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/training/train.py b/training/train.py index bd3f336a..203f0550 100644 --- a/training/train.py +++ b/training/train.py @@ -350,7 +350,8 @@ def _expand(name, v): optimizer_cls = torch.optim.AdamW opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay) - lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2) + optimizer_total_steps = args.total_steps // args.backward_passes_per_step + lr_scheduler = PolynomialLRWarmup(opt, int(optimizer_total_steps * args.warmup_ratio), optimizer_total_steps, 2) else: raise ValueError(f"{args.opt} not support!")