@@ -173,10 +173,11 @@ def ema_reset(self):
173173 self .ema_buffer_modele_params = None
174174
175175 @imperative_base .no_grad ()
176- def ema_accumulate (self ):
176+ def ema_accumulate (self , global_step , loss , zcc_ema_loss_threshold ):
177177 """
178178 perform ema update : ` \a lpha * EMA + (1-\a lpha) + model`
179179 build `self.ema_buffer` if necessary
180+ when loss < threshold, do ema update
180181 """
181182 # logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
182183 # do update: ema = alpha * ema + (1-alpha) * model
@@ -185,14 +186,19 @@ def ema_accumulate(self):
185186 cpu_master_weights = self .optimizer_fusion_storage_helper .cpu_buffer ._slice (
186187 self .master_min_offset , self .master_max_offset
187188 ).cpu ()
188- self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
189- # logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
190- for index , ema_buf in self .ema_buffer_model_params .items ():
191- _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
192- updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf .cpu ()
193- self .ema_buffer_model_params [index ] = updated_ema
194-
195- logger .info (f"[ZCC EMA] accumulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done" )
189+ if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold :
190+ self .ema_buffer = self .ema_coef * self .ema_buffer + (1 - self .ema_coef ) * cpu_master_weights
191+ for index , ema_buf in self .ema_buffer_model_params .items ():
192+ _ , cpu_buf = self .param_fusion_storage_helper .inited_buffers [index ]
193+ updated_ema = self .ema_coef * ema_buf + (1 - self .ema_coef ) * cpu_buf
194+ self .ema_buffer_model_params [index ] = updated_ema
195+ logger .info (
196+ f"[ZCC EMA] accmulating, buffer type:{ self .ema_buffer .place } { self .ema_buffer .dtype } , done"
197+ )
198+ else :
199+ logger .info (
200+ f"[ZCC EMA] accmulating SKIP for global_step:{ global_step } , because loss:{ loss } > threshold:{ zcc_ema_loss_threshold } "
201+ )
196202
197203 @imperative_base .no_grad ()
198204 def ema_state_dict (self ):
@@ -790,7 +796,11 @@ def process_offload_task(self, dump, global_step):
790796 self .global_step .value = global_step
791797
792798 if self .ema_coef is not None :
793- self .zcc_ema_processor .ema_accumulate ()
799+ self .zcc_ema_processor .ema_accumulate (
800+ self .trainer_state .global_step ,
801+ self .trainer_state .loss ,
802+ self .training_args_content .zcc_ema_loss_threshold ,
803+ )
794804
795805 # continue to process dumping task at the last chunk
796806 if self .offloaded_numels == self .all_numel :
@@ -1006,3 +1016,86 @@ def manage_offload_chunk(self):
10061016 logger .info (
10071017 f"[ZCC Worker{ self .worker_id } ] All numel: { self .all_numel } , Offload chunks: { self .offload_chunks } , Chunk size: { self .chunk_size_in_numel } ]"
10081018 )
1019+
1020+
1021+ class EMABuffer :
1022+ def __init__ (self , resume_from_checkpoint , args , sharding_io , offload = True ):
1023+ assert sharding_io is not None , "EMA should be only enabled when save_sharded_model is True"
1024+ self .master_weights = {}
1025+ self .model_params = {}
1026+ self .args = args
1027+ self .sharding_io = sharding_io
1028+ self .offload = offload
1029+ if resume_from_checkpoint is not None :
1030+ self ._load (resume_from_checkpoint )
1031+
1032+ def _ema_path (self , base_path ):
1033+ path = _add_variant (PADDLE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
1034+ path = path .replace ("optimizer" , "ema" )
1035+ return os .path .join (base_path , path )
1036+
1037+ def _load (self , resume_from_checkpoint ):
1038+ ema_path = self ._ema_path (resume_from_checkpoint )
1039+ if not os .path .exists (ema_path ):
1040+ return
1041+
1042+ logger .info (f"Loading EMA checkpoint from { resume_from_checkpoint } ..." )
1043+ with device_guard ("cpu" ):
1044+ ema_state_dict = paddle .load (ema_path )
1045+ logger .info (f"Load EMA checkpoint from { resume_from_checkpoint } done" )
1046+
1047+ self .master_weights = ema_state_dict .pop ("master_weights" )
1048+ self .model_params = ema_state_dict
1049+
1050+ def save (self , global_step ):
1051+ base_path = os .path .join (self .args .output_dir , f"{ PREFIX_CHECKPOINT_DIR } -{ global_step } " )
1052+ ema_path = self ._ema_path (base_path )
1053+ ema_state_dict = {"master_weights" : self .master_weights }
1054+ ema_state_dict .update (self .model_params )
1055+ os .makedirs (base_path , exist_ok = True )
1056+ logger .info (f"Saving EMA checkpoint to { base_path } ..." )
1057+ paddle .save (ema_state_dict , ema_path )
1058+ logger .info (f"Save EMA checkpoint to { base_path } done" )
1059+
1060+ def ema_accumulate (self , global_step , loss , ema_loss_threshold ):
1061+ if ema_loss_threshold is None or loss < ema_loss_threshold :
1062+ logger .info (f"EMA accumulating for step { global_step } ..." )
1063+ self ._ema_impl (
1064+ state_dict = self .sharding_io .optimizer .state_dict ()["master_weights" ],
1065+ ema_state_dict = self .master_weights ,
1066+ )
1067+ self ._ema_impl (
1068+ state_dict = self .sharding_io .manipulate_state_dict_and_config (
1069+ unwrap_model (self .sharding_io .model ),
1070+ merge_tensor_parallel = False ,
1071+ )[0 ],
1072+ ema_state_dict = self .model_params ,
1073+ )
1074+ logger .info (f"EMA accumulate done for step { global_step } " )
1075+
1076+ def _ema_impl (self , state_dict , ema_state_dict ):
1077+ ema_coef = self .args .zcc_save_ema_coef
1078+ for k , v in state_dict .items ():
1079+ if k in ema_state_dict :
1080+ ema_tensor = ema_state_dict [k ]
1081+ ema_tensor = ema_coef * ema_tensor .cuda () + (1 - ema_coef ) * v .cuda ()
1082+ ema_tensor .name = v .name
1083+ v = ema_tensor
1084+ del ema_tensor
1085+
1086+ if self .offload :
1087+ v_pin = v .pin_memory ()
1088+ v_pin .name = v .name
1089+ v = v_pin
1090+ ema_state_dict [k ] = v
1091+
1092+
1093+ class NonZCCEMACallback (TrainerCallback ):
1094+ def __init__ (self , resume_from_checkpoint , args , sharding_io , offload = True ):
1095+ self .buffer = EMABuffer (resume_from_checkpoint , args , sharding_io , offload )
1096+
1097+ def on_step_end (self , args , state , control , ** kwargs ):
1098+ if state .global_step % args .zcc_ema_interval == 0 :
1099+ self .buffer .ema_accumulate (state .global_step , state .loss , args .zcc_ema_loss_threshold )
1100+ if control .should_save :
1101+ self .buffer .save (state .global_step )
0 commit comments