Skip to content

Commit 8446d38

Browse files
authored
Add Non ZCC EMA callback (#2923)
1 parent 0d51c12 commit 8446d38

File tree

3 files changed

+125
-12
lines changed

3 files changed

+125
-12
lines changed

paddleformers/trainer/trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,13 @@
195195

196196
try:
197197
from .utils.zero_cost_checkpoint import (
198+
NonZCCEMACallback,
198199
ZeroCostCheckpointCallback,
199200
ZeroCostCheckpointManager,
200201
get_fused_param_mappings,
201202
)
202203
except (ImportError, ModuleNotFoundError):
203-
ZeroCostCheckpointManager, get_fused_param_mappings = None, None
204+
ZeroCostCheckpointManager, NonZCCEMACallback, get_fused_param_mappings = None, None, None
204205
from .utils.helper import ( # nested_truncate,
205206
broadcast_dataset_rank0_model,
206207
broadcast_dp_optimizer,
@@ -873,6 +874,9 @@ def create_zcc_manager(self, unwrapped_model, resume_from_checkpoint=None):
873874

874875
logger.info("Create zero cost checkpoint manager done.")
875876

877+
def add_non_zcc_ema_callback(self, resume_from_checkpoint):
878+
self.add_callback(NonZCCEMACallback(resume_from_checkpoint, self.args, self.sharding_io))
879+
876880
def _save_flex_model_state(self, output_dir):
877881
model_sharded_state_dict = self.model.sharded_state_dict()
878882
model_state_dict_path = os.path.join(output_dir, MODEL_STATE_DIC)
@@ -1135,6 +1139,8 @@ def train(
11351139

11361140
if self.args.enable_zero_cost_checkpoint:
11371141
self.create_zcc_manager(model, resume_from_checkpoint)
1142+
elif self.args.zcc_save_ema_coef is not None:
1143+
self.add_non_zcc_ema_callback(resume_from_checkpoint)
11381144

11391145
logger.info(f"{self.runtime_timer.log()}")
11401146
logger.info("***** Running training *****")
@@ -1365,6 +1371,16 @@ def _inner_training_loop(
13651371
self._skip_steps_since_last_logged += 1
13661372

13671373
self.state.epoch = epoch + (step + 1) / steps_in_epoch
1374+
1375+
# For ZCC EMA
1376+
if self.args.enable_zero_cost_checkpoint or self.args.zcc_save_ema_coef is not None:
1377+
tr_loss_for_zcc = tr_loss.clone()
1378+
dist.all_reduce(
1379+
tr_loss_for_zcc, dist.ReduceOp.SUM
1380+
) # 3级并行时,每个pp下的loss会广播,全局reduce-mean的时候,分子分母都会乘以pp_world_size,结果会被约掉
1381+
tr_loss_for_zcc_scalar = tr_loss_for_zcc.item() / dist.get_world_size()
1382+
self.state.loss = tr_loss_for_zcc_scalar
1383+
13681384
self.state.consumed_samples = (
13691385
self.state.global_step
13701386
* args.per_device_train_batch_size

paddleformers/trainer/training_args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,10 @@ class TrainingArguments:
10901090
default=1,
10911091
metadata={"help": "Interval between updating EMA parameters."},
10921092
)
1093+
zcc_ema_loss_threshold: Optional[float] = field(
1094+
default=None,
1095+
metadata={"help": "If set not None, only do EMA when the training loss is smaller than the threshold value"},
1096+
)
10931097
save_tokenizer: Optional[bool] = field(
10941098
default=True,
10951099
metadata={"help": "Save tokenizer to output_dir."},
@@ -2099,7 +2103,7 @@ def is_context_parallel_supported():
20992103
assert (
21002104
self.save_steps % self.zcc_ema_interval == 0
21012105
), f"save_steps[{self.save_steps}] must be divisible by zcc_ema_interval[{self.zcc_ema_interval}]"
2102-
if self.zcc_save_ema_coef is not None:
2106+
if self.enable_zero_cost_checkpoint and self.zcc_save_ema_coef is not None:
21032107
assert (
21042108
self.zcc_workers_num == 1
21052109
), "EMA function in zero cost checkpoint mode does not support zcc_workers_num > 1 for now."

paddleformers/trainer/utils/zero_cost_checkpoint.py

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,10 +173,11 @@ def ema_reset(self):
173173
self.ema_buffer_modele_params = None
174174

175175
@imperative_base.no_grad()
176-
def ema_accumulate(self):
176+
def ema_accumulate(self, global_step, loss, zcc_ema_loss_threshold):
177177
"""
178178
perform ema update : ` \alpha * EMA + (1-\alpha) + model`
179179
build `self.ema_buffer` if necessary
180+
when loss < threshold, do ema update
180181
"""
181182
# logger.info(f'[ZCC EMA] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
182183
# do update: ema = alpha * ema + (1-alpha) * model
@@ -185,14 +186,19 @@ def ema_accumulate(self):
185186
cpu_master_weights = self.optimizer_fusion_storage_helper.cpu_buffer._slice(
186187
self.master_min_offset, self.master_max_offset
187188
).cpu()
188-
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
189-
# logger.info(f'[ZCC EMA2] wait all done, doing EMA w/ coef: {self.ema_coef}, status:{self.status()}')
190-
for index, ema_buf in self.ema_buffer_model_params.items():
191-
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
192-
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf.cpu()
193-
self.ema_buffer_model_params[index] = updated_ema
194-
195-
logger.info(f"[ZCC EMA] accumulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done")
189+
if zcc_ema_loss_threshold is None or loss < zcc_ema_loss_threshold:
190+
self.ema_buffer = self.ema_coef * self.ema_buffer + (1 - self.ema_coef) * cpu_master_weights
191+
for index, ema_buf in self.ema_buffer_model_params.items():
192+
_, cpu_buf = self.param_fusion_storage_helper.inited_buffers[index]
193+
updated_ema = self.ema_coef * ema_buf + (1 - self.ema_coef) * cpu_buf
194+
self.ema_buffer_model_params[index] = updated_ema
195+
logger.info(
196+
f"[ZCC EMA] accmulating, buffer type:{self.ema_buffer.place} {self.ema_buffer.dtype}, done"
197+
)
198+
else:
199+
logger.info(
200+
f"[ZCC EMA] accmulating SKIP for global_step:{global_step}, because loss:{loss} > threshold:{zcc_ema_loss_threshold}"
201+
)
196202

197203
@imperative_base.no_grad()
198204
def ema_state_dict(self):
@@ -790,7 +796,11 @@ def process_offload_task(self, dump, global_step):
790796
self.global_step.value = global_step
791797

792798
if self.ema_coef is not None:
793-
self.zcc_ema_processor.ema_accumulate()
799+
self.zcc_ema_processor.ema_accumulate(
800+
self.trainer_state.global_step,
801+
self.trainer_state.loss,
802+
self.training_args_content.zcc_ema_loss_threshold,
803+
)
794804

795805
# continue to process dumping task at the last chunk
796806
if self.offloaded_numels == self.all_numel:
@@ -1006,3 +1016,86 @@ def manage_offload_chunk(self):
10061016
logger.info(
10071017
f"[ZCC Worker{self.worker_id}] All numel: {self.all_numel}, Offload chunks: {self.offload_chunks}, Chunk size: {self.chunk_size_in_numel}]"
10081018
)
1019+
1020+
1021+
class EMABuffer:
1022+
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
1023+
assert sharding_io is not None, "EMA should be only enabled when save_sharded_model is True"
1024+
self.master_weights = {}
1025+
self.model_params = {}
1026+
self.args = args
1027+
self.sharding_io = sharding_io
1028+
self.offload = offload
1029+
if resume_from_checkpoint is not None:
1030+
self._load(resume_from_checkpoint)
1031+
1032+
def _ema_path(self, base_path):
1033+
path = _add_variant(PADDLE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
1034+
path = path.replace("optimizer", "ema")
1035+
return os.path.join(base_path, path)
1036+
1037+
def _load(self, resume_from_checkpoint):
1038+
ema_path = self._ema_path(resume_from_checkpoint)
1039+
if not os.path.exists(ema_path):
1040+
return
1041+
1042+
logger.info(f"Loading EMA checkpoint from {resume_from_checkpoint} ...")
1043+
with device_guard("cpu"):
1044+
ema_state_dict = paddle.load(ema_path)
1045+
logger.info(f"Load EMA checkpoint from {resume_from_checkpoint} done")
1046+
1047+
self.master_weights = ema_state_dict.pop("master_weights")
1048+
self.model_params = ema_state_dict
1049+
1050+
def save(self, global_step):
1051+
base_path = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{global_step}")
1052+
ema_path = self._ema_path(base_path)
1053+
ema_state_dict = {"master_weights": self.master_weights}
1054+
ema_state_dict.update(self.model_params)
1055+
os.makedirs(base_path, exist_ok=True)
1056+
logger.info(f"Saving EMA checkpoint to {base_path} ...")
1057+
paddle.save(ema_state_dict, ema_path)
1058+
logger.info(f"Save EMA checkpoint to {base_path} done")
1059+
1060+
def ema_accumulate(self, global_step, loss, ema_loss_threshold):
1061+
if ema_loss_threshold is None or loss < ema_loss_threshold:
1062+
logger.info(f"EMA accumulating for step {global_step} ...")
1063+
self._ema_impl(
1064+
state_dict=self.sharding_io.optimizer.state_dict()["master_weights"],
1065+
ema_state_dict=self.master_weights,
1066+
)
1067+
self._ema_impl(
1068+
state_dict=self.sharding_io.manipulate_state_dict_and_config(
1069+
unwrap_model(self.sharding_io.model),
1070+
merge_tensor_parallel=False,
1071+
)[0],
1072+
ema_state_dict=self.model_params,
1073+
)
1074+
logger.info(f"EMA accumulate done for step {global_step}")
1075+
1076+
def _ema_impl(self, state_dict, ema_state_dict):
1077+
ema_coef = self.args.zcc_save_ema_coef
1078+
for k, v in state_dict.items():
1079+
if k in ema_state_dict:
1080+
ema_tensor = ema_state_dict[k]
1081+
ema_tensor = ema_coef * ema_tensor.cuda() + (1 - ema_coef) * v.cuda()
1082+
ema_tensor.name = v.name
1083+
v = ema_tensor
1084+
del ema_tensor
1085+
1086+
if self.offload:
1087+
v_pin = v.pin_memory()
1088+
v_pin.name = v.name
1089+
v = v_pin
1090+
ema_state_dict[k] = v
1091+
1092+
1093+
class NonZCCEMACallback(TrainerCallback):
1094+
def __init__(self, resume_from_checkpoint, args, sharding_io, offload=True):
1095+
self.buffer = EMABuffer(resume_from_checkpoint, args, sharding_io, offload)
1096+
1097+
def on_step_end(self, args, state, control, **kwargs):
1098+
if state.global_step % args.zcc_ema_interval == 0:
1099+
self.buffer.ema_accumulate(state.global_step, state.loss, args.zcc_ema_loss_threshold)
1100+
if control.should_save:
1101+
self.buffer.save(state.global_step)

0 commit comments

Comments
 (0)