diff --git a/training/train.py b/training/train.py index d3139f77..c8e66c28 100644 --- a/training/train.py +++ b/training/train.py @@ -664,7 +664,7 @@ def wrap_ddp(model): for pfc in list_module_pfc: clip_grad_norm_(pfc.parameters(), max_norm=5, norm_type=2) opt.step() - opt.zero_grad() + opt.zero_grad(set_to_none=True) lr_scheduler.step() @@ -692,8 +692,6 @@ def wrap_ddp(model): list_head_names=args.list_head_names, keep_num=20, ) - # Also save in HuggingFace format - save_hf_checkpoint(args.output, backbone, global_step=global_step, image_size=args.image_size[0]) if global_step > args.total_steps: save_checkpoint(