diff --git a/training/train.py b/training/train.py index c8e66c2..6cb82b2 100644 --- a/training/train.py +++ b/training/train.py @@ -350,7 +350,8 @@ def _expand(name, v): optimizer_cls = torch.optim.AdamW opt = optimizer_cls(parameters, lr=args.lr, weight_decay=args.weight_decay) - lr_scheduler = PolynomialLRWarmup(opt, int(args.total_steps * args.warmup_ratio), args.total_steps, 2) + optimizer_total_steps = args.total_steps // args.backward_passes_per_step + lr_scheduler = PolynomialLRWarmup(opt, int(optimizer_total_steps * args.warmup_ratio), optimizer_total_steps, 2) else: raise ValueError(f"{args.opt} not support!") @@ -652,7 +653,7 @@ def wrap_ddp(model): list_loss.append(head_loss) list_loss_float.append(head_loss.item()) - is_accumulation_step = global_step % args.backward_passes_per_step != 0 + is_accumulation_step = (global_step + 1) % args.backward_passes_per_step != 0 scaled_loss = sum(list_loss) / args.backward_passes_per_step if is_accumulation_step: @@ -665,8 +666,7 @@ def wrap_ddp(model): clip_grad_norm_(pfc.parameters(), max_norm=5, norm_type=2) opt.step() opt.zero_grad(set_to_none=True) - - lr_scheduler.step() + lr_scheduler.step() batch_end_callback( global_step=global_step,