From c335f6ef7751f379aac9f48f6c26cafc90f52103 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 01:01:18 +0000 Subject: [PATCH 01/51] train with only layer distillation losses --- fast_llm/layers/language_model/head.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..db768ca12 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -409,14 +409,23 @@ def _logits_cross_entropy_forward_backward( else: distillation_loss, distillation_grad = None, None - # TODO: de-allocate earlier. - del logits - # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + + # When using only activation distillation, loss and grad are None. + # Create zero tensors to allow activation distillation gradients to flow through. + if loss is None: + loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) + if grad is None: + # Zero gradient means no loss at the head, but activation distillation gradients + grad = torch.zeros_like(logits) + + # TODO: de-allocate earlier. + del logits + if self.training and losses is not None: if dpo_loss is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) @@ -502,11 +511,12 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + # All tensors are None - this is valid when using only activation distillation + return None From e06a4b2ca02b22dc56e798aabf0b8c30fe280417 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 16 Dec 2025 14:15:45 +0000 Subject: [PATCH 02/51] unscaled loss llogging + training with distillation loss factor = 0 --- fast_llm/layers/language_model/head.py | 53 +++++++++++++++++++------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index db768ca12..733311d39 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -370,11 +370,13 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) + if self.training and losses is not None: + losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + if distillation_target is not None: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -405,9 +407,9 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) + if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor - else: - distillation_loss, distillation_grad = None, None # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) @@ -415,14 +417,6 @@ def _logits_cross_entropy_forward_backward( # TODO: Return individual losses? loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) - # When using only activation distillation, loss and grad are None. - # Create zero tensors to allow activation distillation gradients to flow through. - if loss is None: - loss = torch.zeros(1, device=input_.device, dtype=input_.dtype, requires_grad=True) - if grad is None: - # Zero gradient means no loss at the head, but activation distillation gradients - grad = torch.zeros_like(logits) - # TODO: de-allocate earlier. del logits @@ -443,6 +437,13 @@ def _loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _ce_loss_name_unscaled(self) -> str: + name = "language_model_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -471,8 +472,24 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name + @functools.cached_property + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: + # unscaled CE loss (NTP) + loss_defs = [ + LossDef( + name=self._ce_loss_name_unscaled, + formatted_name=_format_name(self._ce_loss_name_unscaled), + count=count, + ) + ] if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -490,6 +507,15 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + # unscaled distillation loss for comparison purposes + loss_defs.append( + LossDef( + name=self._distillation_loss_name_unscaled, + formatted_name=_format_name(self._distillation_loss_name_unscaled), + count=count, + ) + ) + # if we mix distillation loss and CE loss for NTP, we want to log both if self._config.language_model_loss_factor > 0.0: loss_defs.append( LossDef( @@ -511,12 +537,11 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor | None: +def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: tensors = [tensor for tensor in tensors if tensor is not None] if len(tensors) > 1: return sum(tensors) elif len(tensors) == 1: return tensors[0] else: - # All tensors are None - this is valid when using only activation distillation - return None + raise RuntimeError() From 179ae25e9db3ecda3c75762288abe824c31e65fd Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 21:07:54 +0000 Subject: [PATCH 03/51] make logging more explicit --- fast_llm/layers/language_model/config.py | 12 ++ fast_llm/layers/language_model/head.py | 217 +++++++++++++++-------- 2 files changed, 153 insertions(+), 76 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 53dac2892..13c6d87eb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -168,11 +168,21 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Factor to scale the language modeling loss by when using distillation.", hint=FieldHint.feature, ) + track_language_model_loss: bool = Field( + default=False, + desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", + hint=FieldHint.feature, + ) distillation_loss_factor: float = Field( default=1.0, desc="Factor to scale the distillation loss by when using distillation.", hint=FieldHint.feature, ) + track_distillation_loss: bool = Field( + default=False, + desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + hint=FieldHint.feature, + ) logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -243,6 +253,8 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() + if self.distillation_model is None: + Assert.is_(self.track_distillation_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 733311d39..e785c09e5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -113,6 +113,12 @@ def __init__( peft=self._peft, ) + self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss + self._compute_dpo_loss = self._config.enable_dpo + self._compute_distillation_loss = self._config.distillation_model is not None and ( + self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: @@ -137,8 +143,6 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if losses is not None and language_model_loss is not None: - losses[self._loss_name].append(language_model_loss.detach()) # TODO: Return the model output when needed. if self._is_last_head: # Last head should return the loss for backward. @@ -205,25 +209,22 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - else: - lm_target = None + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice( + self._prediction_distance, self._prediction_distance + lm_target_sequence_length + ) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: @@ -246,7 +247,7 @@ def _logits_cross_entropy_forward_backward_split( losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if self._config.cross_entropy_splits is None or targets is None: - loss, logit_input_grad = self._logits_cross_entropy_forward_backward( + loss, logit_input_grad = self._logits_loss_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) if targets is None: @@ -279,7 +280,7 @@ def _logits_cross_entropy_forward_backward_split( for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): - loss_, grad_ = self._logits_cross_entropy_forward_backward( + loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, weight, @@ -301,7 +302,7 @@ def _logits_cross_entropy_forward_backward_split( all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - def _logits_cross_entropy_forward_backward( + def _logits_loss_forward_backward( self, input_: torch.Tensor, targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], @@ -359,7 +360,7 @@ def _logits_cross_entropy_forward_backward( else: dpo_loss, dpo_grad = None, None - if lm_target is not None: + if lm_target is not None and self._compute_lm_loss: lm_loss, lm_grad = cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, @@ -370,13 +371,10 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - if self.training and losses is not None: - losses[self._ce_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None: + if distillation_target is not None and self._compute_distillation_loss: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), @@ -407,39 +405,121 @@ def _logits_cross_entropy_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - if self.training and losses is not None: # we keep track of unscaled losses for model comparison purposes - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor - - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - - # TODO: Return individual losses? - loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + else: + distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits + loss, grad = self._post_process_loss_and_grad( + dpo_loss, + dpo_grad, + lm_loss, + lm_grad, + distillation_loss, + distillation_grad, + losses, + loss_mask, + kwargs, + ) + + return loss, output_parallel_linear_backward(grad, context) if self.training else None - if self.training and losses is not None: - if dpo_loss is not None: + def _post_process_loss_and_grad( + self, + dpo_loss: torch.Tensor | None, + dpo_grad: torch.Tensor | None, + lm_loss: torch.Tensor | None, + lm_grad: torch.Tensor | None, + distillation_loss: torch.Tensor | None, + distillation_grad: torch.Tensor | None, + losses: dict | None, + loss_mask: torch.Tensor | None, + kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. + + Arguments: + - Losses: unscaled losses from different components (DPO, LM CE, Distillation) + - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. + """ + # Extremely explicit but easier to follow. + ############ + if dpo_loss is not None: + if self.training and losses is not None: losses[self._dpo_loss_name].append(dpo_loss.detach()) - if self._config.distillation_model is not None and distillation_loss is not None: + else: + Assert.is_(dpo_grad, None) + + if lm_loss is not None: + if self.training and losses is not None: + losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df + if self.training and losses is not None: + losses[self._lm_loss_name].append(lm_loss.detach()) + else: + Assert.is_(lm_grad, None) + + if distillation_loss is not None: + # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. + # The runner averages losses by dividing by num_micro_batches, so we need to account for that. + # Note: for grads this scaling is already in the 'grad_output' + total_valid_tokens = kwargs.get( + LanguageModelKwargs.total_valid_tokens + ) # number of not masked tokens across all micro-batches. + num_micro_batches = kwargs.get("num_micro_batches", 1) + + if loss_mask is None or total_valid_tokens is None: + loss_scalor_df = 1 + else: + valid_tokens = loss_mask.sum() + # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens + # This accounts for the runner dividing by num_micro_batches + loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens + distillation_loss = distillation_loss * loss_scalor_df + if self.training and losses is not None: + losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) + distillation_loss = distillation_loss * self._config.distillation_loss_factor + if self.training and losses is not None: losses[self._distillation_loss_name].append(distillation_loss.detach()) - if self._config.distillation_model is not None and lm_loss is not None: - losses[self._distillation_language_model_loss_name].append(lm_loss.detach()) + else: + Assert.is_(distillation_grad, None) - return loss, output_parallel_linear_backward(grad, context) if self.training else None + ############ + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + if losses is not None and total_loss is not None: + losses[self._total_loss_name].append(total_loss.detach()) + + return total_loss, grad @functools.cached_property - def _loss_name(self) -> str: - name = "language_model_loss" + def _total_loss_name(self) -> str: + """ + Combined total scaled loss used for training. + """ + name = "lm_head_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _ce_loss_name_unscaled(self) -> str: - name = "language_model_loss_unscaled" + def _lm_loss_name_unscaled(self) -> str: + """ + Unscaled language model cross-entropy loss. + """ + name = "lm_loss_unscaled" + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + + @functools.cached_property + def _lm_loss_name(self) -> str: + """ + Scaled language model cross-entropy loss. + """ + name = "lm_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -459,8 +539,8 @@ def _dpo_loss_name(self) -> str: return name @functools.cached_property - def _distillation_language_model_loss_name(self) -> str: - name = "distillation_language_model_loss" + def _distillation_loss_name_unscaled(self) -> str: + name = "distillation_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -472,34 +552,28 @@ def _distillation_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [LossDef(name=self._loss_name, formatted_name=_format_name(self._loss_name), count=count)] - if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: - # unscaled CE loss (NTP) - loss_defs = [ + loss_defs = [ + LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) + ] + if self._compute_lm_loss: + loss_defs.append( LossDef( - name=self._ce_loss_name_unscaled, - formatted_name=_format_name(self._ce_loss_name_unscaled), + name=self._lm_loss_name_unscaled, + formatted_name=_format_name(self._lm_loss_name_unscaled), count=count, ) - ] + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) - if self._config.enable_dpo: + if self._compute_dpo_loss: loss_defs.append( LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) ) - if self._config.distillation_model is not None: + if self._compute_distillation_loss: loss_defs.append( LossDef( name=self._distillation_loss_name, @@ -515,15 +589,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - # if we mix distillation loss and CE loss for NTP, we want to log both - if self._config.language_model_loss_factor > 0.0: - loss_defs.append( - LossDef( - name=self._distillation_language_model_loss_name, - formatted_name=_format_name(self._distillation_language_model_loss_name), - count=count, - ) - ) return loss_defs @@ -544,4 +609,4 @@ def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: elif len(tensors) == 1: return tensors[0] else: - raise RuntimeError() + raise RuntimeError("No tensors to add.") From 9968aac14c439823c6850e0dcc4e2210b5ad2cf3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:28 +0000 Subject: [PATCH 04/51] clean + tests --- fast_llm/layers/language_model/head.py | 24 ++---- tests/layers/test_lm_head.py | 107 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 33 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e785c09e5..8a4601941 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -461,22 +461,7 @@ def _post_process_loss_and_grad( Assert.is_(lm_grad, None) if distillation_loss is not None: - # We need to scale the loss by (valid_tokens * num_micro_batches) / total_valid_tokens to correctly average the loss over micro-batches. - # The runner averages losses by dividing by num_micro_batches, so we need to account for that. - # Note: for grads this scaling is already in the 'grad_output' - total_valid_tokens = kwargs.get( - LanguageModelKwargs.total_valid_tokens - ) # number of not masked tokens across all micro-batches. - num_micro_batches = kwargs.get("num_micro_batches", 1) - - if loss_mask is None or total_valid_tokens is None: - loss_scalor_df = 1 - else: - valid_tokens = loss_mask.sum() - # Scale by (valid_tokens * num_micro_batches) / total_valid_tokens - # This accounts for the runner dividing by num_micro_batches - loss_scalor_df = (valid_tokens * num_micro_batches) / total_valid_tokens - distillation_loss = distillation_loss * loss_scalor_df + distillation_loss = distillation_loss if self.training and losses is not None: losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) distillation_loss = distillation_loss * self._config.distillation_loss_factor @@ -564,6 +549,13 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) + loss_defs.append( + LossDef( + name=self._lm_loss_name, + formatted_name=_format_name(self._lm_loss_name), + count=count, + ) + ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d82..88ff9d612 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -55,6 +55,8 @@ def _lm_head( logit_scale_factor: float = 1.0, logit_z_loss=0.0, distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, + language_model_loss_factor: float = 1.0, + distillation_loss_factor: float = 1.0, ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -69,23 +71,31 @@ def _lm_head( loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) - loss.backward(torch.full_like(loss, grad_output)) - return loss, None + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + # Return scaled loss + return loss * distillation_loss_factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None if target.ndim == logits.ndim: + # Distillation loss (cross-entropy with soft targets) loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" ) if loss_mask is not None: loss = loss * loss_mask.flatten() loss = loss.mean() + # Apply distillation_loss_factor + loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) + return loss * distillation_loss_factor, z_loss else: + # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - loss.backward(torch.full_like(loss, grad_output)) - return loss, z_loss + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) + return loss * language_model_loss_factor, z_loss SEQUENCE_LENGTH = 200 @@ -154,6 +164,54 @@ def _lm_head( True, 1, ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "track_language_model_loss": True, + "distillation_loss_factor": 1.0, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": True, + "track_distillation_loss": True, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "language_model_loss_factor": 0.0, + "distillation_loss_factor": 0.0, + "track_language_model_loss": False, + "track_distillation_loss": False, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -292,6 +350,10 @@ def test_lm_head( logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, distillation_loss_implementation=head_config.distillation_loss_implementation, + language_model_loss_factor=( + head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 + ), + distillation_loss_factor=head_config.distillation_loss_factor, ) # Prepare LM head inputs @@ -303,20 +365,27 @@ def test_lm_head( head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" - loss_keys = {loss_name} + lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" + expected_loss_keys = {lm_head_loss_name} + if head._compute_lm_loss: + lm_loss_name_unscaled = ( + f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" + ) + lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" + + expected_loss_keys.add(lm_loss_name_unscaled) + expected_loss_keys.add(lm_loss_name) if ref_z_loss is not None: - loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head_config.distillation_model is not None: - loss_keys.add("distillation_loss") - if head_config.language_model_loss_factor > 0: - loss_keys.add("distillation_language_model_loss") + expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + if head._compute_distillation_loss: + expected_loss_keys.add("distillation_loss") + expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in loss_keys}, + {loss_key: 1 for loss_key in expected_loss_keys}, ) - losses = {key: [] for key in loss_keys} + losses = {key: [] for key in expected_loss_keys} output, context = stage.forward(head_input, kwargs, losses) stage.backward(output_grad, context) @@ -325,16 +394,16 @@ def test_lm_head( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * head_config.logits_scale_factor - Assert.eq(losses.keys(), loss_keys) - Assert.eq(len(losses[loss_name]), 1) + Assert.eq(losses.keys(), expected_loss_keys) + Assert.eq(len(losses[lm_head_loss_name]), 1) if ref_z_loss is not None: Assert.eq(len(losses["z_loss"]), 1) Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) if head._is_last_head: - Assert.all_equal(output, losses[loss_name][0]) + Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) @@ -344,3 +413,7 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 945c5a774bf30fbb088a818f12f5510e98f99bbb Mon Sep 17 00:00:00 2001 From: oleksost Date: Wed, 17 Dec 2025 22:38:54 +0000 Subject: [PATCH 05/51] nvm --- tests/layers/test_lm_head.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 88ff9d612..c6d806db8 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -413,7 +413,3 @@ def test_lm_head( Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 4b6e3d7503b0cf8a93aef156a0328c2b6dc67cc8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 21:28:55 +0000 Subject: [PATCH 06/51] forward KL --- fast_llm/functional/config.py | 1 + fast_llm/functional/cross_entropy.py | 128 +++++++++++++++++++++++++ fast_llm/layers/language_model/head.py | 21 +++- 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 4cfc3b61d..20ed99fde 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -102,6 +102,7 @@ class CrossEntropyImpl(str, enum.Enum): class DistillationLossImpl(str, enum.Enum): reverse_kl = "reverse_kl" + forward_kl = "forward_kl" cross_entropy = "cross_entropy" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 8c9ea9399..5a618eea0 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -359,3 +359,131 @@ def reverse_kl_forward_backward( group=group, ) return distillation_loss, distillation_grad + + +@torch.compile +def _forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Forward KL: KL(p||q) where p=teacher, q=student. + This is reverse KL with roles swapped in the loss computation. + + Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) + = sum_i p_i * (log(p_i) - log(q_i)) + which is reverse KL with p and q swapped. + + However, we still need grad w.r.t. student logits, so gradient is different: + d/d(student_logits) KL(p||q) = student_probs - teacher_probs + """ + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # Compute log softmax for both teacher and student + teacher_log_probs = distributed_log_softmax(target.float(), group=group) + student_log_probs = distributed_log_softmax(logits, group=group) + + teacher_probs = teacher_log_probs.exp() + # Forward KL: p * log(p/q) = p * (log_p - log_q) + log_ratio = teacher_log_probs - student_log_probs + del teacher_log_probs + + # Compute loss: sum over vocab of teacher_probs * log_ratio + loss_terms = (teacher_probs * log_ratio).sum(dim=-1) + del log_ratio + + if loss_mask is not None: + valid = loss_mask.to(loss_terms.dtype) + loss_terms = loss_terms * valid + valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) + loss = loss_terms.sum() + + if group is not None: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= valid_tokens + + if grad_output is not None: + # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs + student_probs = student_log_probs.exp() + grad_base = student_probs - teacher_probs + del student_probs, teacher_probs, student_log_probs + + if loss_mask is not None: + grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) + + grad_base.mul_(grad_output / valid_tokens) + grad = grad_base.to(logits.dtype) + else: + grad = None + + return loss.detach_(), grad + + +def forward_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + teacher_softmax_temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + sequence_parallel_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). + This is mode-covering (vs. mode-seeking for reverse KL) and useful for: + - Encouraging the model to cover all modes of the target distribution + - Spreading probability mass broadly across the target support + - Standard distillation scenarios where you want to match the full teacher distribution + + Key differences from reverse KL: + - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) + - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) + + Takes: + logits: [BxS, V] or [B, S, V], where V is local vocab size + target: [BxS, V] or [B, S, V] (logits format) + loss_mask: [BxS] or [B, S] or None + ... + + Returns: + loss: Forward KL divergence loss + grad: Gradients w.r.t. logits + """ + + if sequence_parallel_logits: + # TODO: see hybrid dev branch where it is implemented + raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") + + Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + # TODO: implement fused? + distillation_loss, distillation_grad = _forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=teacher_softmax_temperature, + group=group, + ) + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8a4601941..b8a8f0cbb 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -14,7 +14,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block @@ -390,6 +394,21 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: + distillation_loss, distillation_grad = forward_kl_forward_backward( + logits.flatten(0, -2), + distillation_target, + loss_mask, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, + teacher_softmax_temperature=self._config.teacher_softmax_temperature, + target_format=( + TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits + ), + sequence_parallel_logits=self._sequence_parallel_logits, + ) + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), From c5fefa0a13b1903bf88e7187790a94211b8d40cb Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:19:52 +0000 Subject: [PATCH 07/51] test forward kl --- tests/functional/test_cross_entropy.py | 43 ++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 72644d061..716c56ba3 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -8,7 +8,11 @@ import torch from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward +from fast_llm.functional.cross_entropy import ( + cross_entropy_forward_backward, + forward_kl_forward_backward, + reverse_kl_forward_backward, +) from fast_llm.utils import Assert from tests.utils.utils import requires_cuda @@ -127,6 +131,41 @@ def test_reverse_kl(loss_masking, target_format): _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) +def _forward_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): + # Manual reference: sum over vocab then average over all tokens (not just valid ones). + # Forward KL: KL(p||q) where p=teacher, q=student + logits = logits.detach().requires_grad_(True) + per_sample = torch.nn.functional.kl_div( + torch.log_softmax(logits.float(), dim=-1), + torch.log_softmax(target.float(), dim=-1), + reduction="none", + log_target=True, + ).sum(dim=-1) + if loss_mask is not None: + per_sample = per_sample * loss_mask + output = per_sample.sum() / per_sample.numel() + output.backward() + return output, logits.grad + + +@requires_cuda +@pytest.mark.slow +# TODO: Support the same parameterization as above in the reference implementation. +@pytest.mark.parametrize("loss_masking", [False, True]) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +def test_forward_kl(loss_masking, target_format): + logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) + out_ref, grad_ref = _forward_kl_forward_backward_torch(logits, target, loss_mask) + out, grad = forward_kl_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=1.0, + target_format=TargetFormat.logits, + ) + _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) + + def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): try: torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) @@ -189,7 +228,7 @@ def _compare_parallel_cross_entropy( def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (reverse_kl_forward_backward, cross_entropy_forward_backward): + for function in (reverse_kl_forward_backward, forward_kl_forward_backward, cross_entropy_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [False, True]: try: From 411959616793a78f49e76b9c0767d055ba2c1971 Mon Sep 17 00:00:00 2001 From: oleksost Date: Fri, 19 Dec 2025 22:48:44 +0000 Subject: [PATCH 08/51] wip: report unscaled + kl loss --- fast_llm/layers/language_model/config.py | 35 ++++- fast_llm/layers/language_model/head.py | 158 +++++++++++++---------- 2 files changed, 122 insertions(+), 71 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 13c6d87eb..807b39703 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -173,16 +173,37 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", hint=FieldHint.feature, ) - distillation_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the distillation loss by when using distillation.", + track_forward_kl_loss: bool = Field( + default=False, + desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", + hint=FieldHint.feature, + ) + track_reverse_kl_loss: bool = Field( + default=False, + desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", hint=FieldHint.feature, ) - track_distillation_loss: bool = Field( + track_distillation_ce_loss: bool = Field( default=False, - desc="Track the unscaled distillation loss for logging purposes. Will always do if distillation_loss_factor > 0.", + desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", + hint=FieldHint.feature, + ) + forward_kl_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the forward KL loss by when using distillation with forward KL.", hint=FieldHint.feature, ) + reverse_kl_loss_factor: float = Field( + default=1.0, + desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", + hint=FieldHint.feature, + ) + distillation_ce_loss_factor: float = Field( + default=0.0, + desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", + hint=FieldHint.feature, + ) + logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -254,7 +275,9 @@ def _validate(self) -> None: self.language_model_loss_factor = 0.0 super()._validate() if self.distillation_model is None: - Assert.is_(self.track_distillation_loss, False) + Assert.is_(self.track_forward_kl_loss, False) + Assert.is_(self.track_reverse_kl_loss, False) + Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b8a8f0cbb..040dc55dc 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,7 +13,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import ( cross_entropy_forward_backward, forward_kl_forward_backward, @@ -119,8 +119,18 @@ def __init__( self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss self._compute_dpo_loss = self._config.enable_dpo - self._compute_distillation_loss = self._config.distillation_model is not None and ( - self._config.distillation_loss_factor > 0.0 or self._config.track_distillation_loss + self._compute_rkl_loss = self._config.distillation_model is not None and ( + self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss + ) + self._compute_kl_loss = self._config.distillation_model is not None and ( + self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss + ) + self._compute_dist_ce_loss = self._config.distillation_model is not None and ( + self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss + ) + + self._compute_distillation_loss = any( + [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] ) def forward( @@ -378,13 +388,16 @@ def _logits_loss_forward_backward( else: lm_loss, lm_grad = None, None + distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None + distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None + if distillation_target is not None and self._compute_distillation_loss: - if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: - distillation_loss, distillation_grad = reverse_kl_forward_backward( + if self._compute_rkl_loss: + distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -394,12 +407,12 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.forward_kl: - distillation_loss, distillation_grad = forward_kl_forward_backward( + if self._compute_kl_loss: + distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, @@ -409,13 +422,13 @@ def _logits_loss_forward_backward( sequence_parallel_logits=self._sequence_parallel_logits, ) - elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: - distillation_loss, distillation_grad = cross_entropy_forward_backward( + if self._compute_dist_ce_loss: + distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, @@ -424,8 +437,6 @@ def _logits_loss_forward_backward( raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - else: - distillation_loss, distillation_grad = None, None # TODO: de-allocate earlier. del logits @@ -434,10 +445,13 @@ def _logits_loss_forward_backward( dpo_grad, lm_loss, lm_grad, - distillation_loss, - distillation_grad, + distillation_rkl_loss, + distillation_rkl_grad, + distillation_kl_loss, + distillation_kl_grad, + distillation_ce_loss, + distillation_ce_grad, losses, - loss_mask, kwargs, ) @@ -449,10 +463,13 @@ def _post_process_loss_and_grad( dpo_grad: torch.Tensor | None, lm_loss: torch.Tensor | None, lm_grad: torch.Tensor | None, - distillation_loss: torch.Tensor | None, - distillation_grad: torch.Tensor | None, + distillation_rkl_loss: torch.Tensor | None, + distillation_rkl_grad: torch.Tensor | None, + distillation_kl_loss: torch.Tensor | None, + distillation_kl_grad: torch.Tensor | None, + distillation_ce_loss: torch.Tensor | None, + distillation_ce_grad: torch.Tensor | None, losses: dict | None, - loss_mask: torch.Tensor | None, kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -463,6 +480,7 @@ def _post_process_loss_and_grad( - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. """ # Extremely explicit but easier to follow. + # TODO: simplify / shrten / make seperate dataclass? ############ if dpo_loss is not None: if self.training and losses is not None: @@ -471,28 +489,38 @@ def _post_process_loss_and_grad( Assert.is_(dpo_grad, None) if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name_unscaled].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor # does not need scaling by loss_scalor_df if self.training and losses is not None: losses[self._lm_loss_name].append(lm_loss.detach()) + lm_loss = lm_loss * self._config.language_model_loss_factor else: Assert.is_(lm_grad, None) - if distillation_loss is not None: - distillation_loss = distillation_loss + if distillation_rkl_loss is not None: + distillation_rkl_loss = distillation_rkl_loss if self.training and losses is not None: - losses[self._distillation_loss_name_unscaled].append(distillation_loss.detach()) - distillation_loss = distillation_loss * self._config.distillation_loss_factor + losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) + distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_rkl_grad, None) + if distillation_kl_loss is not None: + distillation_kl_loss = distillation_kl_loss + if self.training and losses is not None: + losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) + distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor + else: + Assert.is_(distillation_kl_grad, None) + if distillation_ce_loss is not None: + distillation_ce_loss = distillation_ce_loss if self.training and losses is not None: - losses[self._distillation_loss_name].append(distillation_loss.detach()) + losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) + distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor else: - Assert.is_(distillation_grad, None) + Assert.is_(distillation_ce_grad, None) ############ # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_loss) + grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) + total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: losses[self._total_loss_name].append(total_loss.detach()) @@ -509,7 +537,7 @@ def _total_loss_name(self) -> str: return name @functools.cached_property - def _lm_loss_name_unscaled(self) -> str: + def _lm_loss_name(self) -> str: """ Unscaled language model cross-entropy loss. """ @@ -519,39 +547,36 @@ def _lm_loss_name_unscaled(self) -> str: return name @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Scaled language model cross-entropy loss. - """ - name = "lm_loss" + def _z_loss_name(self) -> str: + name = "z_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" + def _dpo_loss_name(self) -> str: + name = "dpo_loss" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" + def _distillation_kl_loss_name(self) -> str: + name = "distillation_kl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name_unscaled(self) -> str: - name = "distillation_loss_unscaled" + def _distillation_rkl_loss_name(self) -> str: + name = "distillation_rkl_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @functools.cached_property - def _distillation_loss_name(self) -> str: - name = "distillation_loss" + def _distillation_ce_loss_name(self) -> str: + name = "distillation_ce_loss_unscaled" if self._prediction_distance > 0: name = f"{name}_{self._prediction_distance}" return name @@ -568,13 +593,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, ) ) - loss_defs.append( - LossDef( - name=self._lm_loss_name, - formatted_name=_format_name(self._lm_loss_name), - count=count, - ) - ) if self._config.logit_z_loss: loss_defs.append( LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) @@ -585,21 +603,31 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: ) if self._compute_distillation_loss: - loss_defs.append( - LossDef( - name=self._distillation_loss_name, - formatted_name=_format_name(self._distillation_loss_name), - count=count, - ) - ) # unscaled distillation loss for comparison purposes - loss_defs.append( - LossDef( - name=self._distillation_loss_name_unscaled, - formatted_name=_format_name(self._distillation_loss_name_unscaled), - count=count, + if self._compute_kl_loss: + loss_defs.append( + LossDef( + name=self._distillation_kl_loss_name, + formatted_name=_format_name(self._distillation_kl_loss_name), + count=count, + ) + ) + if self._compute_rkl_loss: + loss_defs.append( + LossDef( + name=self._distillation_rkl_loss_name, + formatted_name=_format_name(self._distillation_rkl_loss_name), + count=count, + ) + ) + if self._compute_dist_ce_loss: + loss_defs.append( + LossDef( + name=self._distillation_ce_loss_name, + formatted_name=_format_name(self._distillation_ce_loss_name), + count=count, + ) ) - ) return loss_defs From b55a0a428fb85dc3ce16ec061d1bed5ea2ac619a Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 13:42:48 +0000 Subject: [PATCH 09/51] loss config --- fast_llm/functional/cross_entropy.py | 2 + fast_llm/layers/language_model/config.py | 97 +---- fast_llm/layers/language_model/head.py | 408 +++++------------- .../layers/language_model/lm_head_losses.py | 280 ++++++++++++ 4 files changed, 405 insertions(+), 382 deletions(-) create mode 100644 fast_llm/layers/language_model/lm_head_losses.py diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 5a618eea0..f534d8a78 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -314,6 +314,7 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -443,6 +444,7 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 807b39703..6fc92eaa4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,11 +5,11 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,75 +135,22 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) + losses: dict[str, LossConfig] = Field( + default_factory=dict, + desc="A dictionary of loss names and their configurations.", + hint=FieldHint.core, + ) # TODO: Cleanup output_weight: ParameterConfig = Field( desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - distillation_loss_implementation: DistillationLossImpl = Field( - default=DistillationLossImpl.cross_entropy, - desc="Implementation for the distillation cross-entropy computation.", - hint=FieldHint.performance, - ) cross_entropy_splits: int | None = Field( default=None, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - language_model_loss_factor: float = Field( - default=None, - desc="Factor to scale the language modeling loss by when using distillation.", - hint=FieldHint.feature, - ) - track_language_model_loss: bool = Field( - default=False, - desc="Track the unscaled language modeling loss for logging purposes. Will always do if language_model_loss_factor > 0.", - hint=FieldHint.feature, - ) - track_forward_kl_loss: bool = Field( - default=False, - desc="Track the unscaled forward KL loss for logging purposes. Will always do if distillation_loss_implementation is forward_kl.", - hint=FieldHint.feature, - ) - track_reverse_kl_loss: bool = Field( - default=False, - desc="Track the unscaled reverse KL loss for logging purposes. Will always do if distillation_loss_implementation is reverse_kl.", - hint=FieldHint.feature, - ) - track_distillation_ce_loss: bool = Field( - default=False, - desc="Track the unscaled distillation cross-entropy loss for logging purposes. Will always do if distillation_loss_implementation is cross_entropy.", - hint=FieldHint.feature, - ) - forward_kl_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the forward KL loss by when using distillation with forward KL.", - hint=FieldHint.feature, - ) - reverse_kl_loss_factor: float = Field( - default=1.0, - desc="Factor to scale the reverse KL loss by when using distillation with reverse KL.", - hint=FieldHint.feature, - ) - distillation_ce_loss_factor: float = Field( - default=0.0, - desc="Factor to scale the distillation cross-entropy loss by when using distillation with cross-entropy.", - hint=FieldHint.feature, - ) - logits_scale_factor: float = Field( default=1.0, desc="Multiply output logits by scale factor.", @@ -212,10 +159,10 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - teacher_softmax_temperature: float = Field( - default=1.0, - desc="Divides distillation target logits by this factor.", - doc="Divides distillation target logits by this factor.", + logit_z_loss: float = Field( + default=0.0, + desc="Regularize the logits with Z-loss.", + doc="We recommend 1e-4 for stability, as used for training PaLM.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) @@ -224,11 +171,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Name of the reference model to use for dpo.", hint=FieldHint.feature, ) - dpo_beta: float | None = Field( - default=1.0, - desc="Beta value for DPO loss.", - hint=FieldHint.feature, - ) distillation_model: str | None = Field( default=None, desc="Name of the reference model to use for knowledge distillation." @@ -268,16 +210,17 @@ def layer_class(self) -> "type[LanguageModelHead]": def _validate(self) -> None: with self._set_implicit_default(): - if self.language_model_loss_factor is None: - if self.distillation_model is None: - self.language_model_loss_factor = 1.0 - else: - self.language_model_loss_factor = 0.0 + if not self.losses: + self.losses = { + "lm_loss": LossConfig._from_dict( + {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} + ) + } + + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() - if self.distillation_model is None: - Assert.is_(self.track_forward_kl_loss, False) - Assert.is_(self.track_reverse_kl_loss, False) - Assert.is_(self.track_distillation_ce_loss, False) assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both @property diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 040dc55dc..f23bb6f1c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -13,13 +13,6 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import ( - cross_entropy_forward_backward, - forward_kl_forward_backward, - reverse_kl_forward_backward, -) -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames @@ -31,6 +24,7 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) +from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -91,16 +85,6 @@ def __init__( if self._config.cross_entropy_splits is not None and self._sequence_parallel: assert not self._vocab_parallel - if not self._config.enable_dpo: - self._cross_entropy_impl = self._config.cross_entropy_implementation - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._vocab_parallel: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused - self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) self.final_norm = self._config.normalization.get_layer( @@ -116,22 +100,10 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - - self._compute_lm_loss = self.config.language_model_loss_factor > 0.0 or self.config.track_language_model_loss - self._compute_dpo_loss = self._config.enable_dpo - self._compute_rkl_loss = self._config.distillation_model is not None and ( - self._config.reverse_kl_loss_factor > 0.0 or self._config.track_reverse_kl_loss - ) - self._compute_kl_loss = self._config.distillation_model is not None and ( - self._config.forward_kl_loss_factor > 0.0 or self._config.track_forward_kl_loss - ) - self._compute_dist_ce_loss = self._config.distillation_model is not None and ( - self._config.distillation_ce_loss_factor > 0.0 or self._config.track_distillation_ce_loss - ) - - self._compute_distillation_loss = any( - [self._compute_rkl_loss, self._compute_kl_loss, self._compute_dist_ce_loss] - ) + self._formatted_loss_names = { + loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) + for loss_name, loss_config in self._config.losses.items() + } def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -203,22 +175,25 @@ def _forward_backward( else: return loss, None - def _get_targets( - self, kwargs: dict - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: - # Loss mask for distillation. (Labels are already masked.) + def _get_targets(self, kwargs: dict) -> Targets | None: + ( + lm_target, + dpo_target, + reference_model_logits, + loss_mask, + chosen_spans, + rejected_spans, + dpo_reference_model_logits, + ) = (None, None, None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - lm_target = None - distillation_target = None - loss_mask = None + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: - dpo_target = None - if self._config.distillation_model is None: - distillation_target, loss_mask = None, None - else: + if self._config.distillation_model is not None: # Target is reference model logits. - distillation_target = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) + reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) if loss_mask is not None: loss_mask = loss_mask.flatten() @@ -240,12 +215,29 @@ def _get_targets( else lm_target[:, lm_target_slice] ).flatten() - targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] - if not any(target is not None for target in targets): - # Simplify so we don't have to check every time. - targets = None + if dpo_target is not None: + dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) + if lm_target is not None: + lm_target = split_op(lm_target, self._parallel_dim.group, 0) + if loss_mask is not None: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + if reference_model_logits is not None: + reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) + + targets = Targets( + dpo_target=dpo_target, + lm_target=lm_target, + loss_mask=loss_mask, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + reference_model_logits=reference_model_logits, + dpo_reference_model_logits=dpo_reference_model_logits, + ) + + # Return None if no targets are set + if not targets.has_any_target(): + return None return targets def get_output_weights(self) -> list[torch.Tensor]: @@ -254,7 +246,7 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None, + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -285,15 +277,34 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None + + # Extract target tensors for splitting (keep same order as original tuple) + target_tensors = [ + targets.lm_target, + targets.dpo_target, + targets.reference_model_logits, + targets.loss_mask, + ] split_size = div( - get_unique(target.size(0) for target in targets if target is not None), + get_unique(target.size(0) for target in target_tensors if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *targets, logit_input_grad] + for tensor in [logit_input, *target_tensors, logit_input_grad] ] - for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): + for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( + *tensors_split, strict=True + ): + targets_ = Targets( + lm_target=lm_target_, + dpo_target=dpo_target_, + reference_model_logits=reference_model_logits_, + loss_mask=loss_mask_, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + dpo_reference_model_logits=targets.dpo_reference_model_logits, + ) loss_, grad_ = self._logits_loss_forward_backward( logit_input_, targets_, @@ -319,7 +330,7 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None], + targets: Targets | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -334,6 +345,7 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) + # TODO: also move to lm_head_losses? if self._config.logit_z_loss > 0.0: logits = z_loss( logits, @@ -359,175 +371,48 @@ def _logits_loss_forward_backward( if targets is None: return logits * self._config.logits_scale_factor, None - dpo_target, lm_target, distillation_target, loss_mask = targets - if dpo_target is not None: - dpo_loss, dpo_grad = compute_dpo_loss( + total_loss, grad = None, None + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + continue + # losses are returned unscaled but the grads are already scaled + # we log unscaled losses seperately and the scaled total loss + loss_unscaled_, grad_ = loss_config.compute_loss( logits, - dpo_target, - kwargs.get(f"{self._config.dpo_reference_model}_logits"), - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self._config.dpo_beta, - grad_output * self._loss_coefficient, - ) - else: - dpo_loss, dpo_grad = None, None - - if lm_target is not None and self._compute_lm_loss: - lm_loss, lm_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - lm_target, - None, + targets, + grad_output=( + grad_output * self._loss_coefficient * loss_config.weight_scalor + if grad_output is not None + else None + ), group=group, - grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, - implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.labels, + vocab_parallel=self._vocab_parallel, ) - else: - lm_loss, lm_grad = None, None - - distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad = None, None, None - distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss = None, None, None - - if distillation_target is not None and self._compute_distillation_loss: - if self._compute_rkl_loss: - distillation_rkl_loss, distillation_rkl_grad = reverse_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.reverse_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient - if self._compute_kl_loss: - distillation_kl_loss, distillation_kl_grad = forward_kl_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - grad_output=grad_output * self._loss_coefficient * self._config.forward_kl_loss_factor, - group=group, - logits_scale_factor=self._config.logits_scale_factor, - teacher_softmax_temperature=self._config.teacher_softmax_temperature, - target_format=( - TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits - ), - sequence_parallel_logits=self._sequence_parallel_logits, - ) + if losses is not None and loss_config.log_it: + losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) - if self._compute_dist_ce_loss: - distillation_ce_loss, distillation_ce_grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - distillation_target, - loss_mask, - group=group, - grad_output=grad_output * self._loss_coefficient * self._config.distillation_ce_loss_factor, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._config.logits_scale_factor, - target_format=TargetFormat.logits, - ) + if total_loss is None: + total_loss = loss_ else: - raise ValueError( - f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" - ) - - # TODO: de-allocate earlier. - del logits - loss, grad = self._post_process_loss_and_grad( - dpo_loss, - dpo_grad, - lm_loss, - lm_grad, - distillation_rkl_loss, - distillation_rkl_grad, - distillation_kl_loss, - distillation_kl_grad, - distillation_ce_loss, - distillation_ce_grad, - losses, - kwargs, - ) - - return loss, output_parallel_linear_backward(grad, context) if self.training else None - - def _post_process_loss_and_grad( - self, - dpo_loss: torch.Tensor | None, - dpo_grad: torch.Tensor | None, - lm_loss: torch.Tensor | None, - lm_grad: torch.Tensor | None, - distillation_rkl_loss: torch.Tensor | None, - distillation_rkl_grad: torch.Tensor | None, - distillation_kl_loss: torch.Tensor | None, - distillation_kl_grad: torch.Tensor | None, - distillation_ce_loss: torch.Tensor | None, - distillation_ce_grad: torch.Tensor | None, - losses: dict | None, - kwargs, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - If loss is provided (i.e. not None) it will be logged in scaled and unscaled version. The total loss is also logged. - - Arguments: - - Losses: unscaled losses from different components (DPO, LM CE, Distillation) - - Grads: gradients of the losses w.r.t. logits from different components, already scaled by loss factors. - """ - # Extremely explicit but easier to follow. - # TODO: simplify / shrten / make seperate dataclass? - ############ - if dpo_loss is not None: - if self.training and losses is not None: - losses[self._dpo_loss_name].append(dpo_loss.detach()) - else: - Assert.is_(dpo_grad, None) + total_loss = total_loss + loss_ - if lm_loss is not None: - if self.training and losses is not None: - losses[self._lm_loss_name].append(lm_loss.detach()) - lm_loss = lm_loss * self._config.language_model_loss_factor - else: - Assert.is_(lm_grad, None) - - if distillation_rkl_loss is not None: - distillation_rkl_loss = distillation_rkl_loss - if self.training and losses is not None: - losses[self._distillation_rkl_loss_name].append(distillation_rkl_loss.detach()) - distillation_rkl_loss = distillation_rkl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_rkl_grad, None) - if distillation_kl_loss is not None: - distillation_kl_loss = distillation_kl_loss - if self.training and losses is not None: - losses[self._distillation_kl_loss_name].append(distillation_kl_loss.detach()) - distillation_kl_loss = distillation_kl_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_kl_grad, None) - if distillation_ce_loss is not None: - distillation_ce_loss = distillation_ce_loss - if self.training and losses is not None: - losses[self._distillation_ce_loss_name].append(distillation_ce_loss.detach()) - distillation_ce_loss = distillation_ce_loss * self._config.distillation_loss_factor - else: - Assert.is_(distillation_ce_grad, None) + if grad_ is not None: + if grad is None: + grad = grad_ + else: + grad = grad + grad_ - ############ - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = _add_tensors(dpo_grad, lm_grad, distillation_rkl_grad, distillation_kl_grad, distillation_ce_grad) - total_loss = _add_tensors(dpo_loss, lm_loss, distillation_rkl_loss, distillation_kl_loss, distillation_ce_loss) if losses is not None and total_loss is not None: - losses[self._total_loss_name].append(total_loss.detach()) + losses[self._total_head_loss_name].append(total_loss.detach()) - return total_loss, grad + return total_loss, output_parallel_linear_backward(grad, context) if self.training else None @functools.cached_property - def _total_loss_name(self) -> str: + def _total_head_loss_name(self) -> str: """ Combined total scaled loss used for training. """ @@ -536,16 +421,6 @@ def _total_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _lm_loss_name(self) -> str: - """ - Unscaled language model cross-entropy loss. - """ - name = "lm_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - @functools.cached_property def _z_loss_name(self) -> str: name = "z_loss" @@ -553,81 +428,18 @@ def _z_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _dpo_loss_name(self) -> str: - name = "dpo_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_kl_loss_name(self) -> str: - name = "distillation_kl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_rkl_loss_name(self) -> str: - name = "distillation_rkl_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - - @functools.cached_property - def _distillation_ce_loss_name(self) -> str: - name = "distillation_ce_loss_unscaled" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_defs = [ - LossDef(name=self._total_loss_name, formatted_name=_format_name(self._total_loss_name), count=count) - ] - if self._compute_lm_loss: - loss_defs.append( - LossDef( - name=self._lm_loss_name_unscaled, - formatted_name=_format_name(self._lm_loss_name_unscaled), - count=count, - ) + LossDef( + name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) - if self._config.logit_z_loss: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) - if self._compute_dpo_loss: - loss_defs.append( - LossDef(name=self._dpo_loss_name, formatted_name=_format_name(self._dpo_loss_name), count=count) - ) - - if self._compute_distillation_loss: - # unscaled distillation loss for comparison purposes - if self._compute_kl_loss: - loss_defs.append( - LossDef( - name=self._distillation_kl_loss_name, - formatted_name=_format_name(self._distillation_kl_loss_name), - count=count, - ) - ) - if self._compute_rkl_loss: - loss_defs.append( - LossDef( - name=self._distillation_rkl_loss_name, - formatted_name=_format_name(self._distillation_rkl_loss_name), - count=count, - ) - ) - if self._compute_dist_ce_loss: - loss_defs.append( - LossDef( - name=self._distillation_ce_loss_name, - formatted_name=_format_name(self._distillation_ce_loss_name), - count=count, - ) + ] + for loss_name, loss_config in self._config.losses.items(): + if loss_config.log_it: + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance ) + loss_defs.append(loss_def) return loss_defs @@ -635,17 +447,3 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: def heads(self): # For compatibility with MTP. return [self] - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -def _add_tensors(*tensors: torch.Tensor | None) -> torch.Tensor: - tensors = [tensor for tensor in tensors if tensor is not None] - if len(tensors) > 1: - return sum(tensors) - elif len(tensors) == 1: - return tensors[0] - else: - raise RuntimeError("No tensors to add.") diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py new file mode 100644 index 000000000..cc8e5ebc5 --- /dev/null +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -0,0 +1,280 @@ +import abc +import dataclasses +import logging +import typing + +import torch + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.distributed import ProcessGroup +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# +# CE loss on lm_targets for standard LM training. Here targets are already masked. +# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. +# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). +# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). +# DPO loss for alignment using chosen and rejected spans. +# + + +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@dataclasses.dataclass +class Targets: + lm_target: torch.Tensor | None = None + dpo_target: torch.Tensor | None = None + loss_mask: torch.Tensor | None = None + chosen_spans: list[list[tuple[int, int]]] | None = None + rejected_spans: list[list[tuple[int, int]]] | None = None + reference_model_logits: torch.Tensor | None = None + dpo_reference_model_logits: torch.Tensor | None = None + + def has_any_target(self) -> bool: + return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) + + +@config_class(registry=True) +class LossConfig(Config): + """ + Losses canm register themselves + using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight_scalor: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + log_it: bool = Field( + default=True, + hint=FieldHint.optional, + desc="Whether to log this loss.", + ) + + @abc.abstractmethod + def compute_loss( + self, + logits: torch.Tensor, + target: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight_scalor, 0.0) + if self.weight_scalor > 0.0: + with self._set_implicit_default(): + if "log_it" not in self._explicit_fields: + self.log_it = True + super()._validate() + + def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: + name = f"{self._name}({name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + +@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) +class CrossEntropyLMLossConfig(LossConfig): + _name: typing.ClassVar[str] = "CE" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = targets.lm_target + if target is None: + raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "fkl_dist"}) +class ForwardKLLossConfig(LossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = targets.reference_model_logits + if target is None: + raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "revkl_dist"}) +class ReverseKLLossConfig(LossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + logits_scale_factor: float | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = targets.reference_model_logits + if target is None: + raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=targets.loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + **kwargs, + ) + + +@config_class(dynamic_type={LossConfig: "dpo"}) +class DPOLossConfig(LossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + def compute_loss( + self, + logits: torch.Tensor, + targets: Targets, + grad_output: float | None = None, + group: ProcessGroup | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from fast_llm.functional.dpo import compute_dpo_loss + + return compute_dpo_loss( + logits=logits, + targets=targets.dpo_target, + reference_model_logits=targets.dpo_reference_model_logits, + chosen_spans=targets.chosen_spans, + rejected_spans=targets.rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) From 097baeb4c2396575066f96ced831771e0054ea76 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 14:24:57 +0000 Subject: [PATCH 10/51] wip --- fast_llm/functional/config.py | 6 - fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 8 +- .../layers/language_model/lm_head_losses.py | 6 +- tests/layers/test_lm_head.py | 188 +++++++++--------- tests/utils/model_configs.py | 8 +- 6 files changed, 108 insertions(+), 112 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 20ed99fde..511c2d9f3 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -100,12 +100,6 @@ class CrossEntropyImpl(str, enum.Enum): triton = "triton" -class DistillationLossImpl(str, enum.Enum): - reverse_kl = "reverse_kl" - forward_kl = "forward_kl" - cross_entropy = "cross_entropy" - - class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 6fc92eaa4..786d312d8 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -212,9 +212,7 @@ def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: self.losses = { - "lm_loss": LossConfig._from_dict( - {"type": "cross_entropy_lm_loss", "weight_scalor": 1.0, "log_it": True} - ) + "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) } for loss_config in self.losses.values(): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f23bb6f1c..c8c3be797 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -374,7 +374,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.weight_scalor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0 and not loss_config.log_it: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -382,15 +382,13 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.weight_scalor - if grad_output is not None - else None + grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.weight_scalor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient if losses is not None and loss_config.log_it: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index cc8e5ebc5..a231efa5a 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -54,7 +54,7 @@ class LossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - weight_scalor: float = Field( + factor: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -90,8 +90,8 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.weight_scalor, 0.0) - if self.weight_scalor > 0.0: + Assert.geq(self.factor, 0.0) + if self.factor > 0.0: with self._set_implicit_default(): if "log_it" not in self._explicit_fields: self.log_it = True diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c6d806db8..917bb7efd 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,7 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -119,99 +119,99 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - False, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - "language_model_loss_factor": 1.0, - } - }, - {}, - True, - 1, - ), - ( - { - "head": { - "distillation_model": "distillation", - "distillation_loss_implementation": DistillationLossImpl.reverse_kl, - } - }, - {}, - True, - 1, - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "track_language_model_loss": True, - "distillation_loss_factor": 1.0, - } - }, - {}, - False, - 1, - id="track_lm_zero_factor", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": True, - "track_distillation_loss": True, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "language_model_loss_factor": 0.0, - "distillation_loss_factor": 0.0, - "track_language_model_loss": False, - "track_distillation_loss": False, - } - }, - {}, - False, - 1, - marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - strict=True, - ), - id="zero_factors_no_tracking", - ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # False, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, + # "language_model_loss_factor": 1.0, + # } + # }, + # {}, + # True, + # 1, + # ), + # ( + # { + # "head": { + # "distillation_model": "distillation", + # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # } + # }, + # {}, + # True, + # 1, + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "track_language_model_loss": True, + # "distillation_loss_factor": 1.0, + # } + # }, + # {}, + # False, + # 1, + # id="track_lm_zero_factor", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": True, + # "track_distillation_loss": True, + # } + # }, + # {}, + # False, + # 1, + # id="track_both_zero_factors", + # ), + # pytest.param( + # { + # "head": { + # "distillation_model": "distillation", + # "language_model_loss_factor": 0.0, + # "distillation_loss_factor": 0.0, + # "track_language_model_loss": False, + # "track_distillation_loss": False, + # } + # }, + # {}, + # False, + # 1, + # marks=pytest.mark.xfail( + # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", + # strict=True, + # ), + # id="zero_factors_no_tracking", + # ), ), ) def test_lm_head( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6156cb709..f4e3ecea7 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -552,6 +552,12 @@ def _update_and_add_testing_config( "mistral_distill_logits", updates={ ("model", "base_model", "head", "distillation_model"): "teacher", + ("model", "base_model", "head", "losses"): { + "distillation_loss": { + "type": "revkl_dist", + "factor": 1.0, + }, + }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { "teacher": { @@ -599,7 +605,7 @@ def _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "distillation_loss_factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, ("reference_models"): { From d773d986d54ed3cc1729d9bd8992af116c8f20de Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:47:11 +0000 Subject: [PATCH 11/51] tests --- fast_llm/layers/language_model/head.py | 4 + tests/layers/test_lm_head.py | 340 +++++++++++++++---------- 2 files changed, 214 insertions(+), 130 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c8c3be797..c47a87de1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -432,6 +432,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) ] + if self._config.logit_z_loss > 0.0: + loss_defs.append( + LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) + ) for loss_name, loss_config in self._config.losses.items(): if loss_config.log_it: loss_def: LossDef = loss_config.get_loss_def( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 917bb7efd..5835b6673 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,6 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.lm_head_losses import LossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -43,6 +44,20 @@ def _reverse_kl_loss( return loss +def _kl_loss( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + teacher_softmax_temperature: float = 1.0, +): + return _reverse_kl_loss( + target, + logits, + loss_mask, + teacher_softmax_temperature, + ) + + def _lm_head( input_: torch.Tensor, target: torch.Tensor, @@ -54,9 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - distillation_loss_implementation: DistillationLossImpl = DistillationLossImpl.cross_entropy, - language_model_loss_factor: float = 1.0, - distillation_loss_factor: float = 1.0, + losses: dict[str, LossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -66,36 +79,34 @@ def _lm_head( ) logits = torch.nn.functional.linear(hidden, logit_weight).float() - if distillation_loss_implementation == DistillationLossImpl.reverse_kl: - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - # Return scaled loss - return loss * distillation_loss_factor, None + if "dist_loss" in losses: + if losses["dist_loss"].type == "revkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _reverse_kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None + elif losses["dist_loss"].type == "fkl_dist": + Assert.eq(logits.shape, target.shape) + loss = _kl_loss( + (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask + ) + # Apply distillation_loss_factor to grad_output for backward + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + # Return scaled loss + return loss * losses["dist_loss"].factor, None if logit_scale_factor != 1.0: logits *= logit_scale_factor z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None - if target.ndim == logits.ndim: - # Distillation loss (cross-entropy with soft targets) - loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), target.float().softmax(-1).flatten(0, -2), reduction="none" - ) - if loss_mask is not None: - loss = loss * loss_mask.flatten() - loss = loss.mean() - # Apply distillation_loss_factor - loss.backward(torch.full_like(loss, grad_output * distillation_loss_factor)) - return loss * distillation_loss_factor, z_loss - else: - # Language model loss (cross-entropy with hard labels) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * language_model_loss_factor)) - return loss * language_model_loss_factor, z_loss + # Language model loss (cross-entropy with hard labels) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Apply language_model_loss_factor + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) + return loss * losses["lm_loss"].factor, z_loss SEQUENCE_LENGTH = 200 @@ -119,99 +130,169 @@ def _lm_head( ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), ({}, {}, True, 1), + # Skip CE distillation for now - not yet implemented in new losses system # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # } - # }, - # {}, - # False, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 0.0, + # "log_it": False, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO: Not implemented yet + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # False, # 1, # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + ), + # Skip - CE distillation not implemented # ( # { # "head": { # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.cross_entropy, - # "language_model_loss_factor": 1.0, - # } - # }, - # {}, - # True, - # 1, - # ), - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "distillation_loss_implementation": DistillationLossImpl.reverse_kl, + # "losses": { + # "lm_loss": { + # "type": "cross_entropy_lm_loss", + # "weight_scalor": 1.0, + # "log_it": True, + # }, + # "dist_loss": { + # "type": "cross_entropy_dist", # TODO + # "weight_scalor": 1.0, + # "log_it": True, + # } + # } # } # }, # {}, # True, # 1, # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "track_language_model_loss": True, - # "distillation_loss_factor": 1.0, - # } - # }, - # {}, - # False, - # 1, - # id="track_lm_zero_factor", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": True, - # "track_distillation_loss": True, - # } - # }, - # {}, - # False, - # 1, - # id="track_both_zero_factors", - # ), - # pytest.param( - # { - # "head": { - # "distillation_model": "distillation", - # "language_model_loss_factor": 0.0, - # "distillation_loss_factor": 0.0, - # "track_language_model_loss": False, - # "track_distillation_loss": False, - # } - # }, - # {}, - # False, - # 1, - # marks=pytest.mark.xfail( - # reason="No losses computed when all factors=0 and tracking=False, raises RuntimeError in _add_tensors", - # strict=True, - # ), - # id="zero_factors_no_tracking", - # ), + ( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + True, + 1, + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking even with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, + }, + }, + } + }, + {}, + False, + 1, + id="track_lm_zero_factor", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": True, # tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + id="track_both_zero_factors", + ), + pytest.param( + { + "head": { + "distillation_model": "distillation", + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 0.0, + "log_it": False, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="No losses computed when all factors=0 and log_it=False", + strict=True, + ), + id="zero_factors_no_tracking", + ), ), ) def test_lm_head( @@ -222,8 +303,15 @@ def test_lm_head( prediction_heads: int, ): head_config = { - "cross_entropy_implementation": cross_entropy_impl, "normalization": {"type": "rms_norm"}, + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "implementation": cross_entropy_impl, + "factor": 1.0, + "log_it": True, + } + }, } config = GPTBaseModelConfig.from_dict( { @@ -280,19 +368,19 @@ def test_lm_head( AttentionKwargs.sequence_first: sequence_first, AttentionKwargs.grad_output: 1.0, } - if head_config.distillation_model is None: - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, - ) - if loss_mask is not None: - target *= loss_mask + # always set lm targets + target = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=distributed.device, + ) + if loss_mask is not None: + target *= loss_mask - kwargs[LanguageModelKwargs.labels] = target - else: + kwargs[LanguageModelKwargs.labels] = target + if head_config.distillation_model is not None: assert config.head.max_prediction_distance == 1 target = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), @@ -349,11 +437,7 @@ def test_lm_head( logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, logit_z_loss=head_config.logit_z_loss, - distillation_loss_implementation=head_config.distillation_loss_implementation, - language_model_loss_factor=( - head_config.language_model_loss_factor if head_config.language_model_loss_factor is not None else 1.0 - ), - distillation_loss_factor=head_config.distillation_loss_factor, + losses=head_config.losses, ) # Prepare LM head inputs @@ -367,19 +451,15 @@ def test_lm_head( lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" expected_loss_keys = {lm_head_loss_name} - if head._compute_lm_loss: - lm_loss_name_unscaled = ( - f"lm_loss_unscaled_{prediction_distance}" if prediction_distance > 0 else "lm_loss_unscaled" - ) - lm_loss_name = f"lm_loss_{prediction_distance}" if prediction_distance > 0 else "lm_loss" - expected_loss_keys.add(lm_loss_name_unscaled) - expected_loss_keys.add(lm_loss_name) + # Get expected loss names from the loss configs + for loss_name, loss_config in head._config.losses.items(): + if loss_config.log_it: + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) + if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - if head._compute_distillation_loss: - expected_loss_keys.add("distillation_loss") - expected_loss_keys.add("distillation_loss_unscaled") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, From 282925c5bcd6f3b2648aa1cfd4d40bed4058a739 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 16:51:37 +0000 Subject: [PATCH 12/51] test --- tests/layers/test_lm_head.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 5835b6673..6bdaf3f67 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -293,6 +293,32 @@ def _lm_head( ), id="zero_factors_no_tracking", ), + pytest.param( + { + "head": { + "losses": { + "lm_loss": { + "type": "cross_entropy_lm_loss", + "factor": 1.0, + "log_it": False, # not tracking with zero weight + }, + "dist_loss": { + "type": "revkl_dist", + "factor": 1.0, + "log_it": True, # not tracking with zero weight + }, + }, + } + }, + {}, + False, + 1, + marks=pytest.mark.xfail( + reason="Cannot track distillation loss without distillation model being set", + strict=True, + ), + id="track_distillation_without_model", + ), ), ) def test_lm_head( From 0f73ea23d62e43c41c45a9e755e9e3db38a3a5a3 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 17:54:53 +0000 Subject: [PATCH 13/51] tests --- fast_llm/layers/language_model/config.py | 13 ++--- fast_llm/layers/language_model/head.py | 1 + .../layers/language_model/lm_head_losses.py | 47 +++++++++---------- tests/test_config.py | 1 + tests/utils/model_configs.py | 28 +++-------- 5 files changed, 35 insertions(+), 55 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 786d312d8..411e98f4c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,17 +209,12 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: - with self._set_implicit_default(): - if not self.losses: - self.losses = { - "lm_loss": LossConfig._from_dict({"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}) - } - - for loss_config in self.losses.values(): - if "dist" in loss_config.type: - assert self.distillation_model is not None, "Distillation loss requires a distillation model." + for loss_config in self.losses.values(): + if "dist" in loss_config.type: + assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c47a87de1..e1f303323 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,6 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + assert self._config.losses, "At least one loss must be configured." self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index a231efa5a..9fd946625 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,17 +3,16 @@ import logging import typing -import torch - from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + import torch + + from fast_llm.core.distributed import ProcessGroup logger = logging.getLogger(__name__) @@ -32,13 +31,13 @@ def _format_name(name: str) -> str: @dataclasses.dataclass class Targets: - lm_target: torch.Tensor | None = None - dpo_target: torch.Tensor | None = None - loss_mask: torch.Tensor | None = None + lm_target: "torch.Tensor | None" = None + dpo_target: "torch.Tensor | None" = None + loss_mask: "torch.Tensor | None" = None chosen_spans: list[list[tuple[int, int]]] | None = None rejected_spans: list[list[tuple[int, int]]] | None = None - reference_model_logits: torch.Tensor | None = None - dpo_reference_model_logits: torch.Tensor | None = None + reference_model_logits: "torch.Tensor | None" = None + dpo_reference_model_logits: "torch.Tensor | None" = None def has_any_target(self) -> bool: return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) @@ -70,14 +69,14 @@ class LossConfig(Config): @abc.abstractmethod def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", target: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: @@ -124,14 +123,14 @@ class CrossEntropyLMLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward target = targets.lm_target @@ -176,13 +175,13 @@ class ForwardKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward target = targets.reference_model_logits @@ -218,13 +217,13 @@ class ReverseKLLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, logits_scale_factor: float | None = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses @@ -261,12 +260,12 @@ class DPOLossConfig(LossConfig): def compute_loss( self, - logits: torch.Tensor, + logits: "torch.Tensor", targets: Targets, grad_output: float | None = None, - group: ProcessGroup | None = None, + group: "ProcessGroup" = None, **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss return compute_dpo_loss( diff --git a/tests/test_config.py b/tests/test_config.py index 4020b6fbc..8d6f39249 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,6 +147,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, + "head": {}, }, "hidden_size": 512, "tied_embedding_weight": False, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f4e3ecea7..3cadb4e20 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -240,7 +240,12 @@ def _update_and_add_testing_config( }, "num_blocks": 2, }, - "head": {"output_weight": init_1}, + "head": { + "output_weight": init_1, + "losses": { + "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + }, + }, "hidden_size": 256, "tied_embedding_weight": True, }, @@ -580,27 +585,6 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( - "mistral_distill_logits", - "mistral_reverse_kl", - updates={ - ("model", "base_model", "head", "distillation_loss_implementation"): "reverse_kl", - }, - megatron_args=None, - checkpoint_format=MistralCheckpointFormat, - groups={ - ModelTestingGroup.basic: ModelTestingGroupAction.normal, - ModelTestingGroup.checkpoint: ModelTestingGroupAction.unimportant, - ModelTestingGroup.convert: ModelTestingGroupAction.unimportant, - ModelTestingGroup.generate: ModelTestingGroupAction.unimportant, - ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, - ModelTestingGroup.distributed: ModelTestingGroupAction.broken, # failing: fp16, tp2, stp2, stp2_ce4 - }, - compare_factor=2, - # Modes not supported with reference models - skip_tests=("sdp", "ms", "pp"), -) - _update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", From fa85c415abd4481baba7ac9b9e037854e72cea82 Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 22 Dec 2025 22:27:28 +0000 Subject: [PATCH 14/51] wip --- fast_llm/functional/cross_entropy.py | 104 +++----------- fast_llm/layers/language_model/config.py | 4 +- fast_llm/layers/language_model/head.py | 13 +- .../layers/language_model/lm_head_losses.py | 30 ++-- tests/layers/test_lm_head.py | 132 +++--------------- tests/utils/model_configs.py | 4 +- 6 files changed, 55 insertions(+), 232 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index f534d8a78..06c85848c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -85,6 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, + return_target_entropy: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -158,6 +159,16 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) + if return_target_entropy and target_format == TargetFormat.logits: + # Compute teacher entropy + teacher_log_prob = torch.log(target + 1e-20) + target_entropy = -(target * teacher_log_prob).sum(dim=-1) + if loss_mask is not None: + target_entropy = target_entropy * loss_mask.squeeze(-1) + target_entropy = target_entropy.mean() + if group is not None: + all_reduce(target_entropy, op=ReduceOp.SUM, group=group) + return loss, grad, target_entropy return loss, grad @@ -362,78 +373,6 @@ def reverse_kl_forward_backward( return distillation_loss, distillation_grad -@torch.compile -def _forward_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Forward KL: KL(p||q) where p=teacher, q=student. - This is reverse KL with roles swapped in the loss computation. - - Key insight: KL(p||q) = sum_i p_i * log(p_i/q_i) - = sum_i p_i * (log(p_i) - log(q_i)) - which is reverse KL with p and q swapped. - - However, we still need grad w.r.t. student logits, so gradient is different: - d/d(student_logits) KL(p||q) = student_probs - teacher_probs - """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel forward KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel forward KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # Compute log softmax for both teacher and student - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - student_log_probs = distributed_log_softmax(logits, group=group) - - teacher_probs = teacher_log_probs.exp() - # Forward KL: p * log(p/q) = p * (log_p - log_q) - log_ratio = teacher_log_probs - student_log_probs - del teacher_log_probs - - # Compute loss: sum over vocab of teacher_probs * log_ratio - loss_terms = (teacher_probs * log_ratio).sum(dim=-1) - del log_ratio - - if loss_mask is not None: - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(student_logits) KL(p||q) = student_probs - teacher_probs - student_probs = student_log_probs.exp() - grad_base = student_probs - teacher_probs - del student_probs, teacher_probs, student_log_probs - - if loss_mask is not None: - grad_base.mul_(loss_mask.to(logits.dtype).unsqueeze(-1)) - - grad_base.mul_(grad_output / valid_tokens) - grad = grad_base.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - def forward_kl_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -467,25 +406,20 @@ def forward_kl_forward_backward( loss: Forward KL divergence loss grad: Gradients w.r.t. logits """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel forward KL is not implemented yet, set vocab_parallel true") - - Assert.eq(target_format, TargetFormat.logits, msg="Forward KL only supports logits format") + assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _forward_kl_forward_backward( + distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, grad_output=grad_output, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=teacher_softmax_temperature, + target_format=target_format, group=group, + teacher_softmax_temperature=teacher_softmax_temperature, + return_target_entropy=True, + **kwargs, ) + distillation_loss -= teacher_entropy + return distillation_loss, distillation_grad diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 411e98f4c..e2ce6ae19 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,7 +9,7 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -135,7 +135,7 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the final normalization layer.", hint=FieldHint.architecture, ) - losses: dict[str, LossConfig] = Field( + losses: dict[str, LanguageModelLossConfig] = Field( default_factory=dict, desc="A dictionary of loss names and their configurations.", hint=FieldHint.core, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e1f303323..6ba45c242 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0 and not loss_config.log_it: + if loss_config.factor == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -391,7 +391,7 @@ def _logits_loss_forward_backward( ) loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient - if losses is not None and loss_config.log_it: + if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) if total_loss is None: @@ -438,11 +438,10 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) for loss_name, loss_config in self._config.losses.items(): - if loss_config.log_it: - loss_def: LossDef = loss_config.get_loss_def( - name=loss_name, count=count, prediction_distance=self._prediction_distance - ) - loss_defs.append(loss_def) + loss_def: LossDef = loss_config.get_loss_def( + name=loss_name, count=count, prediction_distance=self._prediction_distance + ) + loss_defs.append(loss_def) return loss_defs diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 9fd946625..3695954bd 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -44,10 +44,10 @@ def has_any_target(self) -> bool: @config_class(registry=True) -class LossConfig(Config): +class LanguageModelLossConfig(Config): """ Losses canm register themselves - using @config_class(dynamic_type={LossConfig: "loss_type_name"}) + using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) """ _name: typing.ClassVar[str] @@ -60,12 +60,6 @@ class LossConfig(Config): valid=check_field(Assert.geq, 0.0), ) - log_it: bool = Field( - default=True, - hint=FieldHint.optional, - desc="Whether to log this loss.", - ) - @abc.abstractmethod def compute_loss( self, @@ -90,10 +84,6 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non def _validate(self): Assert.geq(self.factor, 0.0) - if self.factor > 0.0: - with self._set_implicit_default(): - if "log_it" not in self._explicit_fields: - self.log_it = True super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: @@ -103,8 +93,8 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) return name -@config_class(dynamic_type={LossConfig: "cross_entropy_lm_loss"}) -class CrossEntropyLMLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE" _abstract: typing.ClassVar[bool] = False @@ -159,8 +149,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "fkl_dist"}) -class ForwardKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" _name: typing.ClassVar[str] = "FwdKL" @@ -201,8 +191,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "revkl_dist"}) -class ReverseKLLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(LanguageModelLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" @@ -244,8 +234,8 @@ def compute_loss( ) -@config_class(dynamic_type={LossConfig: "dpo"}) -class DPOLossConfig(LossConfig): +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" _name: typing.ClassVar[str] = "DPO" diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 6bdaf3f67..ddfc2fc12 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -9,7 +9,7 @@ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.lm_head_losses import LossConfig +from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -69,7 +69,7 @@ def _lm_head( grad_output: float = 1.0, logit_scale_factor: float = 1.0, logit_z_loss=0.0, - losses: dict[str, LossConfig], + losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( input_.to(rms_weight.dtype), @@ -80,7 +80,7 @@ def _lm_head( logits = torch.nn.functional.linear(hidden, logit_weight).float() if "dist_loss" in losses: - if losses["dist_loss"].type == "revkl_dist": + if losses["dist_loss"].type == "reverse_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _reverse_kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -89,7 +89,7 @@ def _lm_head( loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) # Return scaled loss return loss * losses["dist_loss"].factor, None - elif losses["dist_loss"].type == "fkl_dist": + elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask @@ -137,14 +137,12 @@ def _lm_head( # "distillation_model": "distillation", # "losses": { # "lm_loss": { - # "type": "cross_entropy_lm_loss", + # "type": "cross_entropy", # "weight_scalor": 0.0, - # "log_it": False, # }, # "dist_loss": { # "type": "cross_entropy_dist", # TODO: Not implemented yet # "weight_scalor": 1.0, - # "log_it": True, # } # } # } @@ -153,87 +151,18 @@ def _lm_head( # False, # 1, # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - False, - 1, - ), - # Skip - CE distillation not implemented - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "losses": { - # "lm_loss": { - # "type": "cross_entropy_lm_loss", - # "weight_scalor": 1.0, - # "log_it": True, - # }, - # "dist_loss": { - # "type": "cross_entropy_dist", # TODO - # "weight_scalor": 1.0, - # "log_it": True, - # } - # } - # } - # }, - # {}, - # True, - # 1, - # ), - ( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 1.0, - "log_it": True, - }, - }, - } - }, - {}, - True, - 1, - ), pytest.param( { "head": { "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking even with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, }, }, } @@ -249,37 +178,12 @@ def _lm_head( "distillation_model": "distillation", "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 0.0, - "log_it": True, # tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 0.0, - "log_it": True, # tracking with zero weight - }, - }, - } - }, - {}, - False, - 1, - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "distillation_model": "distillation", - "losses": { - "lm_loss": { - "type": "cross_entropy_lm_loss", - "factor": 0.0, - "log_it": False, # not tracking with zero weight - }, - "dist_loss": { - "type": "revkl_dist", - "factor": 0.0, - "log_it": False, # not tracking with zero weight }, }, } @@ -288,24 +192,22 @@ def _lm_head( False, 1, marks=pytest.mark.xfail( - reason="No losses computed when all factors=0 and log_it=False", + reason="Cannot track both losses with zero factor", strict=True, ), - id="zero_factors_no_tracking", + id="track_both_zero_factors", ), pytest.param( { "head": { "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "factor": 1.0, - "log_it": False, # not tracking with zero weight }, "dist_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, - "log_it": True, # not tracking with zero weight }, }, } @@ -332,10 +234,9 @@ def test_lm_head( "normalization": {"type": "rms_norm"}, "losses": { "lm_loss": { - "type": "cross_entropy_lm_loss", + "type": "cross_entropy", "implementation": cross_entropy_impl, "factor": 1.0, - "log_it": True, } }, } @@ -480,9 +381,8 @@ def test_lm_head( # Get expected loss names from the loss configs for loss_name, loss_config in head._config.losses.items(): - if loss_config.log_it: - formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) - expected_loss_keys.add(formatted_name) + formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + expected_loss_keys.add(formatted_name) if ref_z_loss is not None: expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 3cadb4e20..93c78b58f 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -243,7 +243,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy_lm_loss", "factor": 1.0, "log_it": True}, + "lm_loss": {"type": "cross_entropy", "factor": 1.0}, }, }, "hidden_size": 256, @@ -559,7 +559,7 @@ def _update_and_add_testing_config( ("model", "base_model", "head", "distillation_model"): "teacher", ("model", "base_model", "head", "losses"): { "distillation_loss": { - "type": "revkl_dist", + "type": "reverse_kl_distillation", "factor": 1.0, }, }, From 31cfb84dd2081c0d1c40f31dee20859105e50146 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 02:22:15 +0000 Subject: [PATCH 15/51] wip --- fast_llm/data/dataset/gpt/config.py | 1 - fast_llm/layers/language_model/config.py | 14 ++++++++++++-- fast_llm/layers/language_model/head.py | 2 +- tests/test_config.py | 8 +++++++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 41a2fe7ff..5e978ac2b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -65,7 +65,6 @@ def build(self, preprocessing: PreprocessingConfig) -> SamplableDataset[SampleTy def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." config = yaml.safe_load(self.path.open("r")) - Assert.eq(config.keys(), {"config", "metadata"}) if config.keys() == {"config", "metadata"}: # Newer format with metadata config = config["config"] diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e2ce6ae19..58e85f5d8 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -209,12 +209,22 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: + with self._set_implicit_default(): + if not self.losses: + if "losses" not in self._explicit_fields: + self.losses = { + "lm_loss": LanguageModelLossConfig._from_dict( + { + "type": "cross_entropy", + "factor": 1.0, + } + ) + } for loss_config in self.losses.values(): - if "dist" in loss_config.type: + if "distillation" in loss_config.type: assert self.distillation_model is not None, "Distillation loss requires a distillation model." super()._validate() assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both - # Note: Default loss is handled at runtime in head.py if losses dict is empty @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6ba45c242..a67869f8b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -100,7 +100,7 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - assert self._config.losses, "At least one loss must be configured." + self._formatted_loss_names = { loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) for loss_name, loss_config in self._config.losses.items() diff --git a/tests/test_config.py b/tests/test_config.py index 8d6f39249..81137b587 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -147,14 +147,16 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, }, "num_blocks": 12, - "head": {}, }, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update + # added by default + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -297,3 +299,7 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) + + +if __name__ == "__main__": + pytest.main([__file__]) From 24fe67bbebbdd9a8aa5ad1393b43250ced3b8629 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 15:43:26 +0000 Subject: [PATCH 16/51] no grad if factor 0 --- fast_llm/layers/language_model/head.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index a67869f8b..50240f49c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -383,7 +383,9 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None + (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) + if loss_config.factor != 0.0 + else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, From 0e562e99198e8414b1c026d17cd3383c7acc2f55 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:00:00 +0000 Subject: [PATCH 17/51] addressed comments --- fast_llm/layers/language_model/config.py | 2 +- fast_llm/layers/language_model/head.py | 8 +++--- .../layers/language_model/lm_head_losses.py | 4 +-- tests/layers/test_lm_head.py | 26 +++++++++---------- tests/test_config.py | 4 +-- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 58e85f5d8..4bd8a592c 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -216,7 +216,7 @@ def _validate(self) -> None: "lm_loss": LanguageModelLossConfig._from_dict( { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, } ) } diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 50240f49c..40c099617 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -375,7 +375,7 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.factor == 0.0: + if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled # we log unscaled losses seperately and the scaled total loss @@ -383,15 +383,15 @@ def _logits_loss_forward_backward( logits, targets, grad_output=( - (grad_output * self._loss_coefficient * loss_config.factor if grad_output is not None else None) - if loss_config.factor != 0.0 + (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) + if loss_config.weight != 0.0 else None ), group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, ) - loss_ = loss_unscaled_ * loss_config.factor * self._loss_coefficient + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient if losses is not None: losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 3695954bd..dc367be65 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -53,7 +53,7 @@ class LanguageModelLossConfig(Config): _name: typing.ClassVar[str] _abstract: typing.ClassVar[bool] = True - factor: float = Field( + weight: float = Field( default=1.0, hint=FieldHint.core, desc="Weight for this loss in the total loss computation.", @@ -83,7 +83,7 @@ def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | Non ) def _validate(self): - Assert.geq(self.factor, 0.0) + Assert.geq(self.weight, 0.0) super()._validate() def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ddfc2fc12..7f9e55b79 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -86,18 +86,18 @@ def _lm_head( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None elif losses["dist_loss"].type == "forward_kl_distillation": Assert.eq(logits.shape, target.shape) loss = _kl_loss( (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask ) # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].factor)) + loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) # Return scaled loss - return loss * losses["dist_loss"].factor, None + return loss * losses["dist_loss"].weight, None if logit_scale_factor != 1.0: logits *= logit_scale_factor @@ -105,8 +105,8 @@ def _lm_head( # Language model loss (cross-entropy with hard labels) loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].factor)) - return loss * losses["lm_loss"].factor, z_loss + loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) + return loss * losses["lm_loss"].weight, z_loss SEQUENCE_LENGTH = 200 @@ -158,11 +158,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -179,11 +179,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 0.0, + "weight": 0.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 0.0, + "weight": 0.0, }, }, } @@ -203,11 +203,11 @@ def _lm_head( "losses": { "lm_loss": { "type": "cross_entropy", - "factor": 1.0, + "weight": 1.0, }, "dist_loss": { "type": "reverse_kl_distillation", - "factor": 1.0, + "weight": 1.0, }, }, } @@ -236,7 +236,7 @@ def test_lm_head( "lm_loss": { "type": "cross_entropy", "implementation": cross_entropy_impl, - "factor": 1.0, + "weight": 1.0, } }, } diff --git a/tests/test_config.py b/tests/test_config.py index 81137b587..3c6a76a35 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,7 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}}, + "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, @@ -156,7 +156,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): else: expected_config["base_model"] = base_model_update # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "factor": 1.0}}} + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) From 52c1c113d1fe32732b7bc2c666c0cfd6303abca8 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 23 Dec 2025 17:44:53 +0000 Subject: [PATCH 18/51] addressed comments --- fast_llm/functional/cross_entropy.py | 4 --- fast_llm/layers/language_model/head.py | 11 ++----- .../layers/language_model/lm_head_losses.py | 29 ++++++++++--------- tests/utils/model_configs.py | 2 +- 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 06c85848c..03f7a88ef 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -247,7 +247,6 @@ def _reverse_kl_forward_backward( group: ProcessGroup | None = None, logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -325,7 +324,6 @@ def reverse_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -383,7 +381,6 @@ def forward_kl_forward_backward( teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, sequence_parallel_logits: bool = False, - **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). @@ -418,7 +415,6 @@ def forward_kl_forward_backward( group=group, teacher_softmax_temperature=teacher_softmax_temperature, return_target_entropy=True, - **kwargs, ) distillation_loss -= teacher_entropy diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 40c099617..bce20c83f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -182,14 +182,10 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target, reference_model_logits, loss_mask, - chosen_spans, - rejected_spans, dpo_reference_model_logits, - ) = (None, None, None, None, None, None, None) + ) = (None, None, None, None, None) if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) else: if self._config.distillation_model is not None: @@ -230,8 +226,6 @@ def _get_targets(self, kwargs: dict) -> Targets | None: dpo_target=dpo_target, lm_target=lm_target, loss_mask=loss_mask, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, reference_model_logits=reference_model_logits, dpo_reference_model_logits=dpo_reference_model_logits, ) @@ -302,8 +296,6 @@ def _logits_cross_entropy_forward_backward_split( dpo_target=dpo_target_, reference_model_logits=reference_model_logits_, loss_mask=loss_mask_, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, dpo_reference_model_logits=targets.dpo_reference_model_logits, ) loss_, grad_ = self._logits_loss_forward_backward( @@ -390,6 +382,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, + kwargs=kwargs, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index dc367be65..4be129a28 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -34,8 +34,6 @@ class Targets: lm_target: "torch.Tensor | None" = None dpo_target: "torch.Tensor | None" = None loss_mask: "torch.Tensor | None" = None - chosen_spans: list[list[tuple[int, int]]] | None = None - rejected_spans: list[list[tuple[int, int]]] | None = None reference_model_logits: "torch.Tensor | None" = None dpo_reference_model_logits: "torch.Tensor | None" = None @@ -64,12 +62,12 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - target: Targets, + targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass @@ -119,7 +117,7 @@ def compute_loss( group: "ProcessGroup" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, - **kwargs, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward @@ -145,7 +143,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.labels, - **kwargs, ) @@ -170,7 +167,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward @@ -187,7 +185,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -212,7 +209,8 @@ def compute_loss( grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, - **kwargs, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import reverse_kl_forward_backward @@ -230,7 +228,6 @@ def compute_loss( logits_scale_factor=logits_scale_factor, teacher_softmax_temperature=self.teacher_softmax_temperature, target_format=TargetFormat.logits, - **kwargs, ) @@ -254,16 +251,22 @@ def compute_loss( targets: Targets, grad_output: float | None = None, group: "ProcessGroup" = None, - **kwargs, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss + from fast_llm.layers.language_model.config import LanguageModelKwargs + + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, targets=targets.dpo_target, reference_model_logits=targets.dpo_reference_model_logits, - chosen_spans=targets.chosen_spans, - rejected_spans=targets.rejected_spans, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, beta=self.beta, grad_output=grad_output, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6cda07ad0..f3d4659cd 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -247,7 +247,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy", "factor": 1.0}, + "lm_loss": {"type": "cross_entropy", "weight": 1.0}, }, }, "hidden_size": 256, From 406d0a2eaf355488a699220ad4198371585effa2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:13:50 +0000 Subject: [PATCH 19/51] Removed Targets class Removed the targets, class, moved tragets processing to losses, made loss masks more explicit --- fast_llm/layers/language_model/config.py | 17 +- fast_llm/layers/language_model/embedding.py | 3 +- fast_llm/layers/language_model/head.py | 139 ++++++----------- fast_llm/layers/language_model/kwargs.py | 23 +++ .../layers/language_model/lm_head_losses.py | 147 +++++++++++++----- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_lm_head.py | 3 +- 8 files changed, 185 insertions(+), 151 deletions(-) create mode 100644 fast_llm/layers/language_model/kwargs.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4bd8a592c..9f6cbf4ca 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -19,21 +19,6 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" - - @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24c..fda5e3387 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,8 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index bce20c83f..27b090c1f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, - LanguageModelKwargs, ) -from fast_llm.layers.language_model.lm_head_losses import Targets, _format_name +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -101,10 +101,12 @@ def __init__( peft=self._peft, ) - self._formatted_loss_names = { - loss_name: loss_config.get_formatted_name(loss_name, self._prediction_distance) - for loss_name, loss_config in self._config.losses.items() - } + self._formatted_loss_names = {} + for loss_name, loss_config in self._config.losses.items(): + if loss_config.weight > 0.0: + self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( + loss_name, self._prediction_distance + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -154,6 +156,12 @@ def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: targets = self._get_targets(kwargs) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) + input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) @@ -167,7 +175,7 @@ def _forward_backward( output_weights = self.output_weights loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, output_weights, grad_output, kwargs, losses + ln_output.detach(), targets, loss_mask, output_weights, grad_output, kwargs, losses ) if do_grad: @@ -176,62 +184,20 @@ def _forward_backward( else: return loss, None - def _get_targets(self, kwargs: dict) -> Targets | None: - ( - lm_target, - dpo_target, - reference_model_logits, - loss_mask, - dpo_reference_model_logits, - ) = (None, None, None, None, None) - if self._config.enable_dpo: - dpo_target = kwargs.get(LanguageModelKwargs.labels) - dpo_reference_model_logits = (kwargs.get(f"{self._config.dpo_reference_model}_logits"),) - else: - if self._config.distillation_model is not None: - # Target is reference model logits. - reference_model_logits = kwargs[f"{self._config.distillation_model}_logits"].flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice( - self._prediction_distance, self._prediction_distance + lm_target_sequence_length - ) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - - if self._sequence_parallel_logits: - if dpo_target is not None: - dpo_target = split_op(dpo_target, self._parallel_dim.group, 0) - if lm_target is not None: - lm_target = split_op(lm_target, self._parallel_dim.group, 0) - if loss_mask is not None: - loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) - if reference_model_logits is not None: - reference_model_logits = split_op(reference_model_logits, self._parallel_dim.group, 0) - - targets = Targets( - dpo_target=dpo_target, - lm_target=lm_target, - loss_mask=loss_mask, - reference_model_logits=reference_model_logits, - dpo_reference_model_logits=dpo_reference_model_logits, - ) - - # Return None if no targets are set - if not targets.has_any_target(): + def _get_targets(self, kwargs: dict) -> dict | None: + targets = {} + for loss_config in self._config.losses.values(): + if loss_config.weight == 0.0: + continue + loss_targets = loss_config.extract_targets_from_global_kwargs( + kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + head_config=self._config, + sequence_parallel_logits=self._sequence_parallel_logits, + ) + targets.update({k: v for k, v in loss_targets.items() if v is not None}) + if len(targets) == 0: return None return targets @@ -241,15 +207,16 @@ def get_output_weights(self) -> list[torch.Tensor]: def _logits_cross_entropy_forward_backward_split( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None: loss, logit_input_grad = self._logits_loss_forward_backward( - input_, targets, weight, grad_output, kwargs, losses + input_, targets, loss_mask, weight, grad_output, kwargs, losses ) if targets is None: # TODO: Make a proper way of returning the model output. @@ -273,34 +240,28 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None - # Extract target tensors for splitting (keep same order as original tuple) - target_tensors = [ - targets.lm_target, - targets.dpo_target, - targets.reference_model_logits, - targets.loss_mask, - ] split_size = div( - get_unique(target.size(0) for target in target_tensors if target is not None), + get_unique(target.size(0) for target in targets.values() if target is not None), self._config.cross_entropy_splits, ) tensors_split = [ [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, *target_tensors, logit_input_grad] + for tensor in [logit_input, loss_mask, logit_input_grad] ] - for logit_input_, lm_target_, dpo_target_, reference_model_logits_, loss_mask_, logit_input_grad_ in zip( - *tensors_split, strict=True - ): - targets_ = Targets( - lm_target=lm_target_, - dpo_target=dpo_target_, - reference_model_logits=reference_model_logits_, - loss_mask=loss_mask_, - dpo_reference_model_logits=targets.dpo_reference_model_logits, + target_split = { + name: ( + [None] * self._config.cross_entropy_splits + if targets[name] is None + else targets[name].split(split_size) ) + for name in targets + } + + for i, (logit_input_, loss_mask_, logit_input_grad_) in enumerate(zip(*tensors_split, strict=True)): loss_, grad_ = self._logits_loss_forward_backward( logit_input_, - targets_, + {name: target_split[name][i] for name in target_split}, + loss_mask_, weight, grad_output, kwargs, @@ -323,7 +284,8 @@ def _logits_cross_entropy_forward_backward_split( def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: Targets | None, + targets: dict[str, "torch.Tensor"] | None, + loss_mask: torch.Tensor | None, weight: torch.Tensor, grad_output: float, kwargs: dict, @@ -370,10 +332,9 @@ def _logits_loss_forward_backward( if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled - # we log unscaled losses seperately and the scaled total loss loss_unscaled_, grad_ = loss_config.compute_loss( logits, - targets, + loss_mask, grad_output=( (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) if loss_config.weight != 0.0 @@ -382,7 +343,7 @@ def _logits_loss_forward_backward( group=group, logits_scale_factor=self._config.logits_scale_factor, vocab_parallel=self._vocab_parallel, - kwargs=kwargs, + kwargs={**kwargs, **targets}, ) loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py new file mode 100644 index 000000000..4f6203881 --- /dev/null +++ b/fast_llm/layers/language_model/kwargs.py @@ -0,0 +1,23 @@ +from fast_llm.layers.block.config import BlockKwargs + + +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 4be129a28..088e55042 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -1,18 +1,20 @@ import abc -import dataclasses import logging import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup + from fast_llm.layers.language_model.config import LanguageModelHeadConfig logger = logging.getLogger(__name__) @@ -29,23 +31,10 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -@dataclasses.dataclass -class Targets: - lm_target: "torch.Tensor | None" = None - dpo_target: "torch.Tensor | None" = None - loss_mask: "torch.Tensor | None" = None - reference_model_logits: "torch.Tensor | None" = None - dpo_reference_model_logits: "torch.Tensor | None" = None - - def has_any_target(self) -> bool: - return any(getattr(self, field.name) is not None for field in dataclasses.fields(self)) - - @config_class(registry=True) class LanguageModelLossConfig(Config): """ - Losses canm register themselves - using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}) + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). """ _name: typing.ClassVar[str] @@ -62,7 +51,7 @@ class LanguageModelLossConfig(Config): def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -90,6 +79,18 @@ def get_formatted_name(self, name=None, prediction_distance: int | None = None) name = f"{name}_{prediction_distance}" return name + @abc.abstractmethod + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) class CrossEntropyLMLossConfig(LanguageModelLossConfig): @@ -109,10 +110,40 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -121,9 +152,7 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import cross_entropy_forward_backward - target = targets.lm_target - if target is None: - raise ValueError("CrossEntropyLoss requires lm_target to be set in Targets") + target = kwargs.get(TargetsKwargs.lm_target) implementation = self.implementation if implementation == CrossEntropyImpl.auto: if vocab_parallel: @@ -160,10 +189,29 @@ class ForwardKLLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -172,14 +220,12 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.cross_entropy import forward_kl_forward_backward - target = targets.reference_model_logits - if target is None: - raise ValueError("ForwardKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return forward_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -189,23 +235,16 @@ def compute_loss( @config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(LanguageModelLossConfig): +class ReverseKLLossConfig(ForwardKLLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL" _abstract: typing.ClassVar[bool] = False - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -215,14 +254,12 @@ def compute_loss( from fast_llm.functional.cross_entropy import reverse_kl_forward_backward # Use distillation_target for KL losses - target = targets.reference_model_logits - if target is None: - raise ValueError("ReverseKLLoss requires distillation_target to be set in Targets") + target = kwargs.get(TargetsKwargs.reference_model_logits) return reverse_kl_forward_backward( logits=logits.flatten(0, -2), target=target, - loss_mask=targets.loss_mask, + loss_mask=loss_mask, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -245,10 +282,35 @@ class DPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) + def extract_targets_from_global_kwargs( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + head_config: "LanguageModelHeadConfig | None" = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + def compute_loss( self, logits: "torch.Tensor", - targets: Targets, + loss_mask: "torch.Tensor | None", grad_output: float | None = None, group: "ProcessGroup" = None, logits_scale_factor: float | None = None, @@ -256,15 +318,16 @@ def compute_loss( kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss - from fast_llm.layers.language_model.config import LanguageModelKwargs + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) return compute_dpo_loss( logits=logits, - targets=targets.dpo_target, - reference_model_logits=targets.dpo_reference_model_logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, chosen_spans=chosen_spans, rejected_spans=rejected_spans, beta=self.beta, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2f43d1e41..846c65646 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760e..88da79e65 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 7f9e55b79..ed639db93 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,8 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert From f25380a191fd53bdc0427bc3592c3a026ad3fd22 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:39:22 +0000 Subject: [PATCH 20/51] fixes --- fast_llm/layers/language_model/head.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 27b090c1f..cb2312d75 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -195,6 +195,7 @@ def _get_targets(self, kwargs: dict) -> dict | None: prediction_heads=self._prediction_heads, head_config=self._config, sequence_parallel_logits=self._sequence_parallel_logits, + group=self._parallel_dim.group, ) targets.update({k: v for k, v in loss_targets.items() if v is not None}) if len(targets) == 0: @@ -240,8 +241,14 @@ def _logits_cross_entropy_forward_backward_split( else: logit_input_grad = None + # Collect all tensors that need to be split to determine the split size + tensors_to_check = [logit_input] + if loss_mask is not None: + tensors_to_check.append(loss_mask) + tensors_to_check.extend(target for target in targets.values() if target is not None) + split_size = div( - get_unique(target.size(0) for target in targets.values() if target is not None), + get_unique(tensor.size(0) for tensor in tensors_to_check), self._config.cross_entropy_splits, ) tensors_split = [ From 8adb7ddb9da22eba3f9a4e8a3cbff0e86ca2f214 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 30 Dec 2025 20:51:52 +0000 Subject: [PATCH 21/51] imports --- .../layers/language_model/lm_head_losses.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 088e55042..f6e69b4fa 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -3,7 +3,6 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.core.ops import split_op from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig @@ -137,6 +136,8 @@ def extract_targets_from_global_kwargs( else lm_target[:, lm_target_slice] ).flatten() if sequence_parallel_logits: + from fast_llm.core.ops import split_op + lm_target = split_op(lm_target, group, 0) return {TargetsKwargs.lm_target: lm_target} @@ -205,6 +206,8 @@ def extract_targets_from_global_kwargs( if reference_model_logits is not None: reference_model_logits = reference_model_logits.flatten(0, -2) if sequence_parallel_logits: + from fast_llm.core.ops import split_op + reference_model_logits = split_op(reference_model_logits, group, 0) return {TargetsKwargs.reference_model_logits: reference_model_logits} @@ -296,12 +299,15 @@ def extract_targets_from_global_kwargs( reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) return { TargetsKwargs.dpo_reference_model_logits: reference_model_logits, TargetsKwargs.dpo_target: dpo_target, From 1ce641d85ea418077865a080b4470ff9947fad85 Mon Sep 17 00:00:00 2001 From: oleksost Date: Tue, 6 Jan 2026 20:21:09 +0000 Subject: [PATCH 22/51] polish naming --- fast_llm/layers/language_model/head.py | 6 +++--- fast_llm/layers/language_model/lm_head_losses.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index cb2312d75..f05da5534 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -102,10 +102,10 @@ def __init__( ) self._formatted_loss_names = {} - for loss_name, loss_config in self._config.losses.items(): + for registered_loss_name, loss_config in self._config.losses.items(): if loss_config.weight > 0.0: - self._formatted_loss_names[loss_name] = loss_config.get_formatted_name( - loss_name, self._prediction_distance + self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( + registered_loss_name, self._prediction_distance ) def forward( diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index f6e69b4fa..49dbb3ced 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -72,8 +72,11 @@ def _validate(self): Assert.geq(self.weight, 0.0) super()._validate() - def get_formatted_name(self, name=None, prediction_distance: int | None = None) -> str: - name = f"{self._name}({name})" + def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: + """ + Retruns loss name for logging as '()', e.g. lm_loss(CE_loss), distillation(FwdKL_loss) + """ + name = f"{registered_loss_name}({self._name})" if prediction_distance is not None: name = f"{name}_{prediction_distance}" return name @@ -93,7 +96,7 @@ def extract_targets_from_global_kwargs( @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) class CrossEntropyLMLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE" + _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False implementation: CrossEntropyImpl = Field( @@ -180,7 +183,7 @@ def compute_loss( class ForwardKLLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" - _name: typing.ClassVar[str] = "FwdKL" + _name: typing.ClassVar[str] = "FwdKL_loss" _abstract: typing.ClassVar[bool] = False teacher_softmax_temperature: float = Field( @@ -241,7 +244,7 @@ def compute_loss( class ReverseKLLossConfig(ForwardKLLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" - _name: typing.ClassVar[str] = "RevKL" + _name: typing.ClassVar[str] = "RevKL_loss" _abstract: typing.ClassVar[bool] = False def compute_loss( @@ -275,7 +278,7 @@ def compute_loss( class DPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" - _name: typing.ClassVar[str] = "DPO" + _name: typing.ClassVar[str] = "DPO_loss" _abstract: typing.ClassVar[bool] = False beta: float = Field( From 95f14afc76b4d3639d45dde7228951ba7de4c666 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 18:44:01 +0000 Subject: [PATCH 23/51] addresseing comments --- fast_llm/functional/cross_entropy.py | 13 +++- fast_llm/layers/language_model/config.py | 78 +++++++++++++------ fast_llm/layers/language_model/head.py | 11 +-- .../layers/language_model/lm_head_losses.py | 54 +++++++++---- tests/layers/test_lm_head.py | 5 +- tests/test_config.py | 8 +- tests/utils/model_configs.py | 2 +- 7 files changed, 109 insertions(+), 62 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 03f7a88ef..6b0a4e92f 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -98,7 +98,10 @@ def _fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target = _fused_softmax(target, logits_scale_factor / teacher_softmax_temperature, group) + target_logits, exp_logits, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / teacher_softmax_temperature, group + ) + target = exp_logits / sum_exp_target_logits if target_format == TargetFormat.labels: target = target.unsqueeze(-1) @@ -159,9 +162,11 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) - if return_target_entropy and target_format == TargetFormat.logits: - # Compute teacher entropy - teacher_log_prob = torch.log(target + 1e-20) + if return_target_entropy: + if target_format == TargetFormat.logits: + teacher_log_prob = target_logits - sum_exp_target_logits.log() + else: + teacher_log_prob = torch.log(target + 1e-20) target_entropy = -(target * teacher_log_prob).sum(dim=-1) if loss_mask is not None: target_entropy = target_entropy * loss_mask.squeeze(-1) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9f6cbf4ca..a74489005 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,5 +1,7 @@ import abc import typing +import warnings +from functools import cached_property from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales @@ -9,7 +11,13 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig +from fast_llm.layers.language_model.lm_head_losses import ( + CrossEntropyLMLossConfig, + DPOLossConfig, + ForwardKLLossConfig, + LanguageModelLossConfig, + ReverseKLLossConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -151,17 +159,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) def get_layer( self, @@ -193,23 +190,37 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + removed_fields = ["distillation_loss_factor", "distillation_model", "language_model_loss_factor"] + for field in removed_fields: + if field in default: + warnings.warn( + f"Field `{field}` has been removed from {cls.__name__}. " + "Loss configuration should now be done via the `losses` field.", + DeprecationWarning, + ) + default.pop(field) + return super()._from_dict(default, strict=strict) + def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: if "losses" not in self._explicit_fields: - self.losses = { - "lm_loss": LanguageModelLossConfig._from_dict( - { - "type": "cross_entropy", - "weight": 1.0, - } - ) - } - for loss_config in self.losses.values(): - if "distillation" in loss_config.type: - assert self.distillation_model is not None, "Distillation loss requires a distillation model." + self.losses = {"lm_loss": CrossEntropyLMLossConfig()} super()._validate() - assert self.dpo_reference_model is None or self.distillation_model is None # currently don't support both + if DPOLossConfig in self._loss_configs: + assert ForwardKLLossConfig not in self._loss_configs.keys() # currently don't support both + assert ReverseKLLossConfig not in self._loss_configs.keys() # currently don't support both + if ForwardKLLossConfig in self._loss_configs.keys() and ReverseKLLossConfig in self._loss_configs.keys(): + assert ( + self._loss_configs[ForwardKLLossConfig].distillation_model + == self._loss_configs[ReverseKLLossConfig].distillation_model + ), "Distillation losses must use the same teacher." + + @cached_property + def _loss_configs(self) -> dict[type, LanguageModelLossConfig]: + return {loss.__class__: loss for loss in self.losses.values()} @property def max_prediction_distance(self) -> int: @@ -217,7 +228,24 @@ def max_prediction_distance(self) -> int: @property def enable_dpo(self) -> bool: - return self.dpo_reference_model is not None + return DPOLossConfig in self._loss_configs.keys() + + @property + def enable_distillation(self) -> bool: + return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + + @property + def distillation_model(self) -> str | None: + for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: + if loss_type in self._loss_configs: + return self._loss_configs[loss_type].distillation_model + return None + + @property + def dpo_reference_model(self) -> str | None: + if DPOLossConfig in self._loss_configs: + return self._loss_configs[DPOLossConfig].dpo_reference_model + return None @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f05da5534..465984e01 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -67,9 +67,7 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and ( - self._config.distillation_model is not None or self._config.dpo_reference_model is not None - ): + if prediction_distance > 0 and (self._config.enable_dpo or self._config.enable_distillation): raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") Assert.in_range(prediction_distance, 0, prediction_heads) @@ -189,11 +187,10 @@ def _get_targets(self, kwargs: dict) -> dict | None: for loss_config in self._config.losses.values(): if loss_config.weight == 0.0: continue - loss_targets = loss_config.extract_targets_from_global_kwargs( + loss_targets = loss_config.get_targets( kwargs, prediction_distance=self._prediction_distance, prediction_heads=self._prediction_heads, - head_config=self._config, sequence_parallel_logits=self._sequence_parallel_logits, group=self._parallel_dim.group, ) @@ -339,7 +336,7 @@ def _logits_loss_forward_backward( if loss_config.weight == 0.0: continue # losses are returned unscaled but the grads are already scaled - loss_unscaled_, grad_ = loss_config.compute_loss( + loss_unscaled_, grad_ = loss_config.get_loss( logits, loss_mask, grad_output=( @@ -401,7 +398,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) ) for loss_name, loss_config in self._config.losses.items(): - loss_def: LossDef = loss_config.get_loss_def( + loss_def: LossDef = loss_config.get_loss_definitions( name=loss_name, count=count, prediction_distance=self._prediction_distance ) loss_defs.append(loss_def) diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index 49dbb3ced..e1004b5c8 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -13,7 +13,6 @@ import torch from fast_llm.core.distributed import ProcessGroup - from fast_llm.layers.language_model.config import LanguageModelHeadConfig logger = logging.getLogger(__name__) @@ -46,8 +45,15 @@ class LanguageModelLossConfig(Config): valid=check_field(Assert.geq, 0.0), ) + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) + @abc.abstractmethod - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -59,7 +65,7 @@ def compute_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass - def get_loss_def(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: name = self.get_formatted_name(name, prediction_distance) return LossDef( name=name, @@ -82,12 +88,11 @@ def get_formatted_name(self, registered_loss_name=None, prediction_distance: int return name @abc.abstractmethod - def extract_targets_from_global_kwargs( + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: @@ -112,12 +117,11 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: @@ -144,7 +148,7 @@ def extract_targets_from_global_kwargs( lm_target = split_op(lm_target, group, 0) return {TargetsKwargs.lm_target: lm_target} - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -193,19 +197,22 @@ class ForwardKLLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." + super()._validate() + + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} - reference_model_logits = kwargs.get(f"{head_config.distillation_model}_logits") + reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") if reference_model_logits is not None: reference_model_logits = reference_model_logits.flatten(0, -2) if sequence_parallel_logits: @@ -214,7 +221,7 @@ def extract_targets_from_global_kwargs( reference_model_logits = split_op(reference_model_logits, group, 0) return {TargetsKwargs.reference_model_logits: reference_model_logits} - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -247,7 +254,11 @@ class ReverseKLLossConfig(ForwardKLLossConfig): _name: typing.ClassVar[str] = "RevKL_loss" _abstract: typing.ClassVar[bool] = False - def compute_loss( + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." + super()._validate() + + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", @@ -288,19 +299,28 @@ class DPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def extract_targets_from_global_kwargs( + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def _validate(self): + assert self.dpo_reference_model is not None, "DPO loss requires a reference model." + super()._validate() + + def get_targets( self, kwargs: dict | None = None, prediction_distance: int | None = None, prediction_heads: int | None = None, - head_config: "LanguageModelHeadConfig | None" = None, sequence_parallel_logits: bool | None = None, group: "ProcessGroup" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} - reference_model_logits = kwargs.get(f"{head_config.dpo_reference_model}_logits") + reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") dpo_target = kwargs.get(LanguageModelKwargs.labels) if reference_model_logits is not None or dpo_target is not None: from fast_llm.core.ops import split_op @@ -316,7 +336,7 @@ def extract_targets_from_global_kwargs( TargetsKwargs.dpo_target: dpo_target, } - def compute_loss( + def get_loss( self, logits: "torch.Tensor", loss_mask: "torch.Tensor | None", diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ed639db93..f25aba1e7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -155,7 +155,6 @@ def _lm_head( pytest.param( { "head": { - "distillation_model": "distillation", "losses": { "lm_loss": { "type": "cross_entropy", @@ -164,6 +163,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 1.0, + "distillation_model": "distillation", }, }, } @@ -176,7 +176,6 @@ def _lm_head( pytest.param( { "head": { - "distillation_model": "distillation", "losses": { "lm_loss": { "type": "cross_entropy", @@ -185,6 +184,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 0.0, + "distillation_model": "distillation", }, }, } @@ -209,6 +209,7 @@ def _lm_head( "dist_loss": { "type": "reverse_kl_distillation", "weight": 1.0, + "distillation_model": "distillation", }, }, } diff --git a/tests/test_config.py b/tests/test_config.py index 3c6a76a35..2e900cb14 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,7 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}}, + "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, @@ -156,7 +156,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): else: expected_config["base_model"] = base_model_update # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy", "weight": 1.0}}} + expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) @@ -299,7 +299,3 @@ def test_distributed_global_ranks(bdp: int, sdp: int, tp: int, pp: int, pipeline Assert.eq(len({global_rank for global_ranks in global_ranks_set for global_rank in global_ranks}), world_size) Assert.eq(len(rank_breakdowns), world_size) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f3d4659cd..a9a2e65bf 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -247,7 +247,7 @@ def _update_and_add_testing_config( "head": { "output_weight": init_1, "losses": { - "lm_loss": {"type": "cross_entropy", "weight": 1.0}, + "lm_loss": {"type": "cross_entropy"}, }, }, "hidden_size": 256, From 5ad4c0c98ffc96a58f226376d16a93f77c4e61d2 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 21:59:24 +0000 Subject: [PATCH 24/51] explicit z_loss grads --- fast_llm/layers/common/auxiliary_loss.py | 42 +++++++++----- fast_llm/layers/language_model/head.py | 40 ++++++------- .../layers/language_model/lm_head_losses.py | 36 ++++++++++++ tests/layers/test_lm_head.py | 57 ++++++++++++++----- 4 files changed, 125 insertions(+), 50 deletions(-) diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 44c2d2088..335debb12 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -21,18 +21,34 @@ def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> def z_loss( logits: torch.Tensor, - z_loss_factor: float, - training: bool, grad_scale: float | None = None, - losses: dict | None = None, - loss_name: str | None = None, logits_scale_factor: float = 1.0, -) -> torch.Tensor: - if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) - if losses is not None and loss_name is not None: - losses[loss_name].append(loss.detach()) - if training and grad_scale is not None: - logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) - - return logits +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Compute z-loss and its gradient. + + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + + Returns: + loss: The z-loss value (unscaled) + grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None + """ + if logits_scale_factor != 1.0: + scaled_logits = logits * logits_scale_factor + else: + scaled_logits = logits + + # Forward: z_loss = mean(logsumexp^2) + lse = torch.logsumexp(scaled_logits, dim=-1) # (N,) + loss = torch.mean(lse**2) + + # Backward: grad = (2/N) * lse * softmax(scaled_logits) + grad = None + if grad_scale is not None: + N = scaled_logits.shape[0] + softmax_logits = torch.softmax(scaled_logits, dim=-1) + grad = (2.0 / N) * lse.unsqueeze(-1) * softmax_logits * grad_scale + if logits_scale_factor != 1.0: + grad = grad * logits_scale_factor # Chain rule for logits_scale_factor + + return loss, grad diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 465984e01..f4c38abed 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -16,7 +16,7 @@ from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LanguageModelEmbeddingsConfig, @@ -101,10 +101,9 @@ def __init__( self._formatted_loss_names = {} for registered_loss_name, loss_config in self._config.losses.items(): - if loss_config.weight > 0.0: - self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( - registered_loss_name, self._prediction_distance - ) + self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( + registered_loss_name, self._prediction_distance + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -185,8 +184,6 @@ def _forward_backward( def _get_targets(self, kwargs: dict) -> dict | None: targets = {} for loss_config in self._config.losses.values(): - if loss_config.weight == 0.0: - continue loss_targets = loss_config.get_targets( kwargs, prediction_distance=self._prediction_distance, @@ -304,17 +301,17 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - # TODO: also move to lm_head_losses? - if self._config.logit_z_loss > 0.0: - logits = z_loss( - logits, - self._config.logit_z_loss, - self.training, - grad_output, - losses, - self._z_loss_name, - logits_scale_factor=self._config.logits_scale_factor, - ) + # # TODO: also move to lm_head_losses? + # if self._config.logit_z_loss > 0.0: + # logits = z_loss( + # logits, + # self._config.logit_z_loss, + # self.training, + # grad_output, + # losses, + # self._z_loss_name, + # logits_scale_factor=self._config.logits_scale_factor, + # ) sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: @@ -333,8 +330,6 @@ def _logits_loss_forward_backward( total_loss, grad = None, None for loss_name, loss_config in self._config.losses.items(): - if loss_config.weight == 0.0: - continue # losses are returned unscaled but the grads are already scaled loss_unscaled_, grad_ = loss_config.get_loss( logits, @@ -349,6 +344,7 @@ def _logits_loss_forward_backward( vocab_parallel=self._vocab_parallel, kwargs={**kwargs, **targets}, ) + loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient if losses is not None: @@ -393,10 +389,6 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count ) ] - if self._config.logit_z_loss > 0.0: - loss_defs.append( - LossDef(name=self._z_loss_name, formatted_name=_format_name(self._z_loss_name), count=count) - ) for loss_name, loss_config in self._config.losses.items(): loss_def: LossDef = loss_config.get_loss_definitions( name=loss_name, count=count, prediction_distance=self._prediction_distance diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py index e1004b5c8..327dee560 100644 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ b/fast_llm/layers/language_model/lm_head_losses.py @@ -362,3 +362,39 @@ def get_loss( beta=self.beta, grad_output=grad_output, ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class ZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _name: typing.ClassVar[str] = "Z_loss" + _abstract: typing.ClassVar[bool] = False + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + return {} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss + + return z_loss( + logits=logits.flatten(0, -2), + grad_scale=grad_output, + logits_scale_factor=logits_scale_factor, + ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index f25aba1e7..9c81ba0a4 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -69,7 +69,6 @@ def _lm_head( logit_weight: torch.Tensor, grad_output: float = 1.0, logit_scale_factor: float = 1.0, - logit_z_loss=0.0, losses: dict[str, LanguageModelLossConfig], ): hidden = torch.rms_norm( @@ -102,12 +101,31 @@ def _lm_head( if logit_scale_factor != 1.0: logits *= logit_scale_factor - z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None + + # Compute z_loss if configured + if "z_loss" in losses: + z_loss_unscaled = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + # Backward through z_loss (retain_graph since we need to also backward through ce_loss) + z_loss_unscaled.backward( + torch.full_like(z_loss_unscaled, grad_output * losses["z_loss"].weight), retain_graph=True + ) + z_loss_scaled = z_loss_unscaled * losses["z_loss"].weight + else: + z_loss_unscaled = None + z_loss_scaled = None + # Language model loss (cross-entropy with hard labels) - loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Apply language_model_loss_factor - loss.backward(torch.full_like(loss, grad_output * losses["lm_loss"].weight)) - return loss * losses["lm_loss"].weight, z_loss + ce_loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + # Backward through ce_loss + ce_loss.backward(torch.full_like(ce_loss, grad_output * losses["lm_loss"].weight)) + ce_loss_scaled = ce_loss * losses["lm_loss"].weight + + # Total loss = ce_loss + z_loss (both scaled) + total_loss = ce_loss_scaled + if z_loss_scaled is not None: + total_loss = total_loss + z_loss_scaled + + return total_loss, z_loss_unscaled SEQUENCE_LENGTH = 200 @@ -126,7 +144,21 @@ def _lm_head( ({}, {"compute_dtype": DataType.bfloat16}, False, 1), ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), ({"sequence_first": True}, {}, False, 1), - ({"head": {"logit_z_loss": 1e-3}}, {}, False, 1), + ( + { + "head": { + "losses": { + "z_loss": { + "type": "z_loss", + "weight": 1e-3, + }, + }, + } + }, + {}, + False, + 1, + ), ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), ({"tied_embedding_weight": True}, {}, False, 1), ({}, {}, False, 2), @@ -365,7 +397,6 @@ def test_lm_head( rms_weight=ref_rms_weight, logit_weight=ref_logit_weight, logit_scale_factor=head_config.logits_scale_factor, - logit_z_loss=head_config.logit_z_loss, losses=head_config.losses, ) @@ -386,8 +417,8 @@ def test_lm_head( formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) expected_loss_keys.add(formatted_name) - if ref_z_loss is not None: - expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") + # if ref_z_loss is not None: + # expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") Assert.eq( {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, @@ -404,9 +435,9 @@ def test_lm_head( Assert.eq(losses.keys(), expected_loss_keys) Assert.eq(len(losses[lm_head_loss_name]), 1) - if ref_z_loss is not None: - Assert.eq(len(losses["z_loss"]), 1) - Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + # if ref_z_loss is not None: + # Assert.eq(len(losses["z_loss"]), 1) + # Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) From 0a66e145fe903f03ecf124e46ea70331a04cb8da Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:03:07 +0000 Subject: [PATCH 25/51] removed z_loss as aux loss --- fast_llm/layers/language_model/head.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index f4c38abed..b3e0e47b6 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -301,18 +301,6 @@ def _logits_loss_forward_backward( sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - # # TODO: also move to lm_head_losses? - # if self._config.logit_z_loss > 0.0: - # logits = z_loss( - # logits, - # self._config.logit_z_loss, - # self.training, - # grad_output, - # losses, - # self._z_loss_name, - # logits_scale_factor=self._config.logits_scale_factor, - # ) - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q if LanguageModelKwargs.hidden_dims in kwargs: batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] From f8f70415b5a9c647359b8a9754aca5f13638a927 Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:14:50 +0000 Subject: [PATCH 26/51] move loss configs to the lm config --- fast_llm/layers/language_model/config.py | 392 ++++++++++++++++- fast_llm/layers/language_model/head.py | 2 +- .../layers/language_model/lm_head_losses.py | 400 ------------------ tests/layers/test_lm_head.py | 3 +- 4 files changed, 386 insertions(+), 411 deletions(-) delete mode 100644 fast_llm/layers/language_model/lm_head_losses.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a74489005..adf8dd86e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -3,30 +3,406 @@ import warnings from functools import cached_property -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.lm_head_losses import ( - CrossEntropyLMLossConfig, - DPOLossConfig, - ForwardKLLossConfig, - LanguageModelLossConfig, - ReverseKLLossConfig, -) +from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import torch + + from fast_llm.core.distributed import ProcessGroup from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction +def _format_name(name: str) -> str: + return name.replace("_", " ") + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + """ + Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). + """ + + _name: typing.ClassVar[str] + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + distillation_model: str | None = Field( + default=None, + desc="Name of the reference model to use for knowledge distillation." + "If provided, replace the loss with a distillation loss.", + hint=FieldHint.feature, + ) + + @abc.abstractmethod + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + pass + + def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: + name = self.get_formatted_name(name, prediction_distance) + return LossDef( + name=name, + formatted_name=_format_name(name), + count=count, + dtype=DataType.float32, + ) + + def _validate(self): + Assert.geq(self.weight, 0.0) + super()._validate() + + def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: + """ + Returns loss name for logging as '()', + e.g. lm_loss(CE_loss), distillation(FwdKL_loss) + """ + name = f"{registered_loss_name}({self._name})" + if prediction_distance is not None: + name = f"{name}_{prediction_distance}" + return name + + @abc.abstractmethod + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + pass + + +@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) +class CrossEntropyLMLossConfig(LanguageModelLossConfig): + _name: typing.ClassVar[str] = "CE_loss" + _abstract: typing.ClassVar[bool] = False + + implementation: CrossEntropyImpl = Field( + default=CrossEntropyImpl.auto, + desc="Implementation for the cross-entropy computation.", + hint=FieldHint.performance, + ) + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax (used in distillation losses).", + valid=check_field(Assert.gt, 0.0), + ) + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + lm_target = kwargs.get(LanguageModelKwargs.labels) + if lm_target is not None: + # MTP: Shift the labels + lm_target_sequence_length = ( + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + ) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) + lm_target = ( + lm_target[lm_target_slice] + if kwargs[LanguageModelKwargs.sequence_first] + else lm_target[:, lm_target_slice] + ).flatten() + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + lm_target = split_op(lm_target, group, 0) + return {TargetsKwargs.lm_target: lm_target} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + + target = kwargs.get(TargetsKwargs.lm_target) + implementation = self.implementation + if implementation == CrossEntropyImpl.auto: + if vocab_parallel: + implementation = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + implementation = CrossEntropyImpl.triton + else: + implementation = CrossEntropyImpl.fused + + return cross_entropy_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.labels, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) +class ForwardKLLossConfig(LanguageModelLossConfig): + """Forward KL divergence KL(p||q) for distillation (mode-covering).""" + + _name: typing.ClassVar[str] = "FwdKL_loss" + _abstract: typing.ClassVar[bool] = False + + teacher_softmax_temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." + super()._validate() + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + return {TargetsKwargs.reference_model_logits: reference_model_logits} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import forward_kl_forward_backward + + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return forward_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) +class ReverseKLLossConfig(ForwardKLLossConfig): + """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" + + _name: typing.ClassVar[str] = "RevKL_loss" + _abstract: typing.ClassVar[bool] = False + + def _validate(self): + assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." + super()._validate() + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import reverse_kl_forward_backward + + # Use distillation_target for KL losses + target = kwargs.get(TargetsKwargs.reference_model_logits) + + return reverse_kl_forward_backward( + logits=logits.flatten(0, -2), + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + teacher_softmax_temperature=self.teacher_softmax_temperature, + target_format=TargetFormat.logits, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class DPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO_loss" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def _validate(self): + assert self.dpo_reference_model is not None, "DPO loss requires a reference model." + super()._validate() + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + if kwargs is None: + kwargs = {} + + reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") + dpo_target = kwargs.get(LanguageModelKwargs.labels) + if reference_model_logits is not None or dpo_target is not None: + from fast_llm.core.ops import split_op + + if reference_model_logits is not None: + reference_model_logits = reference_model_logits.flatten(0, -2) + if sequence_parallel_logits: + reference_model_logits = split_op(reference_model_logits, group, 0) + if dpo_target is not None: + dpo_target = split_op(dpo_target, group, 0) + return { + TargetsKwargs.dpo_reference_model_logits: reference_model_logits, + TargetsKwargs.dpo_target: dpo_target, + } + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.dpo import compute_dpo_loss + + dpo_target = kwargs.get(TargetsKwargs.dpo_target) + dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) + chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) + rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) + + return compute_dpo_loss( + logits=logits, + targets=dpo_target, + reference_model_logits=dpo_reference_model_logits, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class ZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _name: typing.ClassVar[str] = "Z_loss" + _abstract: typing.ClassVar[bool] = False + + def get_targets( + self, + kwargs: dict | None = None, + prediction_distance: int | None = None, + prediction_heads: int | None = None, + sequence_parallel_logits: bool | None = None, + group: "ProcessGroup" = None, + ) -> dict[str, "torch.Tensor"]: + return {} + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + group: "ProcessGroup" = None, + logits_scale_factor: float | None = None, + vocab_parallel: bool = False, + kwargs: dict | None = None, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss + + return z_loss( + logits=logits.flatten(0, -2), + grad_scale=grad_output, + logits_scale_factor=logits_scale_factor, + ) + + @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b3e0e47b6..7f303684f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, + _format_name, ) from fast_llm.layers.language_model.kwargs import LanguageModelKwargs -from fast_llm.layers.language_model.lm_head_losses import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique diff --git a/fast_llm/layers/language_model/lm_head_losses.py b/fast_llm/layers/language_model/lm_head_losses.py deleted file mode 100644 index 327dee560..000000000 --- a/fast_llm/layers/language_model/lm_head_losses.py +++ /dev/null @@ -1,400 +0,0 @@ -import abc -import logging -import typing - -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs -from fast_llm.utils import Assert - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.core.distributed import ProcessGroup - -logger = logging.getLogger(__name__) - -# -# CE loss on lm_targets for standard LM training. Here targets are already masked. -# CE loss for distillation: cross entropuy that uses reference_model_logits as soft targets, not implemented, TODO. -# Forward KL divergence loss on reference_model_logits for distillation (mode-covering). -# Reverse KL divergence loss on reference_model_logits for distillation (mode-seeking). -# DPO loss for alignment using chosen and rejected spans. -# - - -def _format_name(name: str) -> str: - return name.replace("_", " ") - - -@config_class(registry=True) -class LanguageModelLossConfig(Config): - """ - Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). - """ - - _name: typing.ClassVar[str] - _abstract: typing.ClassVar[bool] = True - - weight: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Weight for this loss in the total loss computation.", - valid=check_field(Assert.geq, 0.0), - ) - - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) - - @abc.abstractmethod - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - pass - - def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: - name = self.get_formatted_name(name, prediction_distance) - return LossDef( - name=name, - formatted_name=_format_name(name), - count=count, - dtype=DataType.float32, - ) - - def _validate(self): - Assert.geq(self.weight, 0.0) - super()._validate() - - def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: - """ - Retruns loss name for logging as '()', e.g. lm_loss(CE_loss), distillation(FwdKL_loss) - """ - name = f"{registered_loss_name}({self._name})" - if prediction_distance is not None: - name = f"{name}_{prediction_distance}" - return name - - @abc.abstractmethod - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - pass - - -@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) -class CrossEntropyLMLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE_loss" - _abstract: typing.ClassVar[bool] = False - - implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax (used in distillation losses).", - valid=check_field(Assert.gt, 0.0), - ) - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - lm_target = split_op(lm_target, group, 0) - return {TargetsKwargs.lm_target: lm_target} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import cross_entropy_forward_backward - - target = kwargs.get(TargetsKwargs.lm_target) - implementation = self.implementation - if implementation == CrossEntropyImpl.auto: - if vocab_parallel: - implementation = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - implementation = CrossEntropyImpl.triton - else: - implementation = CrossEntropyImpl.fused - - return cross_entropy_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=None, # Labels are already masked - grad_output=grad_output, - group=group, - implementation=implementation, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.labels, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) -class ForwardKLLossConfig(LanguageModelLossConfig): - """Forward KL divergence KL(p||q) for distillation (mode-covering).""" - - _name: typing.ClassVar[str] = "FwdKL_loss" - _abstract: typing.ClassVar[bool] = False - - teacher_softmax_temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - reference_model_logits = split_op(reference_model_logits, group, 0) - return {TargetsKwargs.reference_model_logits: reference_model_logits} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import forward_kl_forward_backward - - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return forward_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(ForwardKLLossConfig): - """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" - - _name: typing.ClassVar[str] = "RevKL_loss" - _abstract: typing.ClassVar[bool] = False - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." - super()._validate() - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import reverse_kl_forward_backward - - # Use distillation_target for KL losses - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return reverse_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) -class DPOLossConfig(LanguageModelLossConfig): - """Direct Preference Optimization (DPO) loss for alignment.""" - - _name: typing.ClassVar[str] = "DPO_loss" - _abstract: typing.ClassVar[bool] = False - - beta: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Beta parameter for DPO loss (controls strength of preference optimization).", - valid=check_field(Assert.gt, 0.0), - ) - - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - - def _validate(self): - assert self.dpo_reference_model is not None, "DPO loss requires a reference model." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") - dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None or dpo_target is not None: - from fast_llm.core.ops import split_op - - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) - return { - TargetsKwargs.dpo_reference_model_logits: reference_model_logits, - TargetsKwargs.dpo_target: dpo_target, - } - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.dpo import compute_dpo_loss - - dpo_target = kwargs.get(TargetsKwargs.dpo_target) - dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) - - return compute_dpo_loss( - logits=logits, - targets=dpo_target, - reference_model_logits=dpo_reference_model_logits, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, - beta=self.beta, - grad_output=grad_output, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) -class ZLossConfig(LanguageModelLossConfig): - """Z-loss regularization to prevent overconfidence.""" - - _name: typing.ClassVar[str] = "Z_loss" - _abstract: typing.ClassVar[bool] = False - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, - ) -> dict[str, "torch.Tensor"]: - return {} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss - - return z_loss( - logits=logits.flatten(0, -2), - grad_scale=grad_output, - logits_scale_factor=logits_scale_factor, - ) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9c81ba0a4..aca378418 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,10 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelLossConfig from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.language_model.kwargs import LanguageModelKwargs -from fast_llm.layers.language_model.lm_head_losses import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda From ab9c9176efae53d0c5d5c5db47b96804ffe1b4ba Mon Sep 17 00:00:00 2001 From: oleksost Date: Thu, 8 Jan 2026 22:30:42 +0000 Subject: [PATCH 27/51] tests --- fast_llm/functional/cross_entropy.py | 4 ++-- tests/layers/test_lm_head.py | 23 ++++++++++++++++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 6b0a4e92f..6204ce316 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -98,10 +98,10 @@ def _fused_cross_entropy_forward_backward( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target_logits, exp_logits, sum_exp_target_logits = _fused_softmax_base( + target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( target, logits_scale_factor / teacher_softmax_temperature, group ) - target = exp_logits / sum_exp_target_logits + target = exp_logits_targets / sum_exp_target_logits if target_format == TargetFormat.labels: target = target.unsqueeze(-1) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index aca378418..6929784f5 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -204,6 +204,27 @@ def _lm_head( 1, id="track_lm_zero_factor", ), + pytest.param( + { + "head": { + "losses": { + "lm_loss": { + "type": "cross_entropy", + "weight": 0.0, + }, + "dist_loss": { + "type": "forward_kl_distillation", + "weight": 1.0, + "distillation_model": "distillation", + }, + }, + } + }, + {}, + False, + 1, + id="forward_kl_distillation", + ), pytest.param( { "head": { @@ -224,7 +245,7 @@ def _lm_head( False, 1, marks=pytest.mark.xfail( - reason="Cannot track both losses with zero factor", + reason="At least one loss has to have non-zero factor to track gradients", strict=True, ), id="track_both_zero_factors", From 66f16d5e235ae778878da8bd614ace549dc4bdd3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 Jan 2026 01:58:06 -0500 Subject: [PATCH 28/51] stuff --- fast_llm/core/distributed.py | 6 +- fast_llm/core/kernels.py | 78 ++++++++++++------- fast_llm/engine/checkpoint/convert.py | 7 -- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/engine/distributed/config.py | 5 ++ fast_llm/engine/distributed/distributed.py | 23 +++--- fast_llm/engine/inference/huggingface.py | 2 - fast_llm/engine/multi_stage/fast_llm_model.py | 3 +- fast_llm/engine/schedule/config.py | 18 +++++ fast_llm/engine/schedule/runner.py | 61 +++++++++------ fast_llm/engine/schedule/schedule.py | 6 +- fast_llm/engine/training/trainer.py | 3 +- fast_llm/functional/triton/mlp.py | 15 ++-- fast_llm/functional/triton/normalization.py | 2 +- fast_llm/functional/triton/pointwise.py | 6 +- fast_llm/layers/attention/attention.py | 6 ++ fast_llm/layers/attention/rotary/config.py | 18 ----- fast_llm/layers/attention/rotary/rotary.py | 46 ++++++++--- fast_llm/utils.py | 9 +++ tests/data/common.py | 6 +- tests/functional/test_cross_entropy.py | 9 +-- tests/functional/test_functional.py | 27 ++++--- tests/functional/test_triton_kernels.py | 31 ++++---- tests/layers/test_attention.py | 14 ++-- tests/layers/test_lm_head.py | 9 ++- tests/layers/test_rotary.py | 11 ++- tests/layers/test_ssm.py | 29 +++---- tests/layers/test_varlen.py | 24 ++++-- tests/models/test_model.py | 1 - tests/test_multi_stage.py | 2 - tests/utils/model_configs.py | 12 ++- tests/utils/utils.py | 2 +- 32 files changed, 290 insertions(+), 203 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 4dcc53d55..9d1f16fbe 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -185,7 +185,11 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta @contextlib.contextmanager def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]: """Use the generator as default, for ops that don't support a generator argument.""" - default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()] + default_generator: torch.Generator = ( + torch.cuda.default_generators[generator.device.index] + if generator.device.type == "cuda" + else torch.default_generator + ) assert generator is not default_generator old_state = default_generator.get_state() default_generator.set_state(generator.get_state()) diff --git a/fast_llm/core/kernels.py b/fast_llm/core/kernels.py index 9ead051d7..93371a654 100644 --- a/fast_llm/core/kernels.py +++ b/fast_llm/core/kernels.py @@ -18,24 +18,29 @@ def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor: - assert _apex_available - norm, _ = _multi_tensor_applier( - _multi_tensor_l2norm, - noop_flag, - [tensors], - False, # no per-parameter norm - ) + if _apex_available: + norm, _ = _multi_tensor_applier( + _multi_tensor_l2norm, + noop_flag, + [tensors], + False, # no per-parameter norm + ) + else: + norm = sum(torch.norm(tensor) ** 2 for tensor in tensors) ** 0.5 return norm def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None: - assert _apex_available - _multi_tensor_applier( - _multi_tensor_scale, - noop_flag, - [tensors, tensors], - scale, - ) + if _apex_available: + _multi_tensor_applier( + _multi_tensor_scale, + noop_flag, + [tensors, tensors], + scale, + ) + else: + for tensor in tensors: + tensor.mul_(scale) # TODO: Same as torch._fused_adam_? @@ -52,16 +57,35 @@ def fused_adam( eps: float, step: int, ) -> None: - _multi_tensor_applier( - _multi_tensor_adam, - noop_flag, - [grads, params, exp_avgs, exp_avg_sqs], - lr, - beta1, - beta2, - eps, - step, - 1, # adamw - 1, # bias correction - wd, - ) + if _apex_available: + _multi_tensor_applier( + _multi_tensor_adam, + noop_flag, + [grads, params, exp_avgs, exp_avg_sqs], + lr, + beta1, + beta2, + eps, + step, + 1, # adamw + 1, # bias correction + wd, + ) + else: + import torch.optim.adamw as adamw + + adamw.adamw( + params, + grads, + exp_avgs, + exp_avg_sqs, + None, + lr=lr, + beta1=beta1, + beta2=beta2, + eps=eps, + state_steps=torch.full([len(params)], step, dtype=torch.int64, device=params[0].device).unbind(), + weight_decay=wd, + amsgrad=False, + maximize=False, + ) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 4ab7b3d54..b40d8a1b3 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -8,7 +8,6 @@ from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode -from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -21,7 +20,6 @@ class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() - use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) @@ -65,7 +63,6 @@ def _convert_model_partial( model = model_class.from_pretrained( self.input, mode=StageMode.weights, - use_cpu=self.use_cpu, stage_filter=stage_filter, ) logger.info(f"Saving {output.format} checkpoint to {output.path}...") @@ -78,9 +75,6 @@ def run(self): # TODO: Set logging in tests logging.getLogger().setLevel(logging.INFO) self.to_logs() - # Disable Triton to convert model on CPU - if self.use_cpu: - TritonConfig.TRITON_ENABLED = False # Skip on exist_ok=False if the model has already been processed if not self.exist_ok and (self.output.path / "ok").exists(): logger.info( @@ -101,7 +95,6 @@ def run(self): model = model_class.from_pretrained( self.input.to_copy({"model_weights": False}), mode=StageMode.off_device, - use_cpu=self.use_cpu, ) stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage) num_stages = len(model.stages) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1849a2316..415147d06 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels and distributed.config.use_cuda TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7f4b7bc38..d0a078812 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -201,6 +201,11 @@ class DistributedConfig(Config): hint=FieldHint.optional, valid=check_field(Assert.gt, 0), ) + use_cuda: bool = Field( + default=True, + desc="Enable CUDA device.", + hint=FieldHint.expert, + ) seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional) # TODO: Rename to compute_dtype (not just for training), move elsewhere compute_dtype: DataType = Field( diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index aa2be6ce7..e2f8daaa4 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -27,7 +27,7 @@ def __init__( world_size: int | None = None, local_world_size: int | None = None, timeout: float = 60, - use_cpu: bool = False, + use_cuda: bool = True, backend: DistributedBackend = DistributedBackend.nccl, ): @@ -37,19 +37,20 @@ def __init__( DistributedConfig.default_local_world_size if local_world_size is None else local_world_size ) self._timeout = timeout - self._use_cpu = use_cpu + self._use_cuda = use_cuda self._backend = backend self._process_groups = {} - if self._use_cpu: - if backend == DistributedBackend.nccl: - Assert.eq(self._world_size, 1) - self._device = torch.device("cpu") - else: + if self._use_cuda: + assert torch.cuda.is_available() Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) torch.cuda.set_device(self._device) + else: + if backend == DistributedBackend.nccl: + Assert.eq(self._world_size, 1) + self._device = torch.device("cpu") if self._world_size > 1: if self._rank == 0: @@ -152,7 +153,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): TODO: Clarify cpu support. """ - def __init__(self, config: DistributedConfig, use_cpu: bool = False): + def __init__(self, config: DistributedConfig): super().__init__(config) assert self._config.reference_config is None @@ -163,7 +164,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self._config.world_size, self._config.local_world_size, self._config.timeout, - use_cpu, + self._config.use_cuda, self._config.backend, ) else: @@ -171,7 +172,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): Assert.geq(self._pool.world_size, self._config.world_size) Assert.eq(self._pool.rank, self._config.rank) Assert.geq(self._pool.local_world_size, self._config.local_world_size) - Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda") + Assert.eq(self._pool.device.type, "cuda" if self._config.use_cuda else "cpu") Assert.eq(self._pool.backend, self._config.backend) self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) @@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None: self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) def __del__(self): - if self._local_pool: + if getattr(self, "_local_pool", False) and hasattr(self, "_pool"): self._pool.shutdown() diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3ffed4533..aa1eaa401 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -85,7 +85,6 @@ def from_pretrained( optimizer_state_names: tuple[str, ...] | None = None, # setup: bool = True, mode: StageMode = StageMode.training, - use_cpu: bool = False, stage_filter: set | None = None, **kwargs, ) -> typing.Self: @@ -104,7 +103,6 @@ def from_pretrained( optimizer_state_names=optimizer_state_names, setup=True, mode=mode, - use_cpu=use_cpu, stage_filter=stage_filter, ) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 6a6223cb7..ccde838e8 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -48,7 +48,6 @@ def from_pretrained( optimizer_state_names: tuple[str, ...] | None = None, setup: bool = True, mode: StageMode = StageMode.training, - use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: metadata = cls.config_class.load_metadata(pretrained_config) @@ -69,7 +68,7 @@ def from_pretrained( ) if setup: - model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode) + model.setup(Distributed(config.distributed), mode=mode) if mode.on_device: if pretrained_config.model_weights: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 272b7c6ae..8696f0a59 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -191,3 +191,21 @@ class EventType(str, enum.Enum): send = "send" recv = "recv" pipe_wait_compute = "pipe_wait_compute" + + +class MockStream: + stream_id: int = 0 + + def wait_stream(self, stream): + pass + + def __eq__(self, other): + return isinstance(other, MockStream) + + +class MockEvent: + def record(self, stream=None): + pass + + def wait(self): + pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 133b3206b..5ddf2ff98 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -1,4 +1,5 @@ import collections +import contextlib import dataclasses import logging import time @@ -16,7 +17,7 @@ from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.optimizer import Optimizer -from fast_llm.engine.schedule.config import EventType, ScheduleConfig, StepType, StreamType +from fast_llm.engine.schedule.config import EventType, MockEvent, MockStream, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert @@ -36,7 +37,7 @@ class BatchContext: # Dictionary of losses, purely for logging purposes. # Losses will be reduced over DP and PP, and aggregated over steps. losses: dict | None = None - profile: list[tuple[EventType, Step | None, torch.cuda.Event, StreamType, float]] = dataclasses.field( + profile: list[tuple[EventType, Step | None, torch.cuda.Event | MockEvent, StreamType, float]] = dataclasses.field( default_factory=list ) # Store metrics like: grad norm, loss scale, learning-rate, etc. @@ -65,15 +66,15 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False - _compute_stream: torch.cuda.Stream - _data_stream: torch.cuda.Stream - _pipeline_stream: torch.cuda.Stream + _compute_stream: torch.cuda.Stream | MockStream + _data_stream: torch.cuda.Stream | MockStream + _pipeline_stream: torch.cuda.Stream | MockStream _streams: dict[int, StreamType] - _compute_event: torch.cuda.Event - _reduce_event: torch.cuda.Event - _send_event: torch.cuda.Event + _compute_event: torch.cuda.Event | MockEvent + _reduce_event: torch.cuda.Event | MockEvent + _send_event: torch.cuda.Event | MockEvent _data_stream_needs_sync: bool - _profile_events: dict[tuple[EventType, tuple | None], torch.cuda.Event] + _profile_events: dict[tuple[EventType, tuple | None], torch.cuda.Event | MockEvent] _distributed: Distributed _optimizer: Optimizer | None _stages_on_device: list[Stage] @@ -111,12 +112,16 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> self._stages_owned = [stage.mode.on_device and not stage.is_tied_weight_copy for stage in self._stages] # Setup the streams - self._compute_stream = torch.cuda.current_stream(self._distributed.device) + self._compute_stream = self._get_current_stream() self._data_stream = ( - torch.cuda.Stream(self._distributed.device) if self._config.data_overlap else self._compute_stream + torch.cuda.Stream(self._distributed.device) + if self._config.data_overlap and self._distributed_config.use_cuda + else self._compute_stream ) self._pipeline_stream = ( - torch.cuda.Stream(self._distributed.device) if self._config.pipeline_overlap else self._compute_stream + torch.cuda.Stream(self._distributed.device) + if self._config.pipeline_overlap and self._distributed_config.use_cuda + else self._compute_stream ) # Putting compute stream last in the dict in case it's the same id. self._streams = { @@ -126,10 +131,12 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> } # Setup the synchronization and profiling events - self._profile_events = collections.defaultdict(lambda: torch.cuda.Event(enable_timing=True)) - self._compute_event = torch.cuda.Event() - self._reduce_event = torch.cuda.Event() - self._send_event = torch.cuda.Event() + self._profile_events = collections.defaultdict( + lambda: torch.cuda.Event(enable_timing=True) if self._distributed_config.use_cuda else MockEvent() + ) + self._compute_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() + self._reduce_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() + self._send_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() self._data_stream_needs_sync = False def run_step( @@ -164,7 +171,7 @@ def run_step( self._distributed.set_step(iteration, schedule.phase) # Synchronize streams - Assert.eq(torch.cuda.current_stream(self._distributed.device), self._compute_stream) + Assert.eq(self._get_current_stream(), self._compute_stream) if self._config.profile_schedule: # Synchronize clocks safe_barrier(self._distributed.world_group, f"clock sync {iteration}") @@ -354,7 +361,7 @@ def _preprocess_data( def _restore(self, context: BatchContext, step: Step) -> None: if step.restore_launch: - with torch.cuda.stream(self._data_stream): + with self._with_stream(self._data_stream): self._sync_data_stream(context, step) for restore_step in step.restore_launch: self._stages[restore_step.stage].restore_parameters() @@ -368,7 +375,7 @@ def _restore(self, context: BatchContext, step: Step) -> None: def _recv(self, context: BatchContext, step: Step) -> None: if step.recv_launch: - with torch.cuda.stream(self._pipeline_stream): + with self._with_stream(self._pipeline_stream): for recv_step in step.recv_launch: # TODO: Pre-allocated buffers context.inputs[recv_step.global_index] = torch.empty_like( @@ -432,7 +439,7 @@ def _send(self, context: BatchContext, step: Step, output: torch.Tensor) -> None if step.next_step.recv_step is None: context.inputs[step.next_step.global_index] = output else: - with torch.cuda.stream(self._pipeline_stream): + with self._with_stream(self._pipeline_stream): self._compute_event.wait() self._record_event(context, EventType.pipe_wait_compute, step, self._pipeline_stream) if self._config.debug_send_recv: @@ -452,7 +459,7 @@ def _send(self, context: BatchContext, step: Step, output: torch.Tensor) -> None def _reduce(self, context: BatchContext, step: Step) -> None: if step.reduce: - with torch.cuda.stream(self._data_stream): + with self._with_stream(self._data_stream): self._sync_data_stream(context, step) stage = self._stages[step.stage] if not self._config.skip_step: @@ -462,12 +469,12 @@ def _reduce(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.reduce, step) def _record_event( - self, context: BatchContext, type_: EventType, step: Step | None, stream: torch.cuda.Stream = None + self, context: BatchContext, type_: EventType, step: Step | None, stream: torch.cuda.Stream | MockStream = None ) -> None: if not self._config.profile_schedule: return if stream is None: - stream = torch.cuda.current_stream() + stream = self._get_current_stream() event = self._profile_events[(type_, None if step is None else step.map_index)] event.record(stream) cpu_time = time.perf_counter() @@ -529,3 +536,11 @@ def _record_compute(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.run, step) if self._config.data_overlap: self._data_stream_needs_sync = True + + def _get_current_stream(self): + return ( + torch.cuda.current_stream(self._distributed.device) if self._distributed_config.use_cuda else MockStream() + ) + + def _with_stream(self, stream: torch.cuda.Stream | MockStream): + return torch.cuda.stream(stream) if self._distributed_config.use_cuda else contextlib.nullcontext() diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 18ca44b78..fa25c914d 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -281,7 +281,7 @@ def _setup_restore_steps(self, weight_buffer_indices: dict[int, int]) -> None: for step in device_steps: buffer_index = weight_buffer_indices[step.stage] if buffer_contents.get(buffer_index) != step.stage: - if self._schedule_config.data_overlap: + if self._schedule_config.data_overlap and self._distributed_config.use_cuda: step.restore_step = device_steps[buffer_last_used.get(buffer_index, -1) + 1] step.restore_event = torch.cuda.Event() else: @@ -378,7 +378,7 @@ def _setup_send_recv_steps(self) -> None: launch_step.recv_launch.append(recv_step) send_step.send_to = launch_step recv_step.recv_step = launch_step - if self._schedule_config.pipeline_overlap: + if self._schedule_config.pipeline_overlap and self._distributed_config.use_cuda: recv_step.recv_event = torch.cuda.Event() def _validate_send_recv_steps(self) -> None: @@ -449,7 +449,7 @@ def _validate_send_recv_steps(self) -> None: raise RuntimeError(f"Cannot find valid timeline for {self}, \nStatuses:{msg}") def _setup_throttle_steps(self) -> None: - if not self._schedule_config.throttle_cpu: + if not self._schedule_config.throttle_cpu or not self._distributed_config.use_cuda: return for device_steps in self._device_steps: for i, step in enumerate(device_steps): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7225ed20a..b35733cc7 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -358,7 +358,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() start_time = time.perf_counter() last_time = start_time start_iteration = self._completed_steps diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 1d2d0b3d6..dc3ee0f04 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -235,14 +235,13 @@ def mlp_forward( input_ = None # Activation - if TritonConfig.TRITON_ENABLED: + if TritonConfig.TRITON_ENABLED and intermediate_1.device.type == "cuda": intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, gated, activation_type) else: do_grad = training and not recompute_level.recompute_activation with torch.set_grad_enabled(do_grad): - intermediate_2 = torch_mlp_activation( - intermediate_1.detach().requires_grad_(do_grad), gated, activation_type - ) + intermediate_1 = intermediate_1.detach().requires_grad_(do_grad) + intermediate_2 = torch_mlp_activation(intermediate_1, gated, activation_type) if recompute_level.recompute_layer_1: intermediate_1 = None @@ -345,20 +344,20 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ )[0] # Activation recomputation and/or backward - if TritonConfig.TRITON_ENABLED: + if TritonConfig.TRITON_ENABLED and grad_output.device.type == "cuda": grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( grad_intermediate_2, (intermediate_1, gated, activation_type), intermediate_2 is None ) else: if intermediate_2 is None: + intermediate_1 = intermediate_1.detach().requires_grad_(True) with torch.set_grad_enabled(True): - intermediate_2_ = torch_mlp_activation( - intermediate_1.detach().requires_grad_(True), gated, activation_type - ) + intermediate_2_ = torch_mlp_activation(intermediate_1, gated, activation_type) else: intermediate_2_ = intermediate_2 intermediate_2_.backward(grad_intermediate_2) grad_intermediate_1 = intermediate_1.grad + print("AAAAA", intermediate_2 is None, grad_intermediate_1) # Layer 2 parameter grad del grad_intermediate_2, intermediate_1 diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 96d1663f7..a018ad44b 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -187,7 +187,7 @@ def triton_normalization_forward( n_cols = weight.numel() output = torch.empty_like(input_, dtype=weight.dtype) - inv_var = torch.empty(n_rows, dtype=torch.float32, device="cuda") + inv_var = torch.empty(n_rows, dtype=torch.float32, device=input_.device) block_size = triton.next_power_of_2(n_cols) assert block_size * input_.element_size() <= TritonConfig.MAX_BLOCK_SIZE_BYTES diff --git a/fast_llm/functional/triton/pointwise.py b/fast_llm/functional/triton/pointwise.py index bd14de9e2..22676ae1a 100644 --- a/fast_llm/functional/triton/pointwise.py +++ b/fast_llm/functional/triton/pointwise.py @@ -32,7 +32,7 @@ def triton_copy( """ A triton implementation of tensor copying (`torch.Tensor.copy_()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return out.copy_(input_) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -65,7 +65,7 @@ def triton_fill( """ A faster triton implementation of tensor copying (`torch.Tensor.fill_()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return input_.fill_(value) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -106,7 +106,7 @@ def triton_add( """ A faster triton implementation of tensor addition (`torch.Tensor.add()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return torch.add(input_, other, out=out) # TODO: Improve assumptions. assert input_.is_contiguous() diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 073599479..902352c25 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -459,6 +459,11 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size if self._config.causal: + print( + "WWWWW", + kwargs[AttentionKwargs.sequence_length], + self._backup_attention_tensor_cache_max_sequence_length, + ) if ( sequence_length := kwargs[AttentionKwargs.sequence_length] ) > self._backup_attention_tensor_cache_max_sequence_length: @@ -491,6 +496,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if attention_mask is None: attention_mask = document_mask else: + print("AAAA", attention_mask.shape, document_mask.shape, kwargs) attention_mask = attention_mask & document_mask kwargs[AttentionKwargs.attention_mask] = attention_mask diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 92adc880e..1ec35ae0c 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -1,12 +1,10 @@ import abc import math import typing -import warnings from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -59,22 +57,6 @@ class DefaultRotaryConfig(RotaryConfig): desc="Scale for the rotary positional embeddings", hint=FieldHint.architecture, ) - # TODO: Make a backup implementation that doesn't affect the layout. - triton: bool = Field( - default=True, - desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", - hint=FieldHint.architecture, - ) - - @property - def complex_format(self) -> bool: - # TODO: Make a backup implementation that doesn't affect the layout. - return not self.triton - - def _validate(self) -> None: - super()._validate() - if self.triton and not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": from fast_llm.layers.attention.rotary.rotary import DefaultRotary diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 258f9d8bc..304f96b83 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -6,6 +6,7 @@ from fast_llm.config import Configurable from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.rotary.config import ( @@ -28,7 +29,7 @@ def convert_rotary_real_to_complex(tensor: torch.Tensor, head_size: int, dim: in return tensor.unflatten(dim, (-1, 2, div(head_size, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) -def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: +def rotary_embeddings_complex(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: """ Apply rotary embeddings to a tensor: * Convert it to a complex, full-precision tensor @@ -41,6 +42,23 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) +@torch.compile +def rotary_embeddings_real(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: + """ + Apply rotary embeddings to a tensor. + """ + tensor_re, tensor_im = torch.chunk(tensor, 2, dim=-1) + frequencies_re, frequencies_im = torch.chunk(rope_frequencies, 2, dim=-1) + + return torch.cat( + [ + tensor_re * frequencies_re - tensor_im * frequencies_im, + tensor_im * frequencies_re + tensor_re * frequencies_im, + ], + dim=-1, + ) + + class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, @@ -82,7 +100,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + rotary_fn = ( + triton_rotary_autograd_ + if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" + else rotary_embeddings_real + ) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key @@ -107,10 +129,9 @@ def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.d positions = torch.arange(sequence_length, device=device, dtype=torch.float64) angles = torch.outer(positions, self._get_angle_scales(head_size, device)) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not self._config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), head_size, 3 - ).contiguous() + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), head_size, 3 + ).contiguous() return frequencies def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: @@ -207,10 +228,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: out=angles.view(-1, 2, self._head_size // 4).permute(1, 0, 2), ) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not self._config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), self._head_size, 3 - ).contiguous() + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), self._head_size, 3 + ).contiguous() # TODO: Support different q and k frequencies. kwargs[AttentionKwargs.rotary_freq_q] = frequencies kwargs[AttentionKwargs.rotary_freq_k] = frequencies @@ -218,7 +238,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + rotary_fn = ( + triton_rotary_autograd_ + if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" + else rotary_embeddings_real + ) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 2ca61aa0e..408441b95 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -421,6 +421,15 @@ def get_and_reset_memory_usage_mib( global _global_max_allocated, _global_max_reserved import torch + if not torch.cuda.is_available(): + return { + "reserved": 0.0, + "allocated": 0.0, + "max_reserved": 0.0, + "max_allocated": 0.0, + "global_max_reserved": 0.0, + } + if clear_cache: # Free memory for more accurate reporting, and to reduce OOM risk with lots of workers. # Cublas workspace can unnecessarily keep 100s of MBs of reserved memory. diff --git a/tests/data/common.py b/tests/data/common.py index 34fdba321..7ec4a9018 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -32,7 +32,7 @@ def get_sampling_data( preprocessing: LanguageModelPreprocessingConfig | None = None, ) -> GPTSamplingData: # Config with convenient defaults. - distributed = Distributed(DistributedConfig(), use_cpu=True) + distributed = Distributed(DistributedConfig(use_cuda=torch.cuda.is_available())) if preprocessing is None: preprocessing = LanguageModelPreprocessingConfig() return GPTSamplingData( @@ -71,8 +71,8 @@ def get_test_data_and_compare_samples( expected_samples: dict[str, list[list[int]]] | list[list[int]], preprocessing: LanguageModelPreprocessingConfig, ) -> GPTData: - distributed_config = DistributedConfig(seed=87522) - distributed = Distributed(distributed_config, use_cpu=True) + distributed_config = DistributedConfig(seed=87522, use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) if isinstance(samples_per_dataset, int): samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 20d16bb96..1bfde36ed 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -10,12 +10,12 @@ from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat, device="cuda" + num_columns: int, loss_masking: bool, target_format: TargetFormat ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" # We want something moderately close to the target for the test to be meaningful logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None @@ -49,7 +49,6 @@ def _compare_cross_entropy_outputs( assert ref_grad is None -@requires_cuda @pytest.mark.slow @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), @@ -85,6 +84,8 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) + if not torch.cuda.is_available(): + return if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) @@ -111,7 +112,6 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso return output, logits.grad -@requires_cuda @pytest.mark.slow # TODO: Support the same parameterization as above in the reference implementation. @pytest.mark.parametrize("loss_masking", [False, True]) @@ -206,7 +206,6 @@ def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGr raise RuntimeError("Test failed") -@requires_cuda @pytest.mark.slow def test_distillation_losses(): _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index c48a0a531..76c0841d9 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -8,7 +8,6 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans -from tests.utils.utils import requires_cuda def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): @@ -78,22 +77,22 @@ def test_dpo_loss(): Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) -@requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] ) def test_mlp_recomputation(gated, activation): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokens = 1024 hidden_size = 2048 intermediate_size = 4096 std = 1 / 64 - input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device="cuda", requires_grad=True) - bias_1 = torch.normal(0, std, (intermediate_size * (gated + 1),), device="cuda", requires_grad=True) - weight_2 = torch.normal(0, std, (intermediate_size, hidden_size), device="cuda", requires_grad=True) - bias_2 = torch.normal(0, std, (hidden_size,), device="cuda", requires_grad=True) + input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device=device, requires_grad=True) + bias_1 = torch.normal(0, std, (intermediate_size * (gated + 1),), device=device, requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size, hidden_size), device=device, requires_grad=True) + bias_2 = torch.normal(0, std, (hidden_size,), device=device, requires_grad=True) params = (weight_1, bias_1, weight_2, bias_2) output_ref = torch.nn.functional.linear( @@ -137,27 +136,27 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow @pytest.mark.skip("Dropless MoE is broken") -@requires_cuda def test_dropless_mlp(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_experts = 4 experts_per_token = 4 tokens = 256 hidden_size = 512 intermediate_size = 1024 std = 1 / 64 - input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - router_weight = torch.normal(0, std, (num_experts, hidden_size), device="cuda") + input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + router_weight = torch.normal(0, std, (num_experts, hidden_size), device=device) top_logits, top_experts = torch.topk( torch.nn.functional.linear(input_.detach(), router_weight), k=experts_per_token, dim=-1 ) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).detach().requires_grad_() - output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) + output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) weight_1 = torch.normal( - 0, std, (intermediate_size * 2 * num_experts, hidden_size), device="cuda", requires_grad=True + 0, std, (intermediate_size * 2 * num_experts, hidden_size), device=device, requires_grad=True ) - weight_2 = torch.normal(0, std, (intermediate_size * num_experts, hidden_size), device="cuda", requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size * num_experts, hidden_size), device=device, requires_grad=True) params = (weight_1, weight_2) for param in params: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index a5a693be6..79817bb03 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -19,9 +19,10 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.layers.attention.rotary.rotary import ( - apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, + rotary_embeddings_complex, + rotary_embeddings_real, ) from fast_llm.utils import Assert, rms_diff from tests.utils.utils import requires_cuda @@ -81,30 +82,32 @@ def test_triton_add(): ) def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): assert TritonConfig.TRITON_ENABLED - x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.bfloat16, device="cuda") - - y1 = apply_rotary_embeddings( - x, - DefaultRotaryConfig(triton=False) + x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device="cuda") + frequencies = ( + DefaultRotaryConfig() .get_layer(TensorDim("", head_size)) ._get_frequencies( sequence_length, head_size, device="cuda", - ), + ) ) - y2 = convert_rotary_real_to_complex( - triton_rotary_( - convert_rotary_complex_to_real(x, head_size, 3), - DefaultRotaryConfig(triton=True) - .get_layer(TensorDim("", head_size)) - ._get_frequencies(sequence_length, head_size, device="cuda"), + y_real = rotary_embeddings_real(x, frequencies) + + y_complex = convert_rotary_complex_to_real( + rotary_embeddings_complex( + convert_rotary_real_to_complex(x, head_size, 3), + torch.view_as_complex(convert_rotary_real_to_complex(frequencies, head_size, 3).unflatten(-1, (-1, 2))), ), head_size, 3, ) - Assert.rms_close(y1, y2, 1e-3) + + y_triton = triton_rotary_(x, frequencies) + + Assert.rms_close(y_real, y_complex, 1e-4) + Assert.rms_close(y_real, y_triton, 1e-4) @requires_cuda diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 508597173..924c2cc7f 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -3,19 +3,19 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.attention import Attention, _flash_available from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda -@requires_cuda @pytest.mark.parametrize("cross_document_attention", (True, False)) @pytest.mark.parametrize(("causal", "window_size"), ((True, None), (True, 50), (False, None))) +@pytest.mark.skipif(not _flash_available, reason="Flash attention not available") def test_attention_implementations(cross_document_attention: bool, causal: bool, window_size: int | None): """ Check that the flash and backup attention implementation give the same result. """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") attention: Attention = AttentionConfig( head_size=32, heads=4, @@ -29,11 +29,11 @@ def test_attention_implementations(cross_document_attention: bool, causal: bool, lr_scale=None, peft=None, ) - query = torch.empty(4, 100, 4, 32, dtype=torch.bfloat16, device="cuda").normal_() - key = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device="cuda").normal_() - value = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device="cuda").normal_() + query = torch.empty(4, 100, 4, 32, dtype=torch.bfloat16, device=device).normal_() + key = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device=device).normal_() + value = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device=device).normal_() kwargs = { - AttentionKwargs.device: torch.device("cuda"), + AttentionKwargs.device: device, AttentionKwargs.sequence_length: 100, AttentionKwargs.sequence_lengths: [ [20, 32, 10, 11, 9, 18], diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d82..86ce0253d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -11,7 +11,7 @@ from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda +from tests.utils.utils import get_base_model, get_stage def _reverse_kl_loss( @@ -94,7 +94,6 @@ def _lm_head( VOCAB_SIZE = 500 -@requires_cuda @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( @@ -163,9 +162,11 @@ def test_lm_head( loss_masking: bool, prediction_heads: int, ): + if cross_entropy_impl in (CrossEntropyImpl.auto, CrossEntropyImpl.triton) and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") head_config = { "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm"}, + "normalization": {"type": "rms_norm", "implementation": "auto" if torch.cuda.is_available() else "torch"}, } config = GPTBaseModelConfig.from_dict( { @@ -191,7 +192,7 @@ def test_lm_head( GPTModelConfig.from_dict( { "base_model": config, - "distributed": distributed_config_dict, + "distributed": {**distributed_config_dict, "use_cuda": torch.cuda.is_available()}, }, ) ) diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 85d72b316..112c88a66 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -6,27 +6,26 @@ from fast_llm.layers.attention.rotary.config import Rotary2DConfig from fast_llm.layers.vision.config import VisionKwargs from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda -@requires_cuda def test_rotary_2d(): """ Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. """ head_dim = 16 num_heads = 8 + device = "cuda" if torch.cuda.is_available() else "cpu" patch_positions = torch.tensor( [[h, w] for h in range(4) for w in range(4)], dtype=torch.int64, - device="cuda", + device=device, ) - query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device="cuda").normal_() + query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=device).normal_() key = torch.empty_like(query).normal_() pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) - pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to("cuda") + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to(device) # Convert patch positions (h, w) to Pixtral's linear position IDs # Pixtral expects: position_id = h * max_patches_per_side + w position_ids = ( @@ -38,7 +37,7 @@ def test_rotary_2d(): ) fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) - kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: "cuda"} + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: device} fast_llm_rotary.preprocess(kwargs) output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index b371ba086..f6be506ca 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -1,5 +1,6 @@ import pytest import torch +import transformers from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_dim import TensorDim @@ -11,18 +12,7 @@ from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.utils import Assert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention -from tests.utils.utils import get_stage, requires_cuda - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba -except ImportError: - Apriel2GatedDeltaNet = None - Apriel2Mamba = None - -try: - from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention -except ImportError: - KimiDeltaAttention = None +from tests.utils.utils import get_stage HIDDEN_SIZE = 16 SEQ_LEN = 65 @@ -31,7 +21,9 @@ def _compare_mixers( fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str], threshold=1e-5 ): - distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16)) + distributed = Distributed( + distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=torch.cuda.is_available()) + ) fast_llm_layer = fast_llm_config.get_layer( distributed_config, TensorDim("", HIDDEN_SIZE), @@ -82,10 +74,9 @@ def _compare_mixers( @pytest.mark.slow -@pytest.mark.skipif(Apriel2GatedDeltaNet is None, reason="Apriel GDN deps missing") -@requires_cuda +# Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. +@pytest.mark.skipif(not transformers.utils.import_utils.is_causal_conv1d_available(), reason="GDN deps missing") def test_gdn(): - device = torch.device("cuda") dtype = torch.bfloat16 NUM_V_HEADS = 4 @@ -103,7 +94,7 @@ def test_gdn(): hf_layer = ( Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) - .to(device=device, dtype=dtype) + .to(device="cuda" if torch.cuda.is_available() else "cpu", dtype=dtype) .eval() ) fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) @@ -111,7 +102,6 @@ def test_gdn(): @pytest.mark.slow -@requires_cuda @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_kda(): NUM_HEADS = 4 @@ -133,10 +123,9 @@ def test_kda(): @pytest.mark.slow -@requires_cuda @pytest.mark.parametrize("add_linear_biases", [True, False]) @pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) -@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") +@pytest.mark.skipif(not transformers.utils.import_utils.is_mamba_ssm_available(), reason="Mamba not available") def test_mamba(add_linear_biases, repeat_kv_before_conv): D_INNER = 128 D_XB = 64 diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index c8d962f40..ff27a8e8d 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -12,12 +12,11 @@ from fast_llm.layers.ssm import kda as kda_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.utils import Assert -from tests.utils.utils import get_stage, requires_cuda +from tests.utils.utils import get_stage # TODO: include mamba varlen @pytest.mark.slow -@requires_cuda @pytest.mark.parametrize( "config", [ @@ -50,7 +49,9 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): """ hidden_size = 32 hidden_dim = TensorDim("hidden", hidden_size) - distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.float16)) + distributed = Distributed( + distributed_config := DistributedConfig(compute_dtype=DataType.float16, use_cuda=torch.cuda.is_available()) + ) mixer = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) stage = get_stage([mixer], distributed) @@ -71,11 +72,15 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): BlockKwargs.device: distributed.device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), + } + + kwargs_packed = { + **kwargs, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_length: seq_len, BlockKwargs.sequence_q_dim: TensorDim("", seq_len), BlockKwargs.sequence_k_dim: TensorDim("", seq_len), } - - kwargs_packed = {**kwargs, BlockKwargs.sequence_lengths: sequence_lengths} mixer.preprocess(kwargs_packed) out_packed, context = stage.forward(hidden_states, kwargs_packed) @@ -89,7 +94,14 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): out_refs = [] for i in range(batch_size): for seq in torch.split(hidden_states[i], sequence_lengths[i], dim=0): - kwargs_seq = {**kwargs, BlockKwargs.sequence_lengths: [[len(seq)]]} + seq_len_ = len(seq) + kwargs_seq = { + **kwargs, + BlockKwargs.sequence_lengths: [[seq_len_]], + BlockKwargs.sequence_length: seq_len_, + BlockKwargs.sequence_q_dim: TensorDim("", seq_len_), + BlockKwargs.sequence_k_dim: TensorDim("", seq_len_), + } mixer.preprocess(kwargs_seq) out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) stage.backward(torch.ones_like(out), context) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d14721142..f80b2b25f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -14,7 +14,6 @@ logger = logging.getLogger(__name__) -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_simple(run_test_script_for_all_models, run_test_script_base_path): # A simple config to prevent unnecessary testing and creation of dependency group diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e3870a7b1..2f476ae52 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -7,7 +7,6 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: @@ -18,7 +17,6 @@ def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: return model -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): model_testing_config.get_dataset() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1248a1117..41f209b58 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -8,6 +8,7 @@ import typing import pytest +import torch import transformers from fast_llm.config import set_nested_dict_value @@ -97,6 +98,7 @@ class ModelTestingConfig: auto_model_class: type["transformers.models.auto.auto_factory._BaseAutoModelClass"] = ( transformers.AutoModelForCausalLM ) + requires_cuda: bool = False def __post_init__(self): _, config, _, _ = self.get_dataset(config_only=True) @@ -259,10 +261,11 @@ def _update_and_add_testing_config( "reproducible_init": True, "timeout": 20, "backend": "nccl", + "use_cuda": torch.cuda.is_available(), }, }, "batch": {"batch_size": 8, "sequence_length": 512}, - "data": {}, + "data": {"sampling": {"gpu": torch.cuda.is_available()}}, "optimizer": {"learning_rate": {"base": 0.0001}}, }, megatron_args=[ @@ -698,6 +701,7 @@ def _update_and_add_testing_config( compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), + requires_cuda=True, ) _update_and_add_testing_config( @@ -792,6 +796,7 @@ def _update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), + requires_cuda=True, ) _update_and_add_testing_config( @@ -914,6 +919,7 @@ def _update_and_add_testing_config( # Pipeline-parallel gives a different mixer selection. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). skip_tests=("sdp", "ms", "pp", TP_NO_STP), + requires_cuda=True, ) @@ -957,6 +963,7 @@ def _update_and_add_testing_config( # Micro-sequence split and sequence-first not supported for Mamba. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP), + requires_cuda=True, ) @@ -996,6 +1003,7 @@ def _update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), + requires_cuda=True, ) @@ -1013,6 +1021,8 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] + if model_config.requires_cuda and not torch.cuda.is_available(): + item.add_marker(pytest.mark.skip(reason=f"Cuda not available.")) for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 3b79f7607..e176d9b32 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +requires_cuda = pytest.mark.skipif(False, reason="CUDA is not available") @pytest.fixture(scope="session") From f144b87e6cfa83bd5e8dd80cb74b472af0c53f5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 Jan 2026 22:56:47 -0500 Subject: [PATCH 29/51] fixes --- fast_llm/engine/checkpoint/convert.py | 3 ++ fast_llm/engine/config_utils/run.py | 2 +- fast_llm/functional/triton/mlp.py | 1 - fast_llm/layers/attention/attention.py | 8 +--- fast_llm/layers/attention/rotary/config.py | 7 ++++ .../common/normalization/normalization.py | 18 +++++++- fast_llm/layers/language_model/head.py | 2 +- fast_llm/logging.py | 2 +- fast_llm/models/gpt/conversion/llama.py | 9 ---- fast_llm/models/gpt/megatron.py | 9 ---- tests/conftest.py | 3 ++ tests/functional/test_cross_entropy.py | 2 +- tests/models/test_checkpoint.py | 42 +++++++++---------- tests/models/test_generate.py | 8 ---- tests/models/test_lm_eval.py | 5 --- tests/models/test_model.py | 9 ++-- tests/utils/compare_tensor_logs.py | 1 + tests/utils/distributed_configs.py | 16 +++++-- tests/utils/utils.py | 2 +- 19 files changed, 73 insertions(+), 76 deletions(-) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index b40d8a1b3..103d9488c 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -20,6 +20,7 @@ class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() + use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) @@ -62,6 +63,7 @@ def _convert_model_partial( logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...") model = model_class.from_pretrained( self.input, + {("distributed", "use_cuda"): not self.use_cpu}, mode=StageMode.weights, stage_filter=stage_filter, ) @@ -94,6 +96,7 @@ def run(self): # Create a dummy version to determine the stage split. model = model_class.from_pretrained( self.input.to_copy({"model_weights": False}), + {("distributed", "use_cuda"): not self.use_cpu}, mode=StageMode.off_device, ) stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 415147d06..77507afa8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels and distributed.config.use_cuda + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index dc3ee0f04..286e7159a 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -357,7 +357,6 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ intermediate_2_ = intermediate_2 intermediate_2_.backward(grad_intermediate_2) grad_intermediate_1 = intermediate_1.grad - print("AAAAA", intermediate_2 is None, grad_intermediate_1) # Layer 2 parameter grad del grad_intermediate_2, intermediate_1 diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 902352c25..be58724ea 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -222,7 +222,7 @@ def _attn_backup( attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( - attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value + attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk).to(value.dtype), value ) if self._local_head_groups == 1: @@ -459,11 +459,6 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size if self._config.causal: - print( - "WWWWW", - kwargs[AttentionKwargs.sequence_length], - self._backup_attention_tensor_cache_max_sequence_length, - ) if ( sequence_length := kwargs[AttentionKwargs.sequence_length] ) > self._backup_attention_tensor_cache_max_sequence_length: @@ -496,7 +491,6 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if attention_mask is None: attention_mask = document_mask else: - print("AAAA", attention_mask.shape, document_mask.shape, kwargs) attention_mask = attention_mask & document_mask kwargs[AttentionKwargs.attention_mask] = attention_mask diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 1ec35ae0c..80f499748 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -58,6 +58,13 @@ class DefaultRotaryConfig(RotaryConfig): hint=FieldHint.architecture, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if "complex_format" in default: + Assert.is_(default["complex_format"], False) + del default["complex_format"] + return super()._from_dict(default, strict=strict) + def _get_configurable_class(self) -> "type[DefaultRotary]": from fast_llm.layers.attention.rotary.rotary import DefaultRotary diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index b1e875707..4bd1343aa 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -183,6 +183,12 @@ class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[Con def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) implementation = self._config.implementation + print( + "IKUEGBNHIUWGBN", + implementation, + TritonConfig.TRITON_ENABLED, + (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, + ) if implementation == NormalizationImplementation.auto: if ( _fast_normalization_available @@ -190,7 +196,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | and not self._config.zero_centered ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: + elif (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -199,6 +205,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch + print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: @@ -258,8 +265,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel implementation = self._config.implementation + print( + "IKUEGBNHIUWGBN", + implementation, + TritonConfig.TRITON_ENABLED, + (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, + ) if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._config.zero_centered: + if (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -267,6 +280,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch + print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..9f3b6506f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -92,7 +92,7 @@ def __init__( if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._vocab_parallel: self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: + elif TritonConfig.TRITON_ENABLED and torch.cuda.is_available(): self._cross_entropy_impl = CrossEntropyImpl.triton else: self._cross_entropy_impl = CrossEntropyImpl.fused diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 931c7f644..5a2ff2dac 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -196,7 +196,7 @@ def log_tensor[ step = max(tensor.numel() // target_samples, 1) while step > 1 and any(step % s == 0 and s > 1 for s in shape): step -= 1 - samples = tensor.flatten()[: target_samples * step : step].cpu() + samples = tensor.flatten()[: target_samples * step : step].to("cpu", copy=True) stats.update(samples=samples, step=step) # Crop the list in the logs. The full tensor is still in stats. samples = [format_number(x) for x in samples.tolist()[: TensorLogs.config.max_elements]] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index bc75f6236..00d871dbf 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -15,7 +15,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -314,16 +313,12 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.head_size, 0) return (query,) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.head_size, 0) return (query,) @@ -336,16 +331,12 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.head_size, 0) return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.head_size, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index f63bd76f8..3b97df3d1 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,6 +1,5 @@ import typing -from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div @@ -84,12 +83,10 @@ def _init_attention_megatron( generator, ) if "dense" in meta.tensor_name: - kv_dim = 1 tensor_ = dense_tensor_ else: # Keep the original random state for key_value and dense. generator.set_state(state) - kv_dim = 0 if "query" in meta.tensor_name: # We want to generate the same tensor for key_value. tensor_ = qkv_tensor_[:, :heads_per_group] @@ -98,12 +95,6 @@ def _init_attention_megatron( else: raise NotImplementedError(meta.tensor_name) - if isinstance(config.mixer.rotary, DefaultRotaryConfig) and config.mixer.rotary.complex_format: - from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex - - # Megatron uses (2, head_size/2) for the complex split; we use (head_size/2, 2). - # TODO: Avoid unnecessarily changing the value and dense tensors. - tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.mixer.head_size, kv_dim) return tensor_ diff --git a/tests/conftest.py b/tests/conftest.py index ba2927c64..28bab0ad5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import pytest import xdist.scheduler +from fast_llm.functional.config import TritonConfig from fast_llm.utils import get_and_reset_memory_usage_mib from tests.utils.depends import DependencyManager from tests.utils.global_variables import TEST_RESULTS_PATH, set_testing_global_variables @@ -259,6 +260,8 @@ def pytest_runtest_call(item: pytest.Function): except RuntimeError: pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") manager.handle_missing(item) + # Some tests may modify this global variable. + TritonConfig.TRITON_ENABLED = True def pytest_unconfigure(): diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 1bfde36ed..420316ce3 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -66,7 +66,6 @@ def _compare_cross_entropy_outputs( @pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): # TODO: Test tensor-parallel implementation. - assert TritonConfig.TRITON_ENABLED logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { "logits": logits, @@ -86,6 +85,7 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski if not torch.cuda.is_available(): return + assert TritonConfig.TRITON_ENABLED if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29e..5c31dde16 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -22,7 +22,6 @@ from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig -from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -35,7 +34,6 @@ ] -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). @@ -59,7 +57,6 @@ def do_prepare_resume(distributed_testing_config: DistributedTestingConfig): return do_prepare_resume -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume(run_test_script_for_all_models, compare_results_for_all_models, prepare_resume): @@ -75,7 +72,6 @@ def test_resume(run_test_script_for_all_models, compare_results_for_all_models, compare_results_for_all_models(distributed_testing_config) -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume_frozen(run_test_script_for_all_models, prepare_resume): @@ -102,13 +98,13 @@ def do_run_conversion( path=get_convert_path(save_format, load_format), format=save_format, ), + use_cpu=not torch.cuda.is_available(), model=model_testing_config.model_config_class, ).run() return do_run_conversion -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_conversion(model_testing_config, run_conversion, get_convert_path): @@ -171,7 +167,6 @@ def _compare_safetensor_files( Assert.all_equal(reference[key], other[key]) -@requires_cuda @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_round_trip(model_testing_config, get_convert_path): @@ -218,7 +213,8 @@ def do_load_and_compare_checkpoints( CheckpointLoadConfig( path=load_path, format=load_format, - ) + ), + {("distributed", "use_cuda"): torch.cuda.is_available()}, ) if reference_config is not None: _compare_model_configs(reference_config, model.config) @@ -228,7 +224,6 @@ def do_load_and_compare_checkpoints( return do_load_and_compare_checkpoints -@requires_cuda @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained( @@ -238,9 +233,9 @@ def test_load_pretrained( reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) - reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ - _WEIGHT_SHARD_SAVE_NAME - ] + reference_shard = safetensors.torch.load_file( + get_convert_path() / "rank_0.safetensors", device="cuda" if torch.cuda.is_available() else "cpu" + )[_WEIGHT_SHARD_SAVE_NAME] load_and_compare_checkpoints( FastLLMCheckpointFormat, get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), @@ -303,10 +298,11 @@ def test_load_pretrained( ) -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): + device = "cuda" if torch.cuda.is_available() else "cpu" + distributed_update = {("distributed", "use_cuda"): torch.cuda.is_available()} if model_testing_config.checkpoint_format is None: return # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. @@ -323,18 +319,19 @@ def test_huggingface_model(model_testing_config, get_convert_path): path=get_convert_path(), format=DistributedCheckpointFormat, load_config=ModelConfigType.model, - ) + ), + distributed_update, ).eval() test_input = torch.randint( 0, 384, size=(4, 100), dtype=torch.int64, - device="cuda", + device=device, ) kwargs = {} if model_testing_config.model_type == "multimodal": - kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).cuda() + kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(device) kwargs["image_sizes"] = torch.tensor( [ [20, 20], # Full image, 25 patches @@ -360,16 +357,21 @@ def test_huggingface_model(model_testing_config, get_convert_path): # Last one cropped out. output_ref = model_ref(test_input, **kwargs) - model_from_fast_llm = hf_class.from_pretrained(fast_llm_path).eval() + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path, distributed_update).eval() model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, - ) + ), + distributed_update, ).eval() errors = [] - model_as_hf = model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True).cuda().eval() + model_as_hf = ( + model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True) + .to("cuda" if torch.cuda.is_available() else "cpu") + .eval() + ) for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), @@ -391,7 +393,6 @@ def test_huggingface_model(model_testing_config, get_convert_path): raise ValueError(f"Comparison failed ({len(errors)} errors)") -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request): @@ -430,7 +431,6 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: # We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. # This should still run after `test_save_and_load_in_parallel` -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( @@ -464,7 +464,6 @@ def test_load_parallel_checkpoint_in_single_gpu( ) -@requires_cuda @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_base_path): @@ -496,7 +495,6 @@ def reference_fast_llm_shard(get_convert_path) -> dict[str, torch.Tensor] | None return None -@requires_cuda @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_multi_gpu_fast_llm_checkpoint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index bce77d4f2..c595b5148 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -11,7 +11,6 @@ from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda def _prepare_data(tokenizer, use_batch_size2: bool): @@ -206,7 +205,6 @@ def _test_generate( @pytest.mark.extra_slow -@requires_cuda @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", [ @@ -238,7 +236,6 @@ def test_generate( ) -@pytest.mark.slow @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_export_for_generate(run_test_script_for_all_models, model_testing_config): # Not really testing, anything, but handles dependencies more easily than a fixture. @@ -254,7 +251,6 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi @pytest.mark.slow -@requires_cuda @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", @@ -307,7 +303,6 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) ) -@requires_cuda @pytest.mark.extra_slow def test_generate_from_model( model_path, @@ -315,7 +310,6 @@ def test_generate_from_model( _test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat) -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) @@ -356,7 +350,6 @@ def _test_forward_return_hidden_states( @pytest.mark.extra_slow -@requires_cuda def test_forward_return_hidden_states(model_path): _test_forward_return_hidden_states( model_path, LlamaCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size @@ -364,7 +357,6 @@ def test_forward_return_hidden_states(model_path): @pytest.mark.slow -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.generate) @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) def test_small_forward_return_hidden_states(model_testing_config, run_test_script_base_path): diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 8011b5bbc..7ae26c2d6 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -7,7 +7,6 @@ from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda # NOTE: These tests only verify that the functionality runs without crashing. # NOTE: The tokenizer is from a LLaMA-style model, which may not be suitable for all models, @@ -55,7 +54,6 @@ def do_get_lm_eval_config(base_path): # "gsm8k,xnli_en,wikitext" -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): run_test_script_for_all_models( @@ -76,7 +74,6 @@ def do_copy_training_output(distributed_testing_config: DistributedTestingConfig return do_copy_training_output -@requires_cuda @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_evaluation_last_checkpoint( @@ -91,7 +88,6 @@ def test_lm_eval_evaluation_last_checkpoint( run_test_script_for_all_models(distributed_testing_config=distributed_testing_config, runnable_type="evaluate") -@requires_cuda @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_evaluation_from_pretrained( @@ -111,7 +107,6 @@ def test_lm_eval_evaluation_from_pretrained( # TODO: rewrite for a new distributed test function -# @requires_cuda # @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) # @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed) # def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): diff --git a/tests/models/test_model.py b/tests/models/test_model.py index f80b2b25f..b3247102b 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -9,7 +9,7 @@ SINGLE_GPU_TESTING_CONFIGS, ) from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success +from tests.utils.utils import check_subtest_success, set_subtest_success logger = logging.getLogger(__name__) @@ -21,7 +21,6 @@ def test_model_simple(run_test_script_for_all_models, run_test_script_base_path) set_subtest_success(run_test_script_base_path / SIMPLE_TESTING_CONFIG.name) -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.basic) # Parametrize with config name so it shows in test name. @@ -46,9 +45,9 @@ def test_and_compare_model( if config.compare is not None: compare_results_for_all_models(config) + # raise ValueError() -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group( ModelTestingGroup.distributed, @@ -56,6 +55,9 @@ def test_and_compare_model( def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): import tests.models.distributed_test_model + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") + script = [ "-m", tests.models.distributed_test_model.__name__, @@ -73,7 +75,6 @@ def test_run_model_distributed(run_distributed_script, model_testing_config, run # We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure. # This should still run after `test_model_distributed` -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.distributed) @pytest.mark.parametrize("config_name", list(DISTRIBUTED_TESTING_CONFIGS)) diff --git a/tests/utils/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py index 1c8ebd76a..9a13fd13f 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/tests/utils/compare_tensor_logs.py @@ -140,6 +140,7 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), f" Ref samples: " + "".join(f"{x:12.4e}" for x in samples_ref[: self.show_samples].tolist()), + f"scale={sub_config.scale}", ] ) errors.append("\n".join([f">>>> [{step_name}] Excessive diff for tensor {tensor_name}:"] + tensor_errors)) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 9c1cc9369..5b45371e6 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -2,6 +2,8 @@ import dataclasses import logging +import torch + from tests.utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) @@ -58,8 +60,9 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ("init", None): get_config(), (None, "fw"): get_config(1.5e-2, 1.5e-3), (None, "bw"): get_config(1.5e-2, 1e-5), - (None, "bias"): get_config(2e-2, 1e-3), - (None, "gradient"): get_config(2e-2, 5e-5), + # Error is higher on cpu. TODO: Diff too big, especially for bias. + (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(0.25, 2e-3), + (None, "gradient"): get_config(2e-2, 5e-5) if torch.cuda.is_available() else get_config(8e-2, 1e-4), } ) @@ -69,8 +72,13 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Saved gradient include the gradient scaling by 2**16 (default initial value) (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), - (None, "bias"): get_config(3e-3, 1e-4, scale=2**16), - (None, "gradient"): get_config(3e-3, 5e-5, scale=2**16), + # Error is higher on cpu. + (None, "bias"): ( + get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 2e-4, scale=2**16) + ), + (None, "gradient"): ( + get_config(3e-3, 5e-5, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 1e-4, scale=2**16) + ), } ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index e176d9b32..3b79f7607 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -requires_cuda = pytest.mark.skipif(False, reason="CUDA is not available") +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") @pytest.fixture(scope="session") From 6e54c93bace0c52837724c08d3f510118a31316b Mon Sep 17 00:00:00 2001 From: oleksost Date: Mon, 12 Jan 2026 18:25:23 +0000 Subject: [PATCH 30/51] comments --- fast_llm/layers/language_model/config.py | 29 +++++++++++++++++++-- fast_llm/layers/language_model/embedding.py | 3 +-- fast_llm/layers/language_model/head.py | 2 +- fast_llm/layers/language_model/kwargs.py | 23 ---------------- fast_llm/models/gpt/model.py | 14 +++++----- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_lm_head.py | 3 +-- 7 files changed, 38 insertions(+), 38 deletions(-) delete mode 100644 fast_llm/layers/language_model/kwargs.py diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index adf8dd86e..ab8848d99 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -10,11 +10,10 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs, TargetsKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -27,6 +26,28 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction +class TargetsKwargs: + lm_target = "preprocessed_lm_target" + dpo_target = "preprocessed_dpo_target" + reference_model_logits = "reference_model_logits" + dpo_reference_model_logits = "dpo_reference_model_logits" + + +class LanguageModelKwargs(BlockKwargs): + token_ids = "token_ids" + position_ids = "position_ids" + token_map = "token_map" + sample_map = "sample_map" + embedding_map = "embedding_map" + # TODO: These are generic + labels = "labels" + phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + loss_mask = "loss_mask" + mask_inputs = "mask_inputs" + + def _format_name(name: str) -> str: return name.replace("_", " ") @@ -610,6 +631,10 @@ def enable_dpo(self) -> bool: def enable_distillation(self) -> bool: return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + @property + def requires_loss_masks(self) -> bool: + return self.enable_distillation + @property def distillation_model(self) -> str | None: for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index fda5e3387..93850d24c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,8 +10,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 7f303684f..2fa2dffe0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,9 +22,9 @@ LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, + LanguageModelKwargs, _format_name, ) -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert, div, get_unique diff --git a/fast_llm/layers/language_model/kwargs.py b/fast_llm/layers/language_model/kwargs.py deleted file mode 100644 index 4f6203881..000000000 --- a/fast_llm/layers/language_model/kwargs.py +++ /dev/null @@ -1,23 +0,0 @@ -from fast_llm.layers.block.config import BlockKwargs - - -class TargetsKwargs: - lm_target = "preprocessed_lm_target" - dpo_target = "preprocessed_dpo_target" - reference_model_logits = "reference_model_logits" - dpo_reference_model_logits = "dpo_reference_model_logits" - - -class LanguageModelKwargs(BlockKwargs): - token_ids = "token_ids" - position_ids = "position_ids" - token_map = "token_map" - sample_map = "sample_map" - embedding_map = "embedding_map" - # TODO: These are generic - labels = "labels" - phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" - loss_mask = "loss_mask" - mask_inputs = "mask_inputs" diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 846c65646..f83d12ca4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -12,7 +12,7 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -263,7 +263,6 @@ def preprocess_batch( if phase != PhaseType.inference: labels_begin = tokens_begin + 1 labels_end = tokens_end + self._config.head.max_prediction_distance - labels = batch.tokens.crop(labels_begin, labels_end).tokens if batch.loss_masking_spans is not None: @@ -272,13 +271,14 @@ def preprocess_batch( for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): for begin, end in loss_masking_spans: loss_mask[sample_index, begin:end] = False - if ( - self._config.head.distillation_model is not None - or self._config.decoder.block.distillation_model is not None - ): - kwargs[LanguageModelKwargs.loss_mask] = loss_mask labels = torch.where(loss_mask, labels, -100) + if ( + self._config.head.requires_loss_masks is not None + ): # loss masks only used for distillation currently + # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders + kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 + kwargs[LanguageModelKwargs.labels] = ( labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels ).contiguous() diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 88da79e65..890d5760e 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel from fast_llm.models.gpt.config import GPTBatchConfig diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index c98c2780a..e01beb031 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,9 +7,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelLossConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs, LanguageModelLossConfig from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.kwargs import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda From 305244fd89a9f3cb2ac798dc8170ac15c8061f86 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 Jan 2026 18:50:22 -0500 Subject: [PATCH 31/51] fixes --- fast_llm/core/kernels.py | 2 +- fast_llm/layers/attention/attention.py | 2 +- fast_llm/layers/common/linear/convolution.py | 2 +- .../common/normalization/normalization.py | 18 ++------------- fast_llm/layers/ssm/gdn.py | 10 ++++----- fast_llm/layers/ssm/mamba.py | 8 ++----- tests/layers/test_ssm.py | 4 +++- tests/layers/test_varlen.py | 7 ++++-- tests/utils/distributed_configs.py | 22 ++++++++++++++----- 9 files changed, 36 insertions(+), 39 deletions(-) diff --git a/fast_llm/core/kernels.py b/fast_llm/core/kernels.py index 93371a654..33ab4349c 100644 --- a/fast_llm/core/kernels.py +++ b/fast_llm/core/kernels.py @@ -12,7 +12,7 @@ from amp_C import multi_tensor_scale as _multi_tensor_scale # noqa from apex.multi_tensor_apply import multi_tensor_applier as _multi_tensor_applier # noqa - _apex_available = True + _apex_available = torch.cuda.is_available() except ImportError: _apex_available = False diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index be58724ea..d6eab0eb2 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -22,7 +22,7 @@ from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa from flash_attn.flash_attn_interface import flash_attn_varlen_func as _flash_attn_varlen_func - _flash_available = True + _flash_available = torch.cuda.is_available() except ImportError: _flash_available = False diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index c336a7e99..e8b00fb3c 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -7,7 +7,7 @@ try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa - _causal_conv1d_available = True + _causal_conv1d_available = torch.cuda.is_available() except (ImportError, RuntimeError): _causal_conv1d_available = False diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 4bd1343aa..55e62af22 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -22,14 +22,14 @@ try: import fused_layer_norm_cuda # noqa - _fused_normalization_available = True + _fused_normalization_available = torch.cuda.is_available() except ImportError: _fused_normalization_available = False try: import fast_layer_norm # noqa - _fast_normalization_available = True + _fast_normalization_available = torch.cuda.is_available() except ImportError: _fast_normalization_available = False @@ -183,12 +183,6 @@ class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[Con def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) implementation = self._config.implementation - print( - "IKUEGBNHIUWGBN", - implementation, - TritonConfig.TRITON_ENABLED, - (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, - ) if implementation == NormalizationImplementation.auto: if ( _fast_normalization_available @@ -205,7 +199,6 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch - print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: @@ -265,12 +258,6 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel implementation = self._config.implementation - print( - "IKUEGBNHIUWGBN", - implementation, - TritonConfig.TRITON_ENABLED, - (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, - ) if implementation == NormalizationImplementation.auto: if (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: implementation = NormalizationImplementation.triton @@ -280,7 +267,6 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch - print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 474108482..3103fefe2 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -22,12 +22,12 @@ logger = logging.getLogger(__name__) try: + from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa from fla.ops.gated_delta_rule import chunk_gated_delta_rule -except ImportError: - chunk_gated_delta_rule = None - -is_fast_path_available = chunk_gated_delta_rule is not None + _fast_path_available = torch.cuda.is_available() +except (ImportError, RuntimeError): + _fast_path_available = False def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: @@ -240,7 +240,7 @@ def __init__( self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule - if not is_fast_path_available: + if not _fast_path_available: logger.warning( "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." ) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index e90c1e01f..81b82d08e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -21,14 +21,10 @@ try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa - _mamba_available = True - sig = inspect.signature(selective_scan_fn) + _mamba_available = torch.cuda.is_available() # for training with packing install https://github.com/jxiw/varlen_mamba # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md - if "position_indices" in sig.parameters: - _mamba_varlen_available = True - else: - _mamba_varlen_available = False + _mamba_varlen_available = "position_indices" in inspect.signature(selective_scan_fn).parameters except (ImportError, RuntimeError): _mamba_available = False diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index f6be506ca..64dd1afeb 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -102,7 +102,9 @@ def test_gdn(): @pytest.mark.slow -@pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") +@pytest.mark.skipif( + kda_module.chunk_kda is None or not torch.cuda.is_available(), reason="KDA fused kernels not available" +) def test_kda(): NUM_HEADS = 4 HEAD_DIM = 4 diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index ff27a8e8d..770e24290 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -34,12 +34,15 @@ pytest.param( GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None, reason="GDN fused kernels not available" + gdn_module.chunk_gated_delta_rule is None or not torch.cuda.is_available(), + reason="GDN fused kernels not available", ), ), pytest.param( KimiDeltaAttentionConfig(heads=4, head_dim=16), - marks=pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available"), + marks=pytest.mark.skipif( + kda_module.chunk_kda is None or not torch.cuda.is_available(), reason="KDA fused kernels not available" + ), ), ], ) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 5b45371e6..0b54f63f7 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -60,9 +60,17 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ("init", None): get_config(), (None, "fw"): get_config(1.5e-2, 1.5e-3), (None, "bw"): get_config(1.5e-2, 1e-5), - # Error is higher on cpu. TODO: Diff too big, especially for bias. - (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(0.25, 2e-3), - (None, "gradient"): get_config(2e-2, 5e-5) if torch.cuda.is_available() else get_config(8e-2, 1e-4), + # TODO: Diff too big for normalization gradients on CPU. + **( + {} + if torch.cuda.is_available() + else { + (None, "norm"): get_config(0.25, 2e-3), + (None, "word_embeddings_weight"): get_config(0.08, 1e-4), + } + ), + (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(2e-2, 2e-3), + (None, "gradient"): get_config(2e-2, 5e-5) if torch.cuda.is_available() else get_config(2e-2, 1e-4), } ) @@ -72,12 +80,14 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Saved gradient include the gradient scaling by 2**16 (default initial value) (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), - # Error is higher on cpu. + # TODO: Diff too big on CPU, especially for bias and normalization. + # TODO: Diff too big for normalization gradients on CPU. + **({} if torch.cuda.is_available() else {(None, "norm"): get_config(0.25, 2e-3, scale=2**16)}), (None, "bias"): ( - get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 2e-4, scale=2**16) + get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(6e-3, 2e-4, scale=2**16) ), (None, "gradient"): ( - get_config(3e-3, 5e-5, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 1e-4, scale=2**16) + get_config(3e-3, 5e-5, scale=2**16) if torch.cuda.is_available() else get_config(6e-3, 1e-4, scale=2**16) ), } ) From f71319f4dca82ae0794c91f134345ae7758bd9fa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 Jan 2026 19:38:05 -0500 Subject: [PATCH 32/51] fix --- fast_llm/layers/ssm/gdn.py | 23 +++++++++++++++++------ tests/layers/test_varlen.py | 2 +- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 3103fefe2..86e7bcf56 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -23,11 +23,18 @@ try: from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa + + _causal_conv1d_available = torch.cuda.is_available() +except (ImportError, RuntimeError): + _causal_conv1d_available = False + + +try: from fla.ops.gated_delta_rule import chunk_gated_delta_rule - _fast_path_available = torch.cuda.is_available() + _fla_available = torch.cuda.is_available() except (ImportError, RuntimeError): - _fast_path_available = False + _fla_available = False def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: @@ -238,12 +245,16 @@ def __init__( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule - - if not _fast_path_available: + if _fla_available: + self.chunk_gated_delta_rule = chunk_gated_delta_rule + else: logger.warning( - "Fast paths for GatedDeltaNet are not available. Please ensure that 'causal_conv1d' and 'fla' are properly installed." + "Fast implementation for GatedDeltaNet is not available. Please ensure that 'fla' is properly installed." ) + self.chunk_gated_delta_rule = torch_chunk_gated_delta_rule + + if not _causal_conv1d_available: + raise RuntimeError("Gated delta net requires `causal_conv1d`.") def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index 770e24290..371801e71 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -34,7 +34,7 @@ pytest.param( GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), marks=pytest.mark.skipif( - gdn_module.chunk_gated_delta_rule is None or not torch.cuda.is_available(), + not gdn_module._fla_available, reason="GDN fused kernels not available", ), ), From 43c58bf0e698a25b6c87efc85d1a2356cc3b1ad2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 Jan 2026 19:40:14 -0500 Subject: [PATCH 33/51] fix --- fast_llm/layers/ssm/kda.py | 9 +++++---- tests/layers/test_varlen.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 270ac65bf..3781792e3 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -22,9 +22,10 @@ try: from fla.ops.kda import chunk_kda from fla.ops.kda.gate import fused_kda_gate -except ImportError: - chunk_kda = None - fused_kda_gate = None + + _fla_available = torch.cuda.is_available() +except (ImportError, RuntimeError): + _fla_available = False def index_first_axis(x, indices): @@ -58,7 +59,7 @@ def __init__( super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) - if chunk_kda is None or fused_kda_gate is None: + if not _fla_available: raise ImportError( "KimiDeltaAttention requires the `fla-core` package. " "Please install it with `pip install -U fla-core`." diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index 371801e71..090276eed 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -34,8 +34,8 @@ pytest.param( GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), marks=pytest.mark.skipif( - not gdn_module._fla_available, - reason="GDN fused kernels not available", + not gdn_module._causal_conv1d_available, + reason="GDN not available", ), ), pytest.param( From 3c8f3c265abc71bb216bdb5ce0b1004f36c888da Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 12 Jan 2026 22:26:38 -0500 Subject: [PATCH 34/51] misc --- fast_llm/functional/cross_entropy.py | 57 ++++++--------- fast_llm/layers/common/auxiliary_loss.py | 57 +++++++++------ .../layers/decoder/mlp/mixture_of_experts.py | 4 +- fast_llm/layers/language_model/config.py | 73 ++++++++++--------- tests/utils/model_configs.py | 7 +- 5 files changed, 97 insertions(+), 101 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 9c4b7fcfc..c21b49a6c 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -75,7 +75,7 @@ def _fused_softmax( return exp_logits / sum_exp_logits -# @torch.compile +@torch.compile def _fused_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -85,7 +85,7 @@ def _fused_cross_entropy_forward_backward( target_format: TargetFormat, group: ProcessGroup | None = None, teacher_softmax_temperature: float = 1.0, - return_target_entropy: bool = False, + return_kl_loss: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -108,14 +108,14 @@ def _fused_cross_entropy_forward_backward( loss_mask = target >= 0 if group is None: # Keep values within range for scatter and gather ops to work. - target = target * loss_mask + target_masked = target * loss_mask target_mask = None else: # Mask the target (fused) # TODO: Could mask earlier on cpu or overlap with reduce? vocab_start_index = logits.size(-1) * group.rank() target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + target_masked = (target - vocab_start_index) * target_mask else: # Target should be tensor-parallel already, no further manipulation needed. target_mask = None @@ -128,10 +128,10 @@ def _fused_cross_entropy_forward_backward( # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. if target_format == TargetFormat.labels: grad_base = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + 1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) ) else: - grad_base = exp_logits - sum_exp_logits * target + grad_base = exp_logits - sum_exp_logits * target_masked grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) if logits_scale_factor != 1.0: @@ -142,13 +142,13 @@ def _fused_cross_entropy_forward_backward( # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) if target_format == TargetFormat.labels: - predicted_logits = logits_norm.gather(1, target) + predicted_logits = logits_norm.gather(1, target_masked) if group is not None: predicted_logits = target_mask * predicted_logits all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) else: - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + predicted_logits = (target_masked * logits_norm).sum(dim=-1, keepdim=True) if group is not None and target_format != TargetFormat.labels: # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) @@ -162,7 +162,7 @@ def _fused_cross_entropy_forward_backward( loss = per_sample_loss.mean() if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) - if return_target_entropy: + if return_kl_loss: if target_format == TargetFormat.logits: teacher_log_prob = target_logits - sum_exp_target_logits.log() else: @@ -173,7 +173,7 @@ def _fused_cross_entropy_forward_backward( target_entropy = target_entropy.mean() if group is not None: all_reduce(target_entropy, op=ReduceOp.SUM, group=group) - return loss, grad, target_entropy + loss -= target_entropy return loss, grad @@ -249,10 +249,7 @@ def _reverse_kl_forward_backward( target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - target_format: TargetFormat, group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. @@ -264,13 +261,6 @@ def _reverse_kl_forward_backward( loss_mask: [BxS] or [B, S] or None ... """ - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for sequence-tensor-parallel reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for sequence-tensor-parallel reverse KL") - Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) @@ -326,7 +316,6 @@ def reverse_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). @@ -349,12 +338,13 @@ def reverse_kl_forward_backward( loss: Reverse KL divergence loss grad: Gradients w.r.t. logits """ - - if sequence_parallel_logits: - # TODO: see hybrid dev branch where it is implemented - raise NotImplementedError("Sequence-parallel reverse KL is not implemented yet, set vocab_parallel true") - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq( + teacher_softmax_temperature, + 1, + msg="Teacher softmax temperature must be 1 for reverse KL", + ) + Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for reverse KL") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: @@ -366,9 +356,6 @@ def reverse_kl_forward_backward( target=target, loss_mask=loss_mask, grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - teacher_softmax_temperature=teacher_softmax_temperature, group=group, ) return distillation_loss, distillation_grad @@ -383,7 +370,6 @@ def forward_kl_forward_backward( logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, - sequence_parallel_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). @@ -408,7 +394,11 @@ def forward_kl_forward_backward( """ assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) - distillation_loss, distillation_grad, teacher_entropy = _fused_cross_entropy_forward_backward( + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + + return _fused_cross_entropy_forward_backward( logits=logits, target=target, loss_mask=loss_mask, @@ -417,8 +407,5 @@ def forward_kl_forward_backward( target_format=target_format, group=group, teacher_softmax_temperature=teacher_softmax_temperature, - return_target_entropy=True, + return_kl_loss=True, ) - distillation_loss -= teacher_entropy - - return distillation_loss, distillation_grad diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 335debb12..1c8fe1c73 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -3,9 +3,9 @@ class AuxiliaryLoss(torch.autograd.Function): @staticmethod - def forward(ctx, scores: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa ctx.grad = torch.full_like(aux_loss, grad) - return scores + return input_ @staticmethod def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa @@ -14,14 +14,33 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: @torch.compile def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - return torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + return torch.mean( + torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + ) -def z_loss( +def auxiliary_z_loss( logits: torch.Tensor, + z_loss_factor: float, + training: bool, grad_scale: float | None = None, + losses: dict | None = None, + loss_name: str | None = None, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + if losses is not None or (training and grad_scale is not None): + loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) + if losses is not None and loss_name is not None: + losses[loss_name].append(loss.detach()) + if training and grad_scale is not None: + logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) + + return logits + + +def z_loss_forward_backward( + logits: torch.Tensor, + grad_output: float | None = None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -33,22 +52,14 @@ def z_loss( loss: The z-loss value (unscaled) grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None """ - if logits_scale_factor != 1.0: - scaled_logits = logits * logits_scale_factor - else: - scaled_logits = logits - - # Forward: z_loss = mean(logsumexp^2) - lse = torch.logsumexp(scaled_logits, dim=-1) # (N,) - loss = torch.mean(lse**2) - - # Backward: grad = (2/N) * lse * softmax(scaled_logits) - grad = None - if grad_scale is not None: - N = scaled_logits.shape[0] - softmax_logits = torch.softmax(scaled_logits, dim=-1) - grad = (2.0 / N) * lse.unsqueeze(-1) * softmax_logits * grad_scale - if logits_scale_factor != 1.0: - grad = grad * logits_scale_factor # Chain rule for logits_scale_factor + + with torch.set_grad_enabled(grad_output is not None): + logits_ = logits.detach().requires_grad_(grad_output is not None) + loss = calculate_z_loss(logits, logits_scale_factor) + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.detach().to(logits.dtype) return loss, grad diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 5cc351dac..fd3647389 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -13,7 +13,7 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss +from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, auxiliary_z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.decoder.mlp.mlp import MLPBase @@ -102,7 +102,7 @@ def _forward( # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: - logits = z_loss( + logits = auxiliary_z_loss( logits, self._config.z_loss_coefficient, self.training, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index ab8848d99..e3de9e9cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -81,7 +81,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -118,13 +118,13 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: pass @config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) -class CrossEntropyLMLossConfig(LanguageModelLossConfig): +class CrossEntropyLanguageModelLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False @@ -134,10 +134,10 @@ class CrossEntropyLMLossConfig(LanguageModelLossConfig): hint=FieldHint.performance, ) - teacher_softmax_temperature: float = Field( + temperature: float = Field( default=1.0, hint=FieldHint.optional, - desc="Temperature for teacher softmax (used in distillation losses).", + desc="Temperature for teacher softmax.", valid=check_field(Assert.gt, 0.0), ) @@ -147,7 +147,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -202,19 +202,19 @@ def get_loss( group=group, implementation=implementation, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.labels, ) @config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) -class ForwardKLLossConfig(LanguageModelLossConfig): +class ForwardKLDistillationLossConfig(LanguageModelLossConfig): """Forward KL divergence KL(p||q) for distillation (mode-covering).""" _name: typing.ClassVar[str] = "FwdKL_loss" _abstract: typing.ClassVar[bool] = False - teacher_softmax_temperature: float = Field( + temperature: float = Field( default=1.0, hint=FieldHint.optional, desc="Temperature for teacher softmax.", @@ -231,7 +231,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -250,7 +250,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -266,13 +266,13 @@ def get_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.logits, ) @config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(ForwardKLLossConfig): +class ReverseKLLossConfig(ForwardKLDistillationLossConfig): """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" _name: typing.ClassVar[str] = "RevKL_loss" @@ -287,7 +287,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -304,7 +304,7 @@ def get_loss( grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.teacher_softmax_temperature, + teacher_softmax_temperature=self.temperature, target_format=TargetFormat.logits, ) @@ -339,7 +339,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: if kwargs is None: kwargs = {} @@ -365,7 +365,7 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, @@ -401,7 +401,7 @@ def get_targets( prediction_distance: int | None = None, prediction_heads: int | None = None, sequence_parallel_logits: bool | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, ) -> dict[str, "torch.Tensor"]: return {} @@ -410,16 +410,20 @@ def get_loss( logits: "torch.Tensor", loss_mask: "torch.Tensor | None", grad_output: float | None = None, - group: "ProcessGroup" = None, + group: "ProcessGroup|None" = None, logits_scale_factor: float | None = None, vocab_parallel: bool = False, kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss + from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward + + # TODO: ====== Support loss mask, vocab_parallel ====== + assert loss_mask is None + assert group is None - return z_loss( + return z_loss_forward_backward( logits=logits.flatten(0, -2), - grad_scale=grad_output, + grad_output=grad_output, logits_scale_factor=logits_scale_factor, ) @@ -549,13 +553,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - logit_z_loss: float = Field( - default=0.0, - desc="Regularize the logits with Z-loss.", - doc="We recommend 1e-4 for stability, as used for training PaLM.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) def get_layer( self, @@ -604,14 +601,17 @@ def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: if "losses" not in self._explicit_fields: - self.losses = {"lm_loss": CrossEntropyLMLossConfig()} + self.losses = {"lm_loss": CrossEntropyLanguageModelLossConfig()} super()._validate() if DPOLossConfig in self._loss_configs: - assert ForwardKLLossConfig not in self._loss_configs.keys() # currently don't support both + assert ForwardKLDistillationLossConfig not in self._loss_configs.keys() # currently don't support both assert ReverseKLLossConfig not in self._loss_configs.keys() # currently don't support both - if ForwardKLLossConfig in self._loss_configs.keys() and ReverseKLLossConfig in self._loss_configs.keys(): + if ( + ForwardKLDistillationLossConfig in self._loss_configs.keys() + and ReverseKLLossConfig in self._loss_configs.keys() + ): assert ( - self._loss_configs[ForwardKLLossConfig].distillation_model + self._loss_configs[ForwardKLDistillationLossConfig].distillation_model == self._loss_configs[ReverseKLLossConfig].distillation_model ), "Distillation losses must use the same teacher." @@ -629,7 +629,10 @@ def enable_dpo(self) -> bool: @property def enable_distillation(self) -> bool: - return ForwardKLLossConfig in self._loss_configs.keys() or ReverseKLLossConfig in self._loss_configs.keys() + return ( + ForwardKLDistillationLossConfig in self._loss_configs.keys() + or ReverseKLLossConfig in self._loss_configs.keys() + ) @property def requires_loss_masks(self) -> bool: @@ -637,7 +640,7 @@ def requires_loss_masks(self) -> bool: @property def distillation_model(self) -> str | None: - for loss_type in [ForwardKLLossConfig, ReverseKLLossConfig]: + for loss_type in [ForwardKLDistillationLossConfig, ReverseKLLossConfig]: if loss_type in self._loss_configs: return self._loss_configs[loss_type].distillation_model return None diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 8d705583d..d18ce934e 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -246,12 +246,7 @@ def update_and_add_testing_config( }, "num_blocks": 2, }, - "head": { - "output_weight": init_1, - "losses": { - "lm_loss": {"type": "cross_entropy"}, - }, - }, + "head": {"output_weight": init_1}, "hidden_size": 256, "tied_embedding_weight": True, }, From 705c482dc9d1e1ba80b167b13f27b68f8f873c3e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 Jan 2026 13:11:32 -0500 Subject: [PATCH 35/51] fix --- fast_llm/functional/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index c21b49a6c..2503dca5e 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -119,6 +119,7 @@ def _fused_cross_entropy_forward_backward( else: # Target should be tensor-parallel already, no further manipulation needed. target_mask = None + target_masked = target if loss_mask is not None: loss_mask = loss_mask.unsqueeze(-1) @@ -392,7 +393,6 @@ def forward_kl_forward_backward( loss: Forward KL divergence loss grad: Gradients w.r.t. logits """ - assert target_format == TargetFormat.logits, "Forward KL only supports logits format" Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: From b156f4eabae459f139cc5d398ad3730f9177504d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 Jan 2026 13:31:48 -0500 Subject: [PATCH 36/51] fix --- fast_llm/layers/ssm/gdn.py | 6 +++--- fast_llm/layers/ssm/kda.py | 6 +++--- tests/layers/test_ssm.py | 6 ++---- tests/layers/test_varlen.py | 13 ++++--------- tests/models/test_model.py | 1 - 5 files changed, 12 insertions(+), 20 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 86e7bcf56..c7a2c1c59 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -32,9 +32,9 @@ try: from fla.ops.gated_delta_rule import chunk_gated_delta_rule - _fla_available = torch.cuda.is_available() + _fast_gdn_available = torch.cuda.is_available() except (ImportError, RuntimeError): - _fla_available = False + _fast_gdn_available = False def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: @@ -245,7 +245,7 @@ def __init__( self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft ) - if _fla_available: + if _fast_gdn_available: self.chunk_gated_delta_rule = chunk_gated_delta_rule else: logger.warning( diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 3781792e3..94cde7d5f 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -23,9 +23,9 @@ from fla.ops.kda import chunk_kda from fla.ops.kda.gate import fused_kda_gate - _fla_available = torch.cuda.is_available() + _kda_available = torch.cuda.is_available() except (ImportError, RuntimeError): - _fla_available = False + _kda_available = False def index_first_axis(x, indices): @@ -59,7 +59,7 @@ def __init__( super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) - if not _fla_available: + if not _kda_available: raise ImportError( "KimiDeltaAttention requires the `fla-core` package. " "Please install it with `pip install -U fla-core`." diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index 64dd1afeb..b210b7c8b 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -8,8 +8,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.config import MixerConfig -from fast_llm.layers.ssm import kda as kda_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig +from fast_llm.layers.ssm.kda import _kda_available from fast_llm.utils import Assert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention from tests.utils.utils import get_stage @@ -102,9 +102,7 @@ def test_gdn(): @pytest.mark.slow -@pytest.mark.skipif( - kda_module.chunk_kda is None or not torch.cuda.is_available(), reason="KDA fused kernels not available" -) +@pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available") def test_kda(): NUM_HEADS = 4 HEAD_DIM = 4 diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index 090276eed..bc538f9a0 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -8,9 +8,9 @@ from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.config import MixerConfig -from fast_llm.layers.ssm import gdn as gdn_module -from fast_llm.layers.ssm import kda as kda_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig +from fast_llm.layers.ssm.gdn import _causal_conv1d_available +from fast_llm.layers.ssm.kda import _kda_available from fast_llm.utils import Assert from tests.utils.utils import get_stage @@ -33,16 +33,11 @@ ), pytest.param( GatedDeltaNetConfig(value_heads=4, key_heads=2, key_head_dim=16, value_head_dim=16), - marks=pytest.mark.skipif( - not gdn_module._causal_conv1d_available, - reason="GDN not available", - ), + marks=pytest.mark.skipif(not _causal_conv1d_available, reason="GDN not available"), ), pytest.param( KimiDeltaAttentionConfig(heads=4, head_dim=16), - marks=pytest.mark.skipif( - kda_module.chunk_kda is None or not torch.cuda.is_available(), reason="KDA fused kernels not available" - ), + marks=pytest.mark.skipif(not _kda_available, reason="KDA not available"), ), ], ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 8ec0e0303..0c58afade 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -47,7 +47,6 @@ def test_and_compare_model( if config.compare is not None: compare_results_for_all_models(config) - # raise ValueError() def _run_model_distributed( From 4fbc7a88af1f56fe494120cb76b7c12538d41f5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 Jan 2026 17:37:12 -0500 Subject: [PATCH 37/51] stuff --- fast_llm/functional/config.py | 8 +- fast_llm/functional/cross_entropy.py | 471 ++++++++------------ fast_llm/functional/triton/cross_entropy.py | 9 +- fast_llm/layers/language_model/config.py | 18 +- tests/functional/test_cross_entropy.py | 107 +---- tests/layers/test_lm_head.py | 6 +- 6 files changed, 242 insertions(+), 377 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 511c2d9f3..050c700c9 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,13 +93,19 @@ def _set_activation_fn_map() -> None: MAX_DROPLESS_BLOCK_SIZE_ROW = 128 -class CrossEntropyImpl(str, enum.Enum): +class EntropyLossImplementation(enum.StrEnum): auto = "auto" torch = "torch" fused = "fused" triton = "triton" +class EntropyLossType(enum.StrEnum): + cross_entropy = "cross_entropy" + forward_kl = "forward_kl" + reverse_kl = "reverse_kl" + + class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 2503dca5e..6ab934212 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -1,19 +1,20 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.utils import Assert -def _torch_cross_entropy_forward_backward( +def _torch_entropy_loss_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A wrapper for the pytorch implementation of cross-entropy. @@ -24,23 +25,40 @@ def _torch_cross_entropy_forward_backward( # Torch compile doesn't understand this. with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) + logits_scaled = logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor if target_format == TargetFormat.logits: - if logits_scale_factor != 1.0: - target = target * logits_scale_factor - if teacher_softmax_temperature != 1.0: - target = target / teacher_softmax_temperature - target = torch.softmax(target, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target - ) + target_scale = logits_scale_factor / temperature + target = target if target_scale == 1.0 else target * target_scale + else: + Assert.eq(temperature, 1.0) + + reduction = "mean" if loss_mask is None else "none" + if entropy_loss_type == EntropyLossType.cross_entropy: + if target_format == TargetFormat.logits: + target = torch.softmax(target, dim=-1) + loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction=reduction) else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) + if target_format == TargetFormat.logits: + target_log_probability = torch.nn.functional.log_softmax(target, dim=-1) + elif target_format == TargetFormat.probabilities: + target_log_probability = target.log() + else: + target_log_probability = torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).log() + if entropy_loss_type == EntropyLossType.forward_kl: + loss = torch.nn.functional.kl_div( + predicted_log_probability, target_log_probability, reduction=reduction, log_target=True + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + loss = torch.nn.functional.kl_div( + target_log_probability, predicted_log_probability, reduction=reduction, log_target=True ) - * loss_mask - ).mean() + else: + raise NotImplementedError(entropy_loss_type) + + if loss_mask is not None: + loss = (loss * loss_mask).mean() + if grad_output is None: grad = None else: @@ -67,135 +85,209 @@ def _fused_softmax_base( return logits_norm, exp_logits, sum_exp_logits -@torch.compile -def _fused_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1 -) -> torch.Tensor: - _, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim) - return exp_logits / sum_exp_logits +def _fused_reverse_kl_base( + logits: torch.Tensor, + target: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + temperature: float = 1.0, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + if target_format == TargetFormat.logits: + target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target, logits_scale_factor / temperature, group + ) + target = exp_logits_targets / sum_exp_target_logits + target_log_probability = target_logits - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) -@torch.compile -def _fused_cross_entropy_forward_backward( + predicted_log_probability = logits_norm - sum_exp_logits.log() + # Compute loss terms: student_probs * log_ratio, then sum over vocab + # This is equivalent to kl_div(..., log_target=True) but more memory efficient + log_ratio = predicted_log_probability - target_log_probability + per_sample_loss = (predicted_log_probability.exp() * log_ratio).sum(dim=-1) + if group is not None: + all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) + + if grad_output is None: + grad = None + else: + # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) + # where E_q[log(q/p)] is the expected log ratio under the student distribution + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * target * grad_output + return per_sample_loss, grad + + +def _fused_cross_entropy_base( logits: torch.Tensor, target: torch.Tensor, - loss_mask: torch.Tensor | None, grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, - teacher_softmax_temperature: float = 1.0, + temperature: float = 1.0, return_kl_loss: bool = False, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - A fused implementation of cross-entropy with torch compile. - It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, - but still suboptimal because it needs multiple kernels. - """ - # Do the forward and backward passes all at once, and fused with dtype conversion. - # Way faster and more memory-efficient than the pytorch version. - +): logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( - target, logits_scale_factor / teacher_softmax_temperature, group + target, logits_scale_factor / temperature, group ) - target = exp_logits_targets / sum_exp_target_logits + target_log_probability = target_logits - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) - if target_format == TargetFormat.labels: - target = target.unsqueeze(-1) - loss_mask = target >= 0 - if group is None: - # Keep values within range for scatter and gather ops to work. - target_masked = target * loss_mask - target_mask = None - else: - # Mask the target (fused) - # TODO: Could mask earlier on cpu or overlap with reduce? - vocab_start_index = logits.size(-1) * group.rank() - target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target_masked = (target - vocab_start_index) * target_mask + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) + if return_kl_loss: + logits_norm = logits_norm - target_log_probability + predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + if group is not None: + # We need to sum the over the tensor-parallel group, + # but this is handled in the final averaging provided we multiply by the group size. + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + + per_sample_loss = sum_exp_logits.log() - predicted_logits + + if grad_output is None: + grad = None else: - # Target should be tensor-parallel already, no further manipulation needed. + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + grad = (exp_logits - sum_exp_logits * target) * (grad_output / sum_exp_logits) + + return per_sample_loss, grad + + +def _fused_cross_entropy_base_from_labels( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor, + grad_output: float | None, + logits_scale_factor: float, + group: ProcessGroup | None = None, +): + logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + + target = target.unsqueeze(-1) + if group is None: + # Keep values within range for scatter and gather ops to work. + target_masked = target * loss_mask target_mask = None - target_masked = target - if loss_mask is not None: - loss_mask = loss_mask.unsqueeze(-1) + else: + # Mask the target (fused) + # TODO: Could mask earlier on cpu or overlap with reduce? + vocab_start_index = logits.size(-1) * group.rank() + target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) + target_masked = (target - vocab_start_index) * target_mask + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss is the same because P * log(P) == 0. + predicted_logits = logits_norm.gather(1, target_masked) + if group is not None: + predicted_logits = target_mask * predicted_logits + all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + per_sample_loss = sum_exp_logits.log() - predicted_logits if grad_output is None: grad = None else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - if target_format == TargetFormat.labels: - grad_base = exp_logits.scatter_add( - 1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) - ) - else: - grad_base = exp_logits - sum_exp_logits * target_masked + grad = exp_logits.scatter_add( + 1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + ) * (grad_output / sum_exp_logits) - grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits) - if logits_scale_factor != 1.0: - grad *= logits_scale_factor - if loss_mask is not None: - grad *= loss_mask - grad = grad.to(logits.dtype) + return per_sample_loss, grad - # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) - if target_format == TargetFormat.labels: - predicted_logits = logits_norm.gather(1, target_masked) - if group is not None: - predicted_logits = target_mask * predicted_logits - all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) +@torch.compile +def _fused_entropy_loss_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + entropy_loss_type: EntropyLossType, + group: ProcessGroup | None = None, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + A fused implementation of cross-entropy with torch compile. + It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, + but still suboptimal because it needs multiple kernels. + """ + grad_output = None if grad_output is None else grad_output / logits.size(0) * logits_scale_factor + if target_format == TargetFormat.labels: + assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) + if loss_mask is None: + loss_mask = target >= 0 + per_sample_loss, grad = _fused_cross_entropy_base_from_labels( + logits, + target, + grad_output, + logits_scale_factor, + group, + ) + elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): + per_sample_loss, grad = _fused_cross_entropy_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, + ) + elif entropy_loss_type == EntropyLossType.reverse_kl: + per_sample_loss, grad = _fused_reverse_kl_base( + logits, + target, + grad_output, + logits_scale_factor, + target_format, + group, + temperature, + ) else: - predicted_logits = (target_masked * logits_norm).sum(dim=-1, keepdim=True) - if group is not None and target_format != TargetFormat.labels: - # this is needed because on each rank we calculate log Z - sum_i t_i * z_i, where z_i is logit. - # Then we average on line 160: 1/K sum_ranks (log Z - sum_i t_i * z_i) - # = log Z - 1/K sum_ranks (sum_i t_i * z_i), where is the global predicted_logits, so without multiplying it by K 1/K there does not cancel out. - predicted_logits = predicted_logits * group.size() + raise NotImplementedError(entropy_loss_type) - per_sample_loss = sum_exp_logits.log() - predicted_logits if loss_mask is not None: + loss_mask = loss_mask.unsqueeze(-1) per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() - if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.AVG, group=group) - if return_kl_loss: - if target_format == TargetFormat.logits: - teacher_log_prob = target_logits - sum_exp_target_logits.log() - else: - teacher_log_prob = torch.log(target + 1e-20) - target_entropy = -(target * teacher_log_prob).sum(dim=-1) + + if grad is not None: if loss_mask is not None: - target_entropy = target_entropy * loss_mask.squeeze(-1) - target_entropy = target_entropy.mean() - if group is not None: - all_reduce(target_entropy, op=ReduceOp.SUM, group=group) - loss -= target_entropy + grad = grad * loss_mask + grad = grad.to(logits.dtype) return loss, grad _CROSS_ENTROPY_IMPLEMENTATIONS = { - CrossEntropyImpl.torch: _torch_cross_entropy_forward_backward, - CrossEntropyImpl.fused: _fused_cross_entropy_forward_backward, - CrossEntropyImpl.triton: triton_cross_entropy_forward_backward, + EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, + EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, + EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, } -def cross_entropy_forward_backward( +def entropy_loss_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, group: ProcessGroup | None = None, - implementation: CrossEntropyImpl = CrossEntropyImpl.fused, + implementation: EntropyLossImplementation = EntropyLossImplementation.fused, logits_scale_factor: float = 1.0, teacher_softmax_temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -213,14 +305,15 @@ def cross_entropy_forward_backward( if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) if group: - Assert.eq(implementation, CrossEntropyImpl.fused) - return _fused_cross_entropy_forward_backward( + Assert.eq(implementation, EntropyLossImplementation.fused) + return _fused_entropy_loss_forward_backward( logits, target, loss_mask, grad_output, logits_scale_factor, target_format, + entropy_loss_type, group, teacher_softmax_temperature, ) @@ -232,180 +325,6 @@ def cross_entropy_forward_backward( grad_output, logits_scale_factor, target_format, + entropy_loss_type, teacher_softmax_temperature=teacher_softmax_temperature, ) - - -def distributed_log_softmax( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -): - logits_norm, _, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group=group, dim=dim) - - return logits_norm - sum_exp_logits.log() # log_softmax - - -@torch.compile -def _reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Reverse KL using PyTorch's native kl_div function. - This is used for TP version where we split accross vocab dimantion. KL is additive over partitions of the vocab. - - Takes: - logits: [BxS, V] or [B, S, V] - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - """ - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - teacher_log_probs = distributed_log_softmax(target.float(), group=group) - log_ratio = distributed_log_softmax(logits, group=group) - - student_probs = log_ratio.exp() - log_ratio = log_ratio - teacher_log_probs # In-place: log_ratio = student_log_probs - teacher_log_probs - del teacher_log_probs - # Compute loss terms: student_probs * log_ratio, then sum over vocab - # This is equivalent to kl_div(..., log_target=True) but more memory efficient - loss_terms = (student_probs * log_ratio).sum(dim=-1) - - if loss_mask is not None: - # loss mask is the same on all ranks for TP over vocab. - valid = loss_mask.to(loss_terms.dtype) - loss_terms = loss_terms * valid - valid_tokens = torch.prod(torch.tensor(loss_terms.shape, device=loss_terms.device, dtype=loss_terms.dtype)) - loss = loss_terms.sum() # sums over batch and seq. len. - - if group is not None: - all_reduce(loss, op=ReduceOp.SUM, group=group) - loss /= valid_tokens - - if grad_output is not None: - # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) - # where E_q[log(q/p)] is the expected log ratio under the student distribution - expected = torch.sum(student_probs * log_ratio, dim=-1, keepdim=True) - if group is not None: - all_reduce(expected, op=ReduceOp.SUM, group=group) - log_ratio = log_ratio - expected - log_ratio = log_ratio * student_probs - del student_probs # Free after use - - if loss_mask is not None: - log_ratio = log_ratio * loss_mask.to(logits.dtype).unsqueeze(-1) - - log_ratio = log_ratio * (grad_output / valid_tokens) - grad = log_ratio.to(logits.dtype) - else: - grad = None - - return loss.detach_(), grad - - -def reverse_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Compute reverse KL divergence: KL(q||p) where q is the predicted distribution (student) and p is the target (teacher). - This is mode-seeking (vs. mode-covering for forward KL) and useful for: - - Encouraging the model to focus on the modes of the target distribution - - Avoiding probability mass on low-probability regions of the target - - Distillation scenarios where you want sharp, focused predictions - - Key differences from standard cross-entropy: - - Standard CE: KL(p||q) = mode-covering (spreads mass broadly) - - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - - Takes: - logits: [BxS, V] or [B, S, V], where V is local vocab size - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - - Returns: - loss: Reverse KL divergence loss - grad: Gradients w.r.t. logits - """ - Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") - Assert.eq( - teacher_softmax_temperature, - 1, - msg="Teacher softmax temperature must be 1 for reverse KL", - ) - Assert.eq(logits_scale_factor, 1, msg="Logits scale factor must be 1 for reverse KL") - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - # TODO: implement fused? - distillation_loss, distillation_grad = _reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - ) - return distillation_loss, distillation_grad - - -def forward_kl_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Compute forward KL divergence: KL(p||q) where p is the target distribution (teacher) and q is the predicted (student). - This is mode-covering (vs. mode-seeking for reverse KL) and useful for: - - Encouraging the model to cover all modes of the target distribution - - Spreading probability mass broadly across the target support - - Standard distillation scenarios where you want to match the full teacher distribution - - Key differences from reverse KL: - - Forward KL: KL(p||q) = mode-covering (spreads mass broadly) - - Reverse KL: KL(q||p) = mode-seeking (focuses on target modes) - - Takes: - logits: [BxS, V] or [B, S, V], where V is local vocab size - target: [BxS, V] or [B, S, V] (logits format) - loss_mask: [BxS] or [B, S] or None - ... - - Returns: - loss: Forward KL divergence loss - grad: Gradients w.r.t. logits - """ - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - - return _fused_cross_entropy_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - group=group, - teacher_softmax_temperature=teacher_softmax_temperature, - return_kl_loss=True, - ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 295cdb74d..709d0c52d 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,8 @@ import torch -from fast_llm.functional.config import TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.utils import Assert @triton_jit() @@ -125,7 +126,8 @@ def triton_cross_entropy_forward_backward( grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, - teacher_softmax_temperature: float = 1.0, + entropy_loss_type: EntropyLossType, + temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -134,6 +136,7 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED + Assert.eq(entropy_loss_type, EntropyLossType.cross_entropy) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -163,7 +166,7 @@ def triton_cross_entropy_forward_backward( assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, - target / teacher_softmax_temperature, + target / temperature, loss_mask, grad_logits, losses, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e3de9e9cb..3e6eb2d3d 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,7 +9,7 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossImplementation, TargetFormat, TritonConfig from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig @@ -128,8 +128,8 @@ class CrossEntropyLanguageModelLossConfig(LanguageModelLossConfig): _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False - implementation: CrossEntropyImpl = Field( - default=CrossEntropyImpl.auto, + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, desc="Implementation for the cross-entropy computation.", hint=FieldHint.performance, ) @@ -182,19 +182,19 @@ def get_loss( vocab_parallel: bool = False, kwargs: dict | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import cross_entropy_forward_backward + from fast_llm.functional.cross_entropy import entropy_loss_forward_backward target = kwargs.get(TargetsKwargs.lm_target) implementation = self.implementation - if implementation == CrossEntropyImpl.auto: + if implementation == EntropyLossImplementation.auto: if vocab_parallel: - implementation = CrossEntropyImpl.fused + implementation = EntropyLossImplementation.fused elif TritonConfig.TRITON_ENABLED: - implementation = CrossEntropyImpl.triton + implementation = EntropyLossImplementation.triton else: - implementation = CrossEntropyImpl.fused + implementation = EntropyLossImplementation.fused - return cross_entropy_forward_backward( + return entropy_loss_forward_backward( logits=logits.flatten(0, -2), target=target, loss_mask=None, # Labels are already masked diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 23eea12b4..5df203ed3 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -7,12 +7,8 @@ import pytest import torch -from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import ( - cross_entropy_forward_backward, - forward_kl_forward_backward, - reverse_kl_forward_backward, -) +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.cross_entropy import entropy_loss_forward_backward from fast_llm.utils import Assert from tests.utils.utils import requires_cuda @@ -37,7 +33,7 @@ def _get_cross_entropy_inputs( return logits, target, loss_mask -def _compare_cross_entropy_outputs( +def _compare_entropy_loss_outputs( loss: torch.Tensor, ref_loss: torch.Tensor, has_grad: bool, @@ -69,7 +65,10 @@ def _compare_cross_entropy_outputs( ), ) @pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) -def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): +@pytest.mark.parametrize( + "entropy_loss_type", (EntropyLossType.cross_entropy, EntropyLossType.forward_kl, EntropyLossType.reverse_kl) +) +def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): # TODO: Test tensor-parallel implementation. assert TritonConfig.TRITON_ENABLED logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) @@ -80,94 +79,32 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski "grad_output": grad_output, "logits_scale_factor": logits_scale_factor, "target_format": target_format, + "entropy_loss_type": entropy_loss_type, } # Torch serves as the reference implementation. - out_torch, grad_torch = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.torch) - out_fused, grad_fused = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.fused) + out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) + out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) # TODO: Why is the error so high with logit scaling? threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 - _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) + _compare_entropy_loss_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) + + if entropy_loss_type != EntropyLossType.cross_entropy: + # Triton implementation only supports cross-entropy. + return if num_columns > 65536: with pytest.raises(AssertionError): - cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) + entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) else: - out_triton, grad_triton = cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) - _compare_cross_entropy_outputs( + out_triton, grad_triton = entropy_loss_forward_backward( + **kwargs, implementation=EntropyLossImplementation.triton + ) + _compare_entropy_loss_outputs( out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold ) -def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): - # Manual reference: sum over vocab then average over valid tokens. - logits = logits.detach().requires_grad_() - per_sample = torch.nn.functional.kl_div( - torch.log_softmax(target.float(), dim=-1), - torch.log_softmax(logits.float(), dim=-1), - reduction="none", - log_target=True, - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - output = per_sample.mean() - output.backward() - return output, logits.grad - - -@requires_cuda -@pytest.mark.slow -# TODO: Support the same parameterization as above in the reference implementation. -@pytest.mark.parametrize("loss_masking", [False, True]) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) -def test_reverse_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out_ref, grad_ref = _reverse_kl_forward_backward_torch(logits, target, loss_mask) - out, grad = reverse_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - target_format=TargetFormat.logits, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) - - -def _forward_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None): - # Manual reference: sum over vocab then average over all tokens (not just valid ones). - # Forward KL: KL(p||q) where p=teacher, q=student - logits = logits.detach().requires_grad_(True) - per_sample = torch.nn.functional.kl_div( - torch.log_softmax(logits.float(), dim=-1), - torch.log_softmax(target.float(), dim=-1), - reduction="none", - log_target=True, - ).sum(dim=-1) - if loss_mask is not None: - per_sample = per_sample * loss_mask - output = per_sample.sum() / per_sample.numel() - output.backward() - return output, logits.grad - - -@requires_cuda -@pytest.mark.slow -# TODO: Support the same parameterization as above in the reference implementation. -@pytest.mark.parametrize("loss_masking", [False, True]) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) -def test_forward_kl(loss_masking, target_format): - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out_ref, grad_ref = _forward_kl_forward_backward_torch(logits, target, loss_mask) - out, grad = forward_kl_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1.0, - target_format=TargetFormat.logits, - ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref, 1e-3) - - def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): try: torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) @@ -225,12 +162,12 @@ def _compare_parallel_cross_entropy( grad_output=1, target_format=target_format, ) - _compare_cross_entropy_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) + _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): success = True - for function in (reverse_kl_forward_backward, forward_kl_forward_backward, cross_entropy_forward_backward): + for function in (reverse_kl_forward_backward, forward_kl_forward_backward, entropy_loss_forward_backward): for target_format in (TargetFormat.logits,): for loss_masking in [False, True]: try: diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index e01beb031..67d674b65 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -5,7 +5,7 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.functional.config import EntropyLossImplementation from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs, LanguageModelLossConfig from fast_llm.layers.language_model.head import LanguageModelHead @@ -134,7 +134,7 @@ def _lm_head( @requires_cuda @pytest.mark.slow -@pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) +@pytest.mark.parametrize("cross_entropy_impl", tuple(EntropyLossImplementation)) @pytest.mark.parametrize( ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), ( @@ -277,7 +277,7 @@ def _lm_head( ), ) def test_lm_head( - cross_entropy_impl: CrossEntropyImpl, + cross_entropy_impl: EntropyLossImplementation, config_dict: dict[str, typing.Any], distributed_config_dict: dict[str, typing.Any], loss_masking: bool, From 99a73b5feee2c7a7234453f440edf862a155c54a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 13 Jan 2026 18:16:58 -0500 Subject: [PATCH 38/51] fixes --- fast_llm/functional/cross_entropy.py | 7 ++++--- fast_llm/layers/language_model/config.py | 2 +- tests/functional/test_cross_entropy.py | 17 +++++++++-------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 6ab934212..7508126e3 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -229,6 +229,7 @@ def _fused_entropy_loss_forward_backward( per_sample_loss, grad = _fused_cross_entropy_base_from_labels( logits, target, + loss_mask, grad_output, logits_scale_factor, group, @@ -285,7 +286,7 @@ def entropy_loss_forward_backward( group: ProcessGroup | None = None, implementation: EntropyLossImplementation = EntropyLossImplementation.fused, logits_scale_factor: float = 1.0, - teacher_softmax_temperature: float = 1.0, + temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -315,7 +316,7 @@ def entropy_loss_forward_backward( target_format, entropy_loss_type, group, - teacher_softmax_temperature, + temperature, ) else: return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( @@ -326,5 +327,5 @@ def entropy_loss_forward_backward( logits_scale_factor, target_format, entropy_loss_type, - teacher_softmax_temperature=teacher_softmax_temperature, + temperature=temperature, ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3e6eb2d3d..fc32617ed 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -202,7 +202,7 @@ def get_loss( group=group, implementation=implementation, logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.temperature, + temperature=self.temperature, target_format=TargetFormat.labels, ) diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 5df203ed3..362c12ed7 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -14,10 +14,11 @@ def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat, device="cuda" + num_columns: int, loss_masking: bool, target_format: TargetFormat ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 + logits_var = torch.randn(256, num_columns, dtype=torch.float32, device=device) / 3 loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None if target_format == TargetFormat.labels: target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) @@ -26,7 +27,7 @@ def _get_cross_entropy_inputs( logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) loss_mask = None else: - target = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) + target = torch.randn(256, num_columns, dtype=torch.float32, device=device) logits = target + logits_var if target_format == TargetFormat.probabilities: target = torch.softmax(target, -1) @@ -49,7 +50,6 @@ def _compare_entropy_loss_outputs( assert ref_grad is None -@requires_cuda @pytest.mark.slow @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), @@ -68,9 +68,10 @@ def _compare_entropy_loss_outputs( @pytest.mark.parametrize( "entropy_loss_type", (EntropyLossType.cross_entropy, EntropyLossType.forward_kl, EntropyLossType.reverse_kl) ) -def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): +def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="rNot implemented") # TODO: Test tensor-parallel implementation. - assert TritonConfig.TRITON_ENABLED logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { "logits": logits, @@ -89,10 +90,10 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 _compare_entropy_loss_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) - if entropy_loss_type != EntropyLossType.cross_entropy: + if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): # Triton implementation only supports cross-entropy. return - + assert TritonConfig.TRITON_ENABLED if num_columns > 65536: with pytest.raises(AssertionError): entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) From fb679d1fb138f5103030e881fb37d71acbb488fa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 05:22:36 -0500 Subject: [PATCH 39/51] stuff --- fast_llm/core/distributed.py | 9 +- fast_llm/engine/schedule/runner.py | 4 +- fast_llm/functional/cross_entropy.py | 58 ++- fast_llm/layers/common/auxiliary_loss.py | 17 +- fast_llm/layers/language_model/config.py | 457 +----------------- fast_llm/layers/language_model/head.py | 296 +++++------- .../layers/language_model/loss/__init__.py | 0 fast_llm/layers/language_model/loss/config.py | 307 ++++++++++++ .../loss/language_model_loss.py | 0 ..._cross_entropy.py => test_entropy_loss.py} | 160 +++--- tests/layers/test_lm_head.py | 7 +- tests/utils/subtest.py | 34 +- 12 files changed, 595 insertions(+), 754 deletions(-) create mode 100644 fast_llm/layers/language_model/loss/__init__.py create mode 100644 fast_llm/layers/language_model/loss/config.py create mode 100644 fast_llm/layers/language_model/loss/language_model_loss.py rename tests/functional/{test_cross_entropy.py => test_entropy_loss.py} (54%) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 4dcc53d55..8eb2f149e 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -72,10 +72,12 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: +def safe_barrier( + group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None +) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -86,9 +88,10 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, + device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 133b3206b..62999b7ca 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,9 @@ def _preprocess_data( self, context: BatchContext, data_iterator: typing.Iterator, preprocessed: bool ) -> typing.Generator[None, None, None]: batch_config = context.schedule.batch_config - grad_output = (1 if self._optimizer is None else self._optimizer.grad_scale) / batch_config.num_inputs + grad_output = ( + self._optimizer.grad_scale / batch_config.num_inputs if context.schedule.phase.is_training else None + ) for micro_batch in range(batch_config.sequential_micro_batches): micro_batch_data = next(data_iterator) if not preprocessed: diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index 7508126e3..0c0fe9fa3 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -32,11 +32,12 @@ def _torch_entropy_loss_forward_backward( else: Assert.eq(temperature, 1.0) - reduction = "mean" if loss_mask is None else "none" if entropy_loss_type == EntropyLossType.cross_entropy: if target_format == TargetFormat.logits: target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction=reduction) + loss = torch.nn.functional.cross_entropy( + logits_scaled, target, reduction="mean" if loss_mask is None else "none" + ) else: predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) if target_format == TargetFormat.logits: @@ -44,17 +45,27 @@ def _torch_entropy_loss_forward_backward( elif target_format == TargetFormat.probabilities: target_log_probability = target.log() else: - target_log_probability = torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).log() + target_log_probability = ( + torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).add(1.0e-10).log() + ) if entropy_loss_type == EntropyLossType.forward_kl: loss = torch.nn.functional.kl_div( - predicted_log_probability, target_log_probability, reduction=reduction, log_target=True + predicted_log_probability, + target_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, ) elif entropy_loss_type == EntropyLossType.reverse_kl: loss = torch.nn.functional.kl_div( - target_log_probability, predicted_log_probability, reduction=reduction, log_target=True + target_log_probability, + predicted_log_probability, + reduction="batchmean" if loss_mask is None else "none", + log_target=True, ) else: raise NotImplementedError(entropy_loss_type) + if loss_mask is not None: + loss = loss.sum(dim=-1) if loss_mask is not None: loss = (loss * loss_mask).mean() @@ -95,21 +106,21 @@ def _fused_reverse_kl_base( temperature: float = 1.0, ): logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) + predicted_log_probability = logits_norm - sum_exp_logits.log() + predicted_probability = exp_logits / sum_exp_logits if target_format == TargetFormat.logits: - target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, _, sum_exp_target_logits = _fused_softmax_base( target, logits_scale_factor / temperature, group ) - target = exp_logits_targets / sum_exp_target_logits - target_log_probability = target_logits - sum_exp_target_logits.log() + target_log_probability = target_logits_norm - sum_exp_target_logits.log() else: target_log_probability = torch.log(target) - predicted_log_probability = logits_norm - sum_exp_logits.log() # Compute loss terms: student_probs * log_ratio, then sum over vocab # This is equivalent to kl_div(..., log_target=True) but more memory efficient log_ratio = predicted_log_probability - target_log_probability - per_sample_loss = (predicted_log_probability.exp() * log_ratio).sum(dim=-1) + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) if group is not None: all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) @@ -118,7 +129,8 @@ def _fused_reverse_kl_base( else: # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) # where E_q[log(q/p)] is the expected log ratio under the student distribution - grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * target * grad_output + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output + return per_sample_loss, grad @@ -135,16 +147,18 @@ def _fused_cross_entropy_base( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target_logits, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( target, logits_scale_factor / temperature, group ) - target_log_probability = target_logits - sum_exp_target_logits.log() - else: - target_log_probability = torch.log(target) + target = exp_logits_targets / sum_exp_target_logits # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) if return_kl_loss: + if target_format == TargetFormat.logits: + target_log_probability = target_logits_norm - sum_exp_target_logits.log() + else: + target_log_probability = torch.log(target) logits_norm = logits_norm - target_log_probability predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) if group is not None: @@ -174,20 +188,21 @@ def _fused_cross_entropy_base_from_labels( logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) target = target.unsqueeze(-1) + if group is None: # Keep values within range for scatter and gather ops to work. - target_masked = target * loss_mask + target = target * loss_mask.unsqueeze(-1) target_mask = None else: # Mask the target (fused) # TODO: Could mask earlier on cpu or overlap with reduce? vocab_start_index = logits.size(-1) * group.rank() target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target_masked = (target - vocab_start_index) * target_mask + target = (target - vocab_start_index) * target_mask # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss is the same because P * log(P) == 0. - predicted_logits = logits_norm.gather(1, target_masked) + predicted_logits = logits_norm.gather(1, target) if group is not None: predicted_logits = target_mask * predicted_logits all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) @@ -198,7 +213,7 @@ def _fused_cross_entropy_base_from_labels( else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. grad = exp_logits.scatter_add( - 1, target_masked, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) + 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) ) * (grad_output / sum_exp_logits) return per_sample_loss, grad @@ -259,13 +274,12 @@ def _fused_entropy_loss_forward_backward( raise NotImplementedError(entropy_loss_type) if loss_mask is not None: - loss_mask = loss_mask.unsqueeze(-1) - per_sample_loss = per_sample_loss * loss_mask + per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1) loss = per_sample_loss.mean() if grad is not None: if loss_mask is not None: - grad = grad * loss_mask + grad = grad * loss_mask.unsqueeze(-1) grad = grad.to(logits.dtype) return loss, grad diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index 1c8fe1c73..baec73b1c 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -13,10 +13,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: @torch.compile -def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: - return torch.mean( - torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 - ) +def calculate_z_loss( + logits: torch.Tensor, logits_scale_factor: float = 1.0, loss_mask: "torch.Tensor | None" = None +) -> torch.Tensor: + out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + if loss_mask is not None: + out *= loss_mask.unsqueeze(-1) + return torch.mean(out) def auxiliary_z_loss( @@ -27,9 +30,10 @@ def auxiliary_z_loss( losses: dict | None = None, loss_name: str | None = None, logits_scale_factor: float = 1.0, + loss_mask: "torch.Tensor | None" = None, ) -> torch.Tensor: if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor=logits_scale_factor) + loss = calculate_z_loss(logits, logits_scale_factor, loss_mask) if losses is not None and loss_name is not None: losses[loss_name].append(loss.detach()) if training and grad_scale is not None: @@ -41,6 +45,7 @@ def auxiliary_z_loss( def z_loss_forward_backward( logits: torch.Tensor, grad_output: float | None = None, + loss_mask: "torch.Tensor | None" = None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -55,7 +60,7 @@ def z_loss_forward_backward( with torch.set_grad_enabled(grad_output is not None): logits_ = logits.detach().requires_grad_(grad_output is not None) - loss = calculate_z_loss(logits, logits_scale_factor) + loss = calculate_z_loss(logits, logits_scale_factor, loss_mask) if grad_output is None: grad = None else: diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index fc32617ed..e0fbef1d7 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,38 +1,25 @@ import abc +import functools import typing -import warnings -from functools import cached_property -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import LossDef -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import EntropyLossImplementation, TargetFormat, TritonConfig from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig, LanguageModelLossConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - - from fast_llm.core.distributed import ProcessGroup from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class TargetsKwargs: - lm_target = "preprocessed_lm_target" - dpo_target = "preprocessed_dpo_target" - reference_model_logits = "reference_model_logits" - dpo_reference_model_logits = "dpo_reference_model_logits" - - class LanguageModelKwargs(BlockKwargs): token_ids = "token_ids" position_ids = "position_ids" @@ -52,382 +39,6 @@ def _format_name(name: str) -> str: return name.replace("_", " ") -@config_class(registry=True) -class LanguageModelLossConfig(Config): - """ - Losses can register themselves using @config_class(dynamic_type= {LanguageModelLossConfig: "loss_type_name"}). - """ - - _name: typing.ClassVar[str] - _abstract: typing.ClassVar[bool] = True - - weight: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Weight for this loss in the total loss computation.", - valid=check_field(Assert.geq, 0.0), - ) - - distillation_model: str | None = Field( - default=None, - desc="Name of the reference model to use for knowledge distillation." - "If provided, replace the loss with a distillation loss.", - hint=FieldHint.feature, - ) - - @abc.abstractmethod - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup|None" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - pass - - def get_loss_definitions(self, name: str, count: int = 1, prediction_distance: int | None = None) -> LossDef: - name = self.get_formatted_name(name, prediction_distance) - return LossDef( - name=name, - formatted_name=_format_name(name), - count=count, - dtype=DataType.float32, - ) - - def _validate(self): - Assert.geq(self.weight, 0.0) - super()._validate() - - def get_formatted_name(self, registered_loss_name=None, prediction_distance: int | None = None) -> str: - """ - Returns loss name for logging as '()', - e.g. lm_loss(CE_loss), distillation(FwdKL_loss) - """ - name = f"{registered_loss_name}({self._name})" - if prediction_distance is not None: - name = f"{name}_{prediction_distance}" - return name - - @abc.abstractmethod - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup|None" = None, - ) -> dict[str, "torch.Tensor"]: - pass - - -@config_class(dynamic_type={LanguageModelLossConfig: "cross_entropy"}) -class CrossEntropyLanguageModelLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE_loss" - _abstract: typing.ClassVar[bool] = False - - implementation: EntropyLossImplementation = Field( - default=EntropyLossImplementation.auto, - desc="Implementation for the cross-entropy computation.", - hint=FieldHint.performance, - ) - - temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup|None" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - lm_target = kwargs.get(LanguageModelKwargs.labels) - if lm_target is not None: - # MTP: Shift the labels - lm_target_sequence_length = ( - lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads - ) - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) - lm_target_slice = slice(prediction_distance, prediction_distance + lm_target_sequence_length) - lm_target = ( - lm_target[lm_target_slice] - if kwargs[LanguageModelKwargs.sequence_first] - else lm_target[:, lm_target_slice] - ).flatten() - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - lm_target = split_op(lm_target, group, 0) - return {TargetsKwargs.lm_target: lm_target} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import entropy_loss_forward_backward - - target = kwargs.get(TargetsKwargs.lm_target) - implementation = self.implementation - if implementation == EntropyLossImplementation.auto: - if vocab_parallel: - implementation = EntropyLossImplementation.fused - elif TritonConfig.TRITON_ENABLED: - implementation = EntropyLossImplementation.triton - else: - implementation = EntropyLossImplementation.fused - - return entropy_loss_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=None, # Labels are already masked - grad_output=grad_output, - group=group, - implementation=implementation, - logits_scale_factor=logits_scale_factor, - temperature=self.temperature, - target_format=TargetFormat.labels, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "forward_kl_distillation"}) -class ForwardKLDistillationLossConfig(LanguageModelLossConfig): - """Forward KL divergence KL(p||q) for distillation (mode-covering).""" - - _name: typing.ClassVar[str] = "FwdKL_loss" - _abstract: typing.ClassVar[bool] = False - - temperature: float = Field( - default=1.0, - hint=FieldHint.optional, - desc="Temperature for teacher softmax.", - valid=check_field(Assert.gt, 0.0), - ) - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by ForwardKL Loss." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup|None" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.distillation_model}_logits") - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - reference_model_logits = split_op(reference_model_logits, group, 0) - return {TargetsKwargs.reference_model_logits: reference_model_logits} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup|None" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import forward_kl_forward_backward - - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return forward_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "reverse_kl_distillation"}) -class ReverseKLLossConfig(ForwardKLDistillationLossConfig): - """Reverse KL divergence KL(q||p) for distillation (mode-seeking).""" - - _name: typing.ClassVar[str] = "RevKL_loss" - _abstract: typing.ClassVar[bool] = False - - def _validate(self): - assert self.distillation_model is not None, "Distillation loss required by Reverse KL Loss." - super()._validate() - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup|None" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import reverse_kl_forward_backward - - # Use distillation_target for KL losses - target = kwargs.get(TargetsKwargs.reference_model_logits) - - return reverse_kl_forward_backward( - logits=logits.flatten(0, -2), - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - group=group, - logits_scale_factor=logits_scale_factor, - teacher_softmax_temperature=self.temperature, - target_format=TargetFormat.logits, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) -class DPOLossConfig(LanguageModelLossConfig): - """Direct Preference Optimization (DPO) loss for alignment.""" - - _name: typing.ClassVar[str] = "DPO_loss" - _abstract: typing.ClassVar[bool] = False - - beta: float = Field( - default=1.0, - hint=FieldHint.core, - desc="Beta parameter for DPO loss (controls strength of preference optimization).", - valid=check_field(Assert.gt, 0.0), - ) - - dpo_reference_model: str | None = Field( - default=None, - desc="Name of the reference model to use for dpo.", - hint=FieldHint.feature, - ) - - def _validate(self): - assert self.dpo_reference_model is not None, "DPO loss requires a reference model." - super()._validate() - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup|None" = None, - ) -> dict[str, "torch.Tensor"]: - if kwargs is None: - kwargs = {} - - reference_model_logits = kwargs.get(f"{self.dpo_reference_model}_logits") - dpo_target = kwargs.get(LanguageModelKwargs.labels) - if reference_model_logits is not None or dpo_target is not None: - from fast_llm.core.ops import split_op - - if reference_model_logits is not None: - reference_model_logits = reference_model_logits.flatten(0, -2) - if sequence_parallel_logits: - reference_model_logits = split_op(reference_model_logits, group, 0) - if dpo_target is not None: - dpo_target = split_op(dpo_target, group, 0) - return { - TargetsKwargs.dpo_reference_model_logits: reference_model_logits, - TargetsKwargs.dpo_target: dpo_target, - } - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup|None" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.dpo import compute_dpo_loss - - dpo_target = kwargs.get(TargetsKwargs.dpo_target) - dpo_reference_model_logits = kwargs.get(TargetsKwargs.dpo_reference_model_logits) - chosen_spans = kwargs.get(LanguageModelKwargs.chosen_spans) - rejected_spans = kwargs.get(LanguageModelKwargs.rejected_spans) - - return compute_dpo_loss( - logits=logits, - targets=dpo_target, - reference_model_logits=dpo_reference_model_logits, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, - beta=self.beta, - grad_output=grad_output, - ) - - -@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) -class ZLossConfig(LanguageModelLossConfig): - """Z-loss regularization to prevent overconfidence.""" - - _name: typing.ClassVar[str] = "Z_loss" - _abstract: typing.ClassVar[bool] = False - - def get_targets( - self, - kwargs: dict | None = None, - prediction_distance: int | None = None, - prediction_heads: int | None = None, - sequence_parallel_logits: bool | None = None, - group: "ProcessGroup|None" = None, - ) -> dict[str, "torch.Tensor"]: - return {} - - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - group: "ProcessGroup|None" = None, - logits_scale_factor: float | None = None, - vocab_parallel: bool = False, - kwargs: dict | None = None, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward - - # TODO: ====== Support loss mask, vocab_parallel ====== - assert loss_mask is None - assert group is None - - return z_loss_forward_backward( - logits=logits.flatten(0, -2), - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - ) - - @config_class() class LanguageModelEmbeddingsConfig(BlockConfig): _abstract = False @@ -539,8 +150,8 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - cross_entropy_splits: int | None = Field( - default=None, + cross_entropy_splits: int = Field( + default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), @@ -584,38 +195,14 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - removed_fields = ["distillation_loss_factor", "distillation_model", "language_model_loss_factor"] - for field in removed_fields: - if field in default: - warnings.warn( - f"Field `{field}` has been removed from {cls.__name__}. " - "Loss configuration should now be done via the `losses` field.", - DeprecationWarning, - ) - default.pop(field) - return super()._from_dict(default, strict=strict) - def _validate(self) -> None: with self._set_implicit_default(): if not self.losses: if "losses" not in self._explicit_fields: - self.losses = {"lm_loss": CrossEntropyLanguageModelLossConfig()} + self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() - if DPOLossConfig in self._loss_configs: - assert ForwardKLDistillationLossConfig not in self._loss_configs.keys() # currently don't support both - assert ReverseKLLossConfig not in self._loss_configs.keys() # currently don't support both - if ( - ForwardKLDistillationLossConfig in self._loss_configs.keys() - and ReverseKLLossConfig in self._loss_configs.keys() - ): - assert ( - self._loss_configs[ForwardKLDistillationLossConfig].distillation_model - == self._loss_configs[ReverseKLLossConfig].distillation_model - ), "Distillation losses must use the same teacher." - - @cached_property + + @functools.cached_property def _loss_configs(self) -> dict[type, LanguageModelLossConfig]: return {loss.__class__: loss for loss in self.losses.values()} @@ -623,34 +210,6 @@ def _loss_configs(self) -> dict[type, LanguageModelLossConfig]: def max_prediction_distance(self) -> int: return 1 - @property - def enable_dpo(self) -> bool: - return DPOLossConfig in self._loss_configs.keys() - - @property - def enable_distillation(self) -> bool: - return ( - ForwardKLDistillationLossConfig in self._loss_configs.keys() - or ReverseKLLossConfig in self._loss_configs.keys() - ) - - @property - def requires_loss_masks(self) -> bool: - return self.enable_distillation - - @property - def distillation_model(self) -> str | None: - for loss_type in [ForwardKLDistillationLossConfig, ReverseKLLossConfig]: - if loss_type in self._loss_configs: - return self._loss_configs[loss_type].distillation_model - return None - - @property - def dpo_reference_model(self) -> str | None: - if DPOLossConfig in self._loss_configs: - return self._loss_configs[DPOLossConfig].dpo_reference_model - return None - @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 2fa2dffe0..13ffc4f16 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -9,6 +9,7 @@ from fast_llm.core.ops import gather_op, split_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -25,8 +26,9 @@ LanguageModelKwargs, _format_name, ) +from fast_llm.layers.language_model.loss.config import LanguageModelDistillationLossConfig, LanguageModelDPOLossConfig from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert, div, get_unique +from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -67,7 +69,10 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and (self._config.enable_dpo or self._config.enable_distillation): + if prediction_distance > 0 and any( + isinstance(loss, (LanguageModelDPOLossConfig, LanguageModelDistillationLossConfig)) + for loss in self.losses.values() + ): raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") Assert.in_range(prediction_distance, 0, prediction_heads) @@ -80,7 +85,7 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_parallel_logits = self._sequence_parallel and not self._vocab_parallel - if self._config.cross_entropy_splits is not None and self._sequence_parallel: + if self._config.cross_entropy_splits > 1 and self._sequence_parallel: assert not self._vocab_parallel self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -101,9 +106,19 @@ def __init__( self._formatted_loss_names = {} for registered_loss_name, loss_config in self._config.losses.items(): - self._formatted_loss_names[registered_loss_name] = loss_config.get_formatted_name( - registered_loss_name, self._prediction_distance - ) + self._formatted_loss_names[registered_loss_name] = loss_config.get_name(self._prediction_distance) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + # TODO: Add marginal compute? (loss) + return ( + 2 + * (config.forward + 2 * config.backward) + * (input_.global_shape if config.global_ else input_).numel() + * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) + ) + + def get_output_weights(self) -> list[torch.Tensor]: + return [self.output_weights] def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None @@ -140,162 +155,108 @@ def forward( # MTP: Return shared_hidden to be used by the next head. return shared_hidden - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - # TODO: Add marginal compute? (loss) - return ( - 2 - * (config.forward + 2 * config.backward) - * (input_.global_shape if config.global_ else input_).numel() - * (self._vocab_dim.global_size if config.global_ else self._vocab_dim.size) - ) - def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - targets = self._get_targets(kwargs) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - if self._sequence_parallel_logits: - loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) - - input_ = input_.detach().requires_grad_(do_grad := targets is not None and self.training) + input_ = input_.detach().requires_grad_(self.training) with torch.enable_grad(): ln_output = self.final_norm(input_) - # Transormers expect normalized outputs for the last transformer layer, + # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) - - grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 - ) - - output_weights = self.output_weights - loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( - ln_output.detach(), targets, loss_mask, output_weights, grad_output, kwargs, losses - ) - - if do_grad: - ln_output.backward(ln_output_grad) - return loss, input_.grad - else: + loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach().flatten(0, -2), kwargs, losses) + if ln_output_grad is None: return loss, None + else: + ln_output.backward(ln_output_grad.view_as(ln_output)) + return loss, input_.grad - def _get_targets(self, kwargs: dict) -> dict | None: - targets = {} - for loss_config in self._config.losses.values(): - loss_targets = loss_config.get_targets( - kwargs, - prediction_distance=self._prediction_distance, - prediction_heads=self._prediction_heads, - sequence_parallel_logits=self._sequence_parallel_logits, - group=self._parallel_dim.group, - ) - targets.update({k: v for k, v in loss_targets.items() if v is not None}) - if len(targets) == 0: - return None - return targets - - def get_output_weights(self) -> list[torch.Tensor]: - return [self.output_weights] - - def _logits_cross_entropy_forward_backward_split( + def _logits_loss_forward_backward( self, input_: torch.Tensor, - targets: dict[str, "torch.Tensor"] | None, - loss_mask: torch.Tensor | None, - weight: torch.Tensor, - grad_output: float, kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._config.cross_entropy_splits is None: - loss, logit_input_grad = self._logits_loss_forward_backward( - input_, targets, loss_mask, weight, grad_output, kwargs, losses - ) - if targets is None: - # TODO: Make a proper way of returning the model output. - loss = loss.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - loss = gather_op(loss, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - loss = gather_op( - loss, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss - return None, None - else: - loss = None - # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._config.cross_entropy_splits - logit_input = input_.flatten(0, -2) - if self.training: - logit_input_grad = torch.empty_like(logit_input) - else: - logit_input_grad = None - - # Collect all tensors that need to be split to determine the split size - tensors_to_check = [logit_input] - if loss_mask is not None: - tensors_to_check.append(loss_mask) - tensors_to_check.extend(target for target in targets.values() if target is not None) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + if loss_mask is not None: + loss_mask = loss_mask.flatten() + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) - split_size = div( - get_unique(tensor.size(0) for tensor in tensors_to_check), - self._config.cross_entropy_splits, + if not self.training: + logits, _ = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs, True) + # TODO: Make a proper way of returning the model output. + logits = logits.detach() + if kwargs.get("global_logits"): + if self._vocab_parallel: + logits = gather_op(logits, self._parallel_dim.group, 2) + elif self._sequence_parallel_logits: + logits = gather_op( + logits, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + ) + kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = ( + logits.detach() ) + return None, None + if self._config.cross_entropy_splits == 1 or self.training: + losses_, input_grad = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs) + else: + input_grad = torch.empty_like(input_) tensors_split = [ - [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) - for tensor in [logit_input, loss_mask, logit_input_grad] - ] - target_split = { - name: ( + ( [None] * self._config.cross_entropy_splits - if targets[name] is None - else targets[name].split(split_size) + if tensor is None + else tensor.chunk(self._config.cross_entropy_splits) ) - for name in targets - } - - for i, (logit_input_, loss_mask_, logit_input_grad_) in enumerate(zip(*tensors_split, strict=True)): - loss_, grad_ = self._logits_loss_forward_backward( - logit_input_, - {name: target_split[name][i] for name in target_split}, + for tensor in [input_, loss_mask, input_grad] + ] + for i, (partial_input_, loss_mask_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): + partial_losses_, grad_ = self._logits_loss_forward_backward_partial( + partial_input_, loss_mask_, - weight, - grad_output, kwargs, ) # TODO: Avoid copy with explicit out argument. - if self.training: - logit_input_grad_.copy_(grad_) - loss = loss_ if loss is None else loss + loss_ - del grad_, loss_ - loss_count = (self._config.cross_entropy_splits or 1) * ( - self._parallel_dim.size if self._sequence_parallel_logits else 1 + input_grad_.copy_(grad_) + if i == 0: + losses_ = partial_losses_ + else: + for name in self._config.losses: + losses_[name] += partial_losses_[name] + + loss: torch.Tensor = sum( + (loss_config.weight * self._loss_coefficient / self._config.cross_entropy_splits) * losses_[name] + for name, loss_config in self._config.losses.items() ) - if loss_count != 1: - loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._parallel_dim.group) - return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None - - def _logits_loss_forward_backward( + all_reduce(loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + + if losses is not None: + losses[self._total_head_loss_name].append(loss) + if len(self._config.losses) > 1: + for name, loss_config in self._config.losses.items(): + loss_ = losses_[name] + if self._config.cross_entropy_splits != 1: + loss_ /= self._config.cross_entropy_splits + if self._sequence_parallel_logits: + # TODO: Async + all_reduce(loss_, op=ReduceOp.AVG, group=self._parallel_dim.group) + losses[loss_config.get_name(self._prediction_distance)].append(loss_) + + return loss, input_grad + + def _logits_loss_forward_backward_partial( self, input_: torch.Tensor, - targets: dict[str, "torch.Tensor"] | None, loss_mask: torch.Tensor | None, - weight: torch.Tensor, - grad_output: float, kwargs: dict, - losses: dict | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + return_logits: bool = False, + ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, torch.Tensor | None]: group = self._parallel_dim.group if self._vocab_parallel else None logits, context = output_parallel_linear_forward( input_=input_, - weight=weight, + weight=self.output_weights, bias=None, group=group, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -313,46 +274,36 @@ def _logits_loss_forward_backward( dims = None self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) - if targets is None: - return logits * self._config.logits_scale_factor, None + if return_logits: + return logits, None - total_loss, grad = None, None + losses, grad = {}, None for loss_name, loss_config in self._config.losses.items(): # losses are returned unscaled but the grads are already scaled - loss_unscaled_, grad_ = loss_config.get_loss( + # TODO: ====== grad_output can't be None? + grad_output = kwargs.get(LanguageModelKwargs.grad_output) + if grad_output is not None: + grad_output = ( + grad_output + * self._loss_coefficient + * loss_config.weight + / (self._parallel_dim.size if self._sequence_parallel_logits else 1) + / self._config.cross_entropy_splits + ) + loss, grad_ = loss_config.get_loss( logits, loss_mask, - grad_output=( - (grad_output * self._loss_coefficient * loss_config.weight if grad_output is not None else None) - if loss_config.weight != 0.0 - else None - ), + grad_output=None if grad_output == 0.0 else grad_output, group=group, logits_scale_factor=self._config.logits_scale_factor, - vocab_parallel=self._vocab_parallel, - kwargs={**kwargs, **targets}, + kwargs=kwargs, ) - - loss_ = loss_unscaled_ * loss_config.weight * self._loss_coefficient - - if losses is not None: - losses[self._formatted_loss_names[loss_name]].append(loss_unscaled_.detach()) - - if total_loss is None: - total_loss = loss_ - else: - total_loss = total_loss + loss_ - + losses[loss_name] = loss.detach() if grad_ is not None: - if grad is None: - grad = grad_ - else: - grad = grad + grad_ - - if losses is not None and total_loss is not None: - losses[self._total_head_loss_name].append(total_loss.detach()) + # TODO: Accumulate grads in-place to reduce memory and compute overhead. + grad = grad_ if grad is None else grad + grad_ - return total_loss, output_parallel_linear_backward(grad, context) if self.training else None + return losses, output_parallel_linear_backward(grad, context) if self.training else None @functools.cached_property def _total_head_loss_name(self) -> str: @@ -364,26 +315,21 @@ def _total_head_loss_name(self) -> str: name = f"{name}_{self._prediction_distance}" return name - @functools.cached_property - def _z_loss_name(self) -> str: - name = "z_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - loss_defs = [ + return [ LossDef( name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count - ) + ), + *( + LossDef( + name=loss_config.get_name(self._prediction_distance), + formatted_name=_format_name(loss_config.get_name(self._prediction_distance)), + count=count, + dtype=DataType.float32, + ) + for loss_config in self._config.losses.values() + ), ] - for loss_name, loss_config in self._config.losses.items(): - loss_def: LossDef = loss_config.get_loss_definitions( - name=loss_name, count=count, prediction_distance=self._prediction_distance - ) - loss_defs.append(loss_def) - - return loss_defs @property def heads(self): diff --git a/fast_llm/layers/language_model/loss/__init__.py b/fast_llm/layers/language_model/loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py new file mode 100644 index 000000000..609bcbac9 --- /dev/null +++ b/fast_llm/layers/language_model/loss/config.py @@ -0,0 +1,307 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + import torch + + +@config_class(registry=True) +class LanguageModelLossConfig(Config): + _abstract: typing.ClassVar[bool] = True + + weight: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Weight for this loss in the total loss computation.", + valid=check_field(Assert.geq, 0.0), + ) + + def get_name(self, prediction_distance: int = 0) -> str: + return self._name if prediction_distance == 0 else f"{self._name}_{prediction_distance}" + + @property + def _name(self) -> str: + raise NotImplementedError() + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + raise NotImplementedError() + + +@config_class(dynamic_type={LanguageModelLossConfig: "label"}) +class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): + _name: typing.ClassVar[str] = "CE_loss" + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import entropy_loss_forward_backward + + labels = kwargs[LanguageModelKwargs.labels] + + # MTP: Shift the labels + if prediction_heads > 1: + sequence_q_length = labels.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(sequence_q_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + label_slice = slice(prediction_distance, prediction_distance + sequence_q_length) + labels = labels[label_slice] if kwargs[LanguageModelKwargs.sequence_first] else labels[:, label_slice] + + labels = labels.flatten() + + # Get the local chunk. + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + labels = split_op(labels, group, 0) + + # Get the chunk for the current split. + if num_splits > 1: + labels = labels.chunk(num_splits)[split_index] + + implementation = self.implementation + if implementation == EntropyLossImplementation.auto: + if ( + TritonConfig.TRITON_ENABLED + and torch.cuda.is_available() + and group is None + and self.loss_type == EntropyLossType.cross_entropy + ): + implementation = EntropyLossImplementation.triton + else: + implementation = EntropyLossImplementation.fused + + return entropy_loss_forward_backward( + logits, + labels, + None, # Labels are already masked + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + target_format=TargetFormat.labels, + entropy_loss_type=self.loss_type, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "distillation"}) +class LanguageModelDistillationLossConfig(LanguageModelLossConfig): + _name: typing.ClassVar[str] = "FwdKL_loss" + _abstract: typing.ClassVar[bool] = False + + loss_type: EntropyLossType = Field( + default=EntropyLossType.cross_entropy, + desc="Type of loss to use.", + hint=FieldHint.core, + ) + implementation: EntropyLossImplementation = Field( + default=EntropyLossImplementation.auto, + desc="Loss implementation.", + hint=FieldHint.performance, + ) + reference_model: str = Field( + desc="Name of the reference model for knowledge distillation.", + hint=FieldHint.feature, + ) + temperature: float = Field( + default=1.0, + hint=FieldHint.optional, + desc="Temperature for teacher softmax.", + valid=check_field(Assert.gt, 0.0), + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.cross_entropy import entropy_loss_forward_backward + + if prediction_distance > 0: + raise NotImplementedError() + + reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) + + # Get the local chunk. + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + + # Get the chunk for the current split. + if num_splits > 1: + reference_model_logits = reference_model_logits.chunk(num_splits)[split_index] + + implementation = ( + EntropyLossImplementation.fused + if self.implementation == EntropyLossImplementation.auto + else self.implementation + ) + return entropy_loss_forward_backward( + logits, + reference_model_logits, + loss_mask, + grad_output=grad_output, + group=group, + implementation=implementation, + logits_scale_factor=logits_scale_factor, + temperature=self.temperature, + target_format=TargetFormat.labels, + entropy_loss_type=self.loss_type, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) +class LanguageModelDPOLossConfig(LanguageModelLossConfig): + """Direct Preference Optimization (DPO) loss for alignment.""" + + _name: typing.ClassVar[str] = "DPO_loss" + _abstract: typing.ClassVar[bool] = False + + beta: float = Field( + default=1.0, + hint=FieldHint.core, + desc="Beta parameter for DPO loss (controls strength of preference optimization).", + valid=check_field(Assert.gt, 0.0), + ) + + reference_model: str = Field( + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.functional.dpo import compute_dpo_loss + + if num_splits > 1: + raise NotImplementedError() + if prediction_distance > 0: + raise NotImplementedError() + + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + + reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) + target = kwargs[LanguageModelKwargs.labels] + + if sequence_parallel_logits: + from fast_llm.core.ops import split_op + + reference_model_logits = split_op(reference_model_logits, group, 0) + target = split_op(target, group, 0) + + chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] + rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] + + return compute_dpo_loss( + logits=logits, + targets=target, + reference_model_logits=reference_model_logits, + chosen_spans=chosen_spans, + rejected_spans=rejected_spans, + beta=self.beta, + grad_output=grad_output, + ) + + +@config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) +class LanguageModelZLossConfig(LanguageModelLossConfig): + """Z-loss regularization to prevent overconfidence.""" + + _name: typing.ClassVar[str] = "Z_loss" + _abstract: typing.ClassVar[bool] = False + + def get_loss( + self, + logits: "torch.Tensor", + loss_mask: "torch.Tensor | None", + grad_output: float | None = None, + *, + group: "torch.distributed.ProcessGroup|None" = None, + logits_scale_factor: float = 1.0, + prediction_distance: int = 0, + prediction_heads: int = 1, + split_index: int = 0, + num_splits: int = 1, + sequence_parallel_logits: bool = False, + kwargs: dict[str, typing.Any], + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward + + # TODO: ====== Support loss mask, vocab_parallel ====== + assert loss_mask is None + assert group is None + + return z_loss_forward_backward( + logits, + grad_output, + loss_mask, + logits_scale_factor, + ) diff --git a/fast_llm/layers/language_model/loss/language_model_loss.py b/fast_llm/layers/language_model/loss/language_model_loss.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_entropy_loss.py similarity index 54% rename from tests/functional/test_cross_entropy.py rename to tests/functional/test_entropy_loss.py index 362c12ed7..cb2036f94 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_entropy_loss.py @@ -1,16 +1,13 @@ -import os -import sys -import tempfile -import traceback -import typing +import pathlib import pytest import torch +from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import entropy_loss_forward_backward from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda +from tests.utils.subtest import DistributedTestContext def _get_cross_entropy_inputs( @@ -41,8 +38,9 @@ def _compare_entropy_loss_outputs( grad: torch.Tensor | None, ref_grad: torch.Tensor | None, threshold=1e-5, + loss_min_threshold=1e-6, ): - Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) + Assert.rms_close_relative(loss, ref_loss, threshold, loss_min_threshold) if has_grad: Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) else: @@ -64,13 +62,11 @@ def _compare_entropy_loss_outputs( (65537, 1.0, 1.0, False), # Above max block size ), ) -@pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) -@pytest.mark.parametrize( - "entropy_loss_type", (EntropyLossType.cross_entropy, EntropyLossType.forward_kl, EntropyLossType.reverse_kl) -) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="rNot implemented") + pytest.skip(reason="Not implemented") # TODO: Test tensor-parallel implementation. logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { @@ -86,9 +82,15 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) - # TODO: Why is the error so high with logit scaling? - threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 - _compare_entropy_loss_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) + # TODO: Why is the error so high with loss masking for reverse KL? + _compare_entropy_loss_outputs( + out_fused, + out_torch, + grad_output is not None, + grad_fused, + grad_torch, + loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6, + ) if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): # Triton implementation only supports cross-entropy. @@ -101,89 +103,77 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin out_triton, grad_triton = entropy_loss_forward_backward( **kwargs, implementation=EntropyLossImplementation.triton ) - _compare_entropy_loss_outputs( - out_triton, out_torch, grad_output is not None, grad_triton, grad_torch, threshold - ) - - -def _mp_worker(rank: int, world_size: int, init_method: str, fn_args: tuple): - try: - torch.distributed.init_process_group(backend="gloo", rank=rank, world_size=world_size, init_method=init_method) - fn_args[0](rank, torch.distributed.group.WORLD, *fn_args[1:]) - finally: - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - - -def _spawn_dist(world_size: int, *fn_args): - """ - Run `fn(rank, group, *fn_args)` across `world_size` ranks using torch.multiprocessing. - """ - with tempfile.NamedTemporaryFile(delete=False) as tmp: - init_method = f"file://{tmp.name}" - - try: - torch.multiprocessing.spawn( - _mp_worker, - args=(world_size, init_method, fn_args), - nprocs=world_size, - join=True, - start_method="spawn", - ) - finally: - if os.path.exists(tmp.name): - os.remove(tmp.name) + _compare_entropy_loss_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) -def _compare_parallel_cross_entropy( - rank: int, - group: torch.distributed.ProcessGroup, +def _entropy_loss_distributed( target_format: TargetFormat, - function: typing.Callable, + entropy_loss_type: EntropyLossType, loss_masking: bool, + group: torch.distributed.ProcessGroup, ): # Ensure all workers have the same inputs. torch.manual_seed(0) - world_size = torch.distributed.get_world_size(group) + rank = group.rank() + world_size = group.size() logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - out, grad = function( - logits=logits.chunk(world_size, 1)[rank], - target=target.chunk(world_size, 1)[rank], - loss_mask=loss_mask, - grad_output=1, - group=group, - target_format=target_format, - ) + kwargs = { + "loss_mask": loss_mask, + "grad_output": 1.0, + "target_format": target_format, + "implementation": EntropyLossImplementation.fused, + "entropy_loss_type": entropy_loss_type, + } + out_ref, grad_ref = entropy_loss_forward_backward(logits, target, **kwargs) - out_ref, grad_ref = function( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=1, - target_format=target_format, + out, grad = entropy_loss_forward_backward( + logits.chunk(world_size, 1)[rank], + target if target_format == TargetFormat.labels else target.chunk(world_size, 1)[rank], + group=group, + **kwargs, ) _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) -def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGroup): - success = True - for function in (reverse_kl_forward_backward, forward_kl_forward_backward, entropy_loss_forward_backward): - for target_format in (TargetFormat.logits,): +def _run_entropy_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path): + for entropy_loss_type in EntropyLossType: + for target_format in TargetFormat: + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + continue for loss_masking in [False, True]: - try: - _compare_parallel_cross_entropy(rank, group, target_format, function, loss_masking) - except Exception: - print( - f" >>>>>> Failed {function.__name__}, target_format, use_mask={loss_masking}", file=sys.stderr - ) - traceback.print_exc() - success = False - if not success: - raise RuntimeError("Test failed") - - -@requires_cuda + name = f"{entropy_loss_type}_{target_format}_{loss_masking}" + with test_context.subtest(base_path, name, 2) as subtest: + if subtest.do_run: + _entropy_loss_distributed(target_format, entropy_loss_type, loss_masking, test_context.group) + + @pytest.mark.slow -def test_distillation_losses(): - _spawn_dist(2, compare_parallel_cross_entropy) +def test_entropy_loss_distributed_dependency(): + # Mock test so the distributed subtest are placed in the same dependency group. + pass + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +def test_run_entropy_loss_distributed(run_parallel_script, result_path): + run_parallel_script( + _run_entropy_loss_distributed, + (result_path / "test_entropy_loss",), + world_size=2, + backend=DistributedBackend.gloo, + use_cpu=True, # Disable device count check. + ) + + +# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. +# This should still run after `test_run_entropy_loss_distributed` +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +@pytest.mark.parametrize("loss_masking", (False, True)) +def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 67d674b65..4ea86621c 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,8 +7,9 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import EntropyLossImplementation from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs, LanguageModelLossConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -322,9 +323,7 @@ def test_lm_head( ) ) - sequence_first = config.sequence_first or ( - head_config.cross_entropy_splits is not None and head_config.cross_entropy_splits > 1 - ) + sequence_first = config.sequence_first or head_config.cross_entropy_splits > 1 input_ = torch.randn( (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=( diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 4fea1fbba..a30440ad1 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -2,6 +2,7 @@ import json import logging import math +import os import pathlib import sys import time @@ -27,11 +28,13 @@ def __init__( timeout: float = 20.0, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, + use_cpu: bool = False, ) -> None: self._do_capture = do_capture self._timeout = timeout self._init_method = init_method self._backend = backend + self._use_cpu = use_cpu def __enter__(self): if self._do_capture: @@ -40,7 +43,7 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend + timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cpu=self._use_cpu ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size @@ -48,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start") + safe_barrier(self._group, "start", device=self._pool.device) return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end") + safe_barrier(self._group, "testing end", device=self._pool.device) # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -75,6 +78,10 @@ def rank(self) -> int: def world_size(self) -> int: return self._world_size + @property + def group(self) -> torch.distributed.ProcessGroup: + return self._group + class DistributedSubtestContext: def __init__( self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int @@ -83,7 +90,7 @@ def __init__( self._path = base_path / name self._name = name self._num_gpus = num_gpus - self._skip = self._test_context._world_size < self._num_gpus + self._skip = self._test_context._world_size < self._num_gpus and not self._test_context._use_cpu self._do_run = self._test_context._rank < num_gpus and not self._skip self._do_capture = self._test_context._do_capture and self._do_run self._success = False @@ -131,10 +138,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name) - self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + safe_barrier(group, self._name, device=self._test_context._pool.device) + self._success = ( + allreduce_scalar( + self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device + ) + == group.size() + ) - if self._do_capture: + if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) report["duration"] = time.perf_counter() - self._start @@ -233,13 +245,14 @@ def parallel_worker( init_method: str, backend: DistributedBackend, do_capture: bool, + use_cpu: bool, fn: typing.Callable, fn_args: typing.Sequence[typing.Any], ): DistributedConfig.default_rank = rank DistributedConfig.default_world_size = world_size DistributedConfig.default_local_world_size = world_size - with DistributedTestContext(do_capture, 60, init_method, backend) as test_context: + with DistributedTestContext(do_capture, 60, init_method, backend, use_cpu) as test_context: fn(test_context, *fn_args) @@ -251,14 +264,17 @@ def do_run_parallel_script( world_size: int, timeout: float = 240, backend: DistributedBackend = DistributedBackend.nccl, + use_cpu: bool = False, # Use CPU device in process group pool. May be used to disable device count check ): + if "PYTHONHASHSEED" not in os.environ: + os.environ["PYTHONHASHSEED"] = "0" if do_capture: logger.warning( "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." ) torch.multiprocessing.spawn( parallel_worker, - args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args), + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, use_cpu, fn, fn_args), nprocs=world_size, join=False, ).join(timeout, grace_period=5) From 7f96009c2641d59953dc6ce1e042f84c8d40805b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 05:28:45 -0500 Subject: [PATCH 40/51] fix --- tests/utils/distributed_configs.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 0b54f63f7..60f7b22fc 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -60,12 +60,12 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ("init", None): get_config(), (None, "fw"): get_config(1.5e-2, 1.5e-3), (None, "bw"): get_config(1.5e-2, 1e-5), - # TODO: Diff too big for normalization gradients on CPU. + # TODO: Normalization gradient broken on CPU, getting inconsistent results across machines. **( {} if torch.cuda.is_available() else { - (None, "norm"): get_config(0.25, 2e-3), + (None, "norm"): get_config(ignore_tensors=True), (None, "word_embeddings_weight"): get_config(0.08, 1e-4), } ), @@ -80,9 +80,8 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Saved gradient include the gradient scaling by 2**16 (default initial value) (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), - # TODO: Diff too big on CPU, especially for bias and normalization. - # TODO: Diff too big for normalization gradients on CPU. - **({} if torch.cuda.is_available() else {(None, "norm"): get_config(0.25, 2e-3, scale=2**16)}), + # TODO: Normalization gradient broken on CPU, getting inconsistent results across machines. + **({} if torch.cuda.is_available() else {(None, "norm"): get_config(ignore_tensors=True)}), (None, "bias"): ( get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(6e-3, 2e-4, scale=2**16) ), From afc33f33654863a33f3f319d9cef38ab6b4b2ea7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 06:26:55 -0500 Subject: [PATCH 41/51] stuff --- fast_llm/functional/{cross_entropy.py => entropy_loss.py} | 0 fast_llm/layers/language_model/loss/config.py | 4 ++-- tests/functional/test_entropy_loss.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename fast_llm/functional/{cross_entropy.py => entropy_loss.py} (100%) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/entropy_loss.py similarity index 100% rename from fast_llm/functional/cross_entropy.py rename to fast_llm/functional/entropy_loss.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 609bcbac9..948d7eea0 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -77,7 +77,7 @@ def get_loss( sequence_parallel_logits: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import entropy_loss_forward_backward + from fast_llm.functional.entropy_loss import entropy_loss_forward_backward labels = kwargs[LanguageModelKwargs.labels] @@ -167,7 +167,7 @@ def get_loss( sequence_parallel_logits: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.cross_entropy import entropy_loss_forward_backward + from fast_llm.functional.entropy_loss import entropy_loss_forward_backward if prediction_distance > 0: raise NotImplementedError() diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py index cb2036f94..4f3f5b6cb 100644 --- a/tests/functional/test_entropy_loss.py +++ b/tests/functional/test_entropy_loss.py @@ -5,7 +5,7 @@ from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.cross_entropy import entropy_loss_forward_backward +from fast_llm.functional.entropy_loss import entropy_loss_forward_backward from fast_llm.utils import Assert from tests.utils.subtest import DistributedTestContext From f8dcce6699d4e2cce8cfaab0eaffc892a103078b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 10:08:58 -0500 Subject: [PATCH 42/51] stuff --- fast_llm/engine/distributed/distributed.py | 2 +- fast_llm/layers/language_model/config.py | 13 +- fast_llm/layers/language_model/loss/config.py | 24 +- tests/layers/test_lm_head.py | 624 +++++++----------- 4 files changed, 246 insertions(+), 417 deletions(-) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index 7b9d1e75d..9f81760cf 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -259,5 +259,5 @@ def set_step(self, step: int, phase: PhaseType) -> None: self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) def __del__(self): - if self._local_pool: + if getattr(self, "_local_pool", False) and hasattr(self, "_pool"): self._pool.shutdown() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e0fbef1d7..c41860ea9 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -6,11 +6,15 @@ from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockSequenceConfig +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig, LanguageModelLossConfig +from fast_llm.layers.language_model.loss.config import ( + LanguageModelLabelEntropyLossConfig, + LanguageModelLossConfig, + LanguageModelLossKwargs, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -20,17 +24,14 @@ from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction -class LanguageModelKwargs(BlockKwargs): +class LanguageModelKwargs(LanguageModelLossKwargs): token_ids = "token_ids" position_ids = "position_ids" token_map = "token_map" sample_map = "sample_map" embedding_map = "embedding_map" # TODO: These are generic - labels = "labels" phase = "phase" - chosen_spans = "chosen_spans" - rejected_spans = "rejected_spans" loss_mask = "loss_mask" mask_inputs = "mask_inputs" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 948d7eea0..81d881793 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -2,13 +2,19 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: import torch +class LanguageModelLossKwargs(BlockKwargs): + labels = "labels" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" + + @config_class(registry=True) class LanguageModelLossConfig(Config): _abstract: typing.ClassVar[bool] = True @@ -79,15 +85,15 @@ def get_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.entropy_loss import entropy_loss_forward_backward - labels = kwargs[LanguageModelKwargs.labels] + labels = kwargs[LanguageModelLossKwargs.labels] # MTP: Shift the labels if prediction_heads > 1: - sequence_q_length = labels.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - prediction_heads - if LanguageModelKwargs.sequence_q_dim in kwargs: - Assert.eq(sequence_q_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) + sequence_q_length = labels.size(1 - kwargs[LanguageModelLossKwargs.sequence_first]) + 1 - prediction_heads + if LanguageModelLossKwargs.sequence_q_dim in kwargs: + Assert.eq(sequence_q_length, kwargs[LanguageModelLossKwargs.sequence_q_dim].size) label_slice = slice(prediction_distance, prediction_distance + sequence_q_length) - labels = labels[label_slice] if kwargs[LanguageModelKwargs.sequence_first] else labels[:, label_slice] + labels = labels[label_slice] if kwargs[LanguageModelLossKwargs.sequence_first] else labels[:, label_slice] labels = labels.flatten() @@ -249,7 +255,7 @@ def get_loss( logits = logits * logits_scale_factor reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) - target = kwargs[LanguageModelKwargs.labels] + target = kwargs[LanguageModelLossKwargs.labels] if sequence_parallel_logits: from fast_llm.core.ops import split_op @@ -257,8 +263,8 @@ def get_loss( reference_model_logits = split_op(reference_model_logits, group, 0) target = split_op(target, group, 0) - chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] - rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] + chosen_spans = kwargs[LanguageModelLossKwargs.chosen_spans] + rejected_spans = kwargs[LanguageModelLossKwargs.rejected_spans] return compute_dpo_loss( logits=logits, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 4ea86621c..b23c18a36 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -1,131 +1,16 @@ +import dataclasses import typing import pytest import torch -from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.functional.config import EntropyLossImplementation from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, LanguageModelKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda - - -def _reverse_kl_loss( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - teacher_softmax_temperature: float = 1.0, -): - scaled_target = torch.clamp(target / teacher_softmax_temperature, min=-50, max=50) - teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) - - with torch.enable_grad(): - # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) - student_log_probs = torch.log_softmax(logits, dim=-1) - if loss_mask is None: - loss = torch.nn.functional.kl_div( - teacher_log_probs, # input = log(p) - student_log_probs, # target = log(q) - reduction="batchmean", - log_target=True, - ) - else: - # Apply loss mask - this requires some reshaping - loss_per_sample = torch.nn.functional.kl_div( - teacher_log_probs, student_log_probs, reduction="none", log_target=True - ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() - return loss - - -def _kl_loss( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - teacher_softmax_temperature: float = 1.0, -): - return _reverse_kl_loss( - target, - logits, - loss_mask, - teacher_softmax_temperature, - ) - - -def _lm_head( - input_: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - *, - # config:LanguageModelBaseConfig, - rms_weight: torch.Tensor, - logit_weight: torch.Tensor, - grad_output: float = 1.0, - logit_scale_factor: float = 1.0, - losses: dict[str, LanguageModelLossConfig], -): - hidden = torch.rms_norm( - input_.to(rms_weight.dtype), - input_.shape[-1:], - rms_weight, - 1e-5, - ) - logits = torch.nn.functional.linear(hidden, logit_weight).float() - - if "dist_loss" in losses: - if losses["dist_loss"].type == "reverse_kl_distillation": - Assert.eq(logits.shape, target.shape) - loss = _reverse_kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) - # Return scaled loss - return loss * losses["dist_loss"].weight, None - elif losses["dist_loss"].type == "forward_kl_distillation": - Assert.eq(logits.shape, target.shape) - loss = _kl_loss( - (logits * logit_scale_factor).flatten(0, -2), (target * logit_scale_factor).flatten(0, -2), loss_mask - ) - # Apply distillation_loss_factor to grad_output for backward - loss.backward(torch.full_like(loss, grad_output * losses["dist_loss"].weight)) - # Return scaled loss - return loss * losses["dist_loss"].weight, None - - if logit_scale_factor != 1.0: - logits *= logit_scale_factor - - # Compute z_loss if configured - if "z_loss" in losses: - z_loss_unscaled = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) - # Backward through z_loss (retain_graph since we need to also backward through ce_loss) - z_loss_unscaled.backward( - torch.full_like(z_loss_unscaled, grad_output * losses["z_loss"].weight), retain_graph=True - ) - z_loss_scaled = z_loss_unscaled * losses["z_loss"].weight - else: - z_loss_unscaled = None - z_loss_scaled = None - - # Language model loss (cross-entropy with hard labels) - ce_loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) - # Backward through ce_loss - ce_loss.backward(torch.full_like(ce_loss, grad_output * losses["lm_loss"].weight)) - ce_loss_scaled = ce_loss * losses["lm_loss"].weight - - # Total loss = ce_loss + z_loss (both scaled) - total_loss = ce_loss_scaled - if z_loss_scaled is not None: - total_loss = total_loss + z_loss_scaled - - return total_loss, z_loss_unscaled - +from tests.utils.utils import get_base_model, get_stage SEQUENCE_LENGTH = 200 BATCH_SIZE = 4 @@ -133,340 +18,277 @@ def _lm_head( VOCAB_SIZE = 500 -@requires_cuda -@pytest.mark.slow -@pytest.mark.parametrize("cross_entropy_impl", tuple(EntropyLossImplementation)) -@pytest.mark.parametrize( - ("config_dict", "distributed_config_dict", "loss_masking", "prediction_heads"), - ( - ({}, {}, False, 1), - ({}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"embeddings": {"full_precision_residual": True}}, {"compute_dtype": DataType.bfloat16}, False, 1), - ({"sequence_first": True}, {}, False, 1), - ( - { - "head": { - "losses": { - "z_loss": { - "type": "z_loss", - "weight": 1e-3, - }, - }, - } - }, - {}, - False, - 1, - ), - ({"head": {"logits_scale_factor": 5.0}}, {}, False, 1), - ({"tied_embedding_weight": True}, {}, False, 1), - ({}, {}, False, 2), - ({}, {}, True, 1), - # Skip CE distillation for now - not yet implemented in new losses system - # ( - # { - # "head": { - # "distillation_model": "distillation", - # "losses": { - # "lm_loss": { - # "type": "cross_entropy", - # "weight_scalor": 0.0, - # }, - # "dist_loss": { - # "type": "cross_entropy_dist", # TODO: Not implemented yet - # "weight_scalor": 1.0, - # } - # } - # } - # }, - # {}, - # False, - # 1, - # ), - pytest.param( - { - "head": { - "losses": { - "lm_loss": { - "type": "cross_entropy", - "weight": 0.0, - }, - "dist_loss": { - "type": "reverse_kl_distillation", - "weight": 1.0, - "distillation_model": "distillation", - }, - }, - } - }, - {}, - False, - 1, - id="track_lm_zero_factor", - ), - pytest.param( - { - "head": { - "losses": { - "lm_loss": { - "type": "cross_entropy", - "weight": 0.0, - }, - "dist_loss": { - "type": "forward_kl_distillation", - "weight": 1.0, - "distillation_model": "distillation", - }, - }, - } - }, - {}, - False, - 1, - id="forward_kl_distillation", - ), - pytest.param( - { - "head": { - "losses": { - "lm_loss": { - "type": "cross_entropy", - "weight": 0.0, - }, - "dist_loss": { - "type": "reverse_kl_distillation", - "weight": 0.0, - "distillation_model": "distillation", - }, - }, - } - }, - {}, - False, - 1, - marks=pytest.mark.xfail( - reason="At least one loss has to have non-zero factor to track gradients", - strict=True, - ), - id="track_both_zero_factors", - ), - pytest.param( - { - "head": { - "losses": { - "lm_loss": { - "type": "cross_entropy", - "weight": 1.0, - }, - "dist_loss": { - "type": "reverse_kl_distillation", - "weight": 1.0, - "distillation_model": "distillation", - }, - }, - } - }, - {}, - False, - 1, - marks=pytest.mark.xfail( - reason="Cannot track distillation loss without distillation model being set", - strict=True, - ), - id="track_distillation_without_model", - ), - ), -) -def test_lm_head( - cross_entropy_impl: EntropyLossImplementation, - config_dict: dict[str, typing.Any], - distributed_config_dict: dict[str, typing.Any], - loss_masking: bool, - prediction_heads: int, -): - head_config = { - "normalization": {"type": "rms_norm"}, - "losses": { - "lm_loss": { - "type": "cross_entropy", - "implementation": cross_entropy_impl, - "weight": 1.0, - } - }, - } - config = GPTBaseModelConfig.from_dict( - { - "decoder": {"num_blocks": 0}, - "embeddings": {"vocab_size": VOCAB_SIZE}, - "head": ( - head_config - if prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": prediction_heads, - } - ), - "hidden_size": HIDDEN_SIZE, - }, - config_dict, - update_type=UpdateType.update, - ) - head_config: LanguageModelHeadConfig = config.head if prediction_heads == 1 else config.head.head +@dataclasses.dataclass +class LMHeadTestConfig: + name: str + label_loss: bool | float = False + distillation_loss: bool | float = False + z_loss: bool | float = False + logits_scale_factor: float = 1.0 + compute_dtype: DataType = DataType.float32 + full_precision_residual: bool = False + sequence_first: bool = False + loss_masking: bool = False + prediction_heads: int = 1 + tied_embedding_weight: bool = False + cross_entropy_splits: int = 1 + + @property + def actual_label_loss(self): + return ( + True + if self.label_loss is False and self.distillation_loss is False and self.z_loss is False + else self.label_loss + ) - model, distributed = get_base_model( - GPTModelConfig.from_dict( + def get_config(self) -> GPTModelConfig: + head_config = { + "normalization": {"type": "rms_norm"}, + "logits_scale_factor": self.logits_scale_factor, + "cross_entropy_splits": self.cross_entropy_splits, + } + losses = {} + if self.label_loss is not False: + losses["label"] = {"type": "label"} + if isinstance(self.label_loss, float): + losses["label"]["weight"] = self.label_loss + if self.distillation_loss is not False: + losses["distillation"] = {"type": "distillation", "reference_model": "distillation"} + if isinstance(self.distillation_loss, float): + losses["distillation"]["weight"] = self.distillation_loss + if self.z_loss is not False: + losses["z_loss"] = {"type": "z_loss"} + if isinstance(self.z_loss, float): + losses["z_loss"]["weight"] = self.z_loss + if losses: + head_config["losses"] = losses + + return GPTModelConfig.from_dict( { - "base_model": config, - "distributed": distributed_config_dict, + "base_model": { + "decoder": {"num_blocks": 0}, + "embeddings": {"vocab_size": VOCAB_SIZE, "full_precision_residual": self.full_precision_residual}, + "head": ( + head_config + if self.prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": self.prediction_heads, + } + ), + "hidden_size": HIDDEN_SIZE, + "tied_embedding_weight": self.tied_embedding_weight, + }, + "distributed": {"compute_dtype": self.compute_dtype}, }, ) - ) - sequence_first = config.sequence_first or head_config.cross_entropy_splits > 1 - input_ = torch.randn( - (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), - dtype=( - distributed.config.optimization_dtype.torch - if config.embeddings.full_precision_residual - else distributed.config.compute_dtype.torch - ), - device=distributed.device, - requires_grad=True, - ) - label_shape = ( - (SEQUENCE_LENGTH + config.head.max_prediction_distance - 1, BATCH_SIZE) - if sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + config.head.max_prediction_distance - 1) - ) - if loss_masking: - loss_mask = torch.randint(0, 2, label_shape, dtype=torch.bool, device=distributed.device) - else: - loss_mask = None - kwargs = { - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.grad_output: 1.0, - } - # always set lm targets - target = torch.randint( - 0, - VOCAB_SIZE, - label_shape, - dtype=torch.int64, - device=distributed.device, - ) - if loss_mask is not None: - target *= loss_mask - - kwargs[LanguageModelKwargs.labels] = target - if head_config.distillation_model is not None: - assert config.head.max_prediction_distance == 1 - target = torch.randn( - input_.shape[:-1] + (VOCAB_SIZE,), - dtype=input_.dtype, - device=distributed.device, + def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: + device = "cuda" if torch.cuda.is_available() else "cpu" + input_ = torch.randn( + ( + (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE) + ), + dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), + device=device, + requires_grad=True, + ) + label_shape = ( + (SEQUENCE_LENGTH + self.max_prediction_distance - 1, BATCH_SIZE) + if self.sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + self.max_prediction_distance - 1) ) - kwargs[f"{head_config.distillation_model}_logits"] = target - if loss_mask is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask + kwargs: dict[str, typing.Any] = { + AttentionKwargs.sequence_first: self.sequence_first, + AttentionKwargs.grad_output: 1.0, + } + if self.loss_masking: + kwargs[LanguageModelKwargs.loss_mask] = torch.randint(0, 2, label_shape, dtype=torch.bool, device=device) + if self.actual_label_loss is not False: + labels = torch.randint( + 0, + VOCAB_SIZE, + label_shape, + dtype=torch.int64, + device=device, + ) + if LanguageModelKwargs: + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + kwargs[LanguageModelKwargs.labels] = labels + + if self.distillation_loss is not False: + assert self.max_prediction_distance == 1 + kwargs[f"distillation_logits"] = torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=device, + ) + return input_, kwargs + + def get_reference_outputs( + self, + head: LanguageModelHead, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + tied_logit_weight: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]: + # Get reference outputs and grads + logit_weight = ( + (head.output_weights if tied_logit_weight is None else tied_logit_weight).detach().requires_grad_() + ) + normalization_weight = head.final_norm.weight.detach().requires_grad_() + input_ = input_.detach().requires_grad_() + + hidden = torch.rms_norm(input_.to(normalization_weight.dtype), input_.shape[-1:], normalization_weight, 1e-5) + logits = torch.nn.functional.linear(hidden, logit_weight).float() + + if self.logits_scale_factor is not None: + logits = logits * self.logits_scale_factor + + total_loss = 0 + losses = {} - if config.tied_embedding_weight or config.head.max_prediction_distance > 1: - logit_weight = torch.nn.Parameter( + if self.actual_label_loss is not False: + label_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), kwargs[LanguageModelKwargs.labels].flatten() + ) + losses["label_loss"] = label_loss.detach() + total_loss = total_loss + float(self.actual_label_loss) * label_loss + + if self.distillation_loss is not False: + distillation_loss = torch.nn.functional.cross_entropy( + logits.flatten(0, -2), torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1) + ) + losses["distillation_loss"] = distillation_loss.detach() + total_loss = total_loss + float(self.distillation_loss) * distillation_loss + + if self.z_loss is not False: + z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + losses["z_loss"] = z_loss.detach() + total_loss = total_loss + float(self.z_loss) * z_loss + + total_loss.backward() + return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses + + +_lm_head_test_configs = ( + # TODO: Test DPO loss. + LMHeadTestConfig("default"), + LMHeadTestConfig("bfloat16", compute_dtype=DataType.bfloat16), + LMHeadTestConfig("full_precision_residual", full_precision_residual=True), + LMHeadTestConfig("sequence_first", sequence_first=True), + LMHeadTestConfig("logit_scaling", logits_scale_factor=5.0), + LMHeadTestConfig("tied_embedding_weight", tied_embedding_weight=True), + LMHeadTestConfig("multi_token_prediction", prediction_heads=2), + LMHeadTestConfig("cross_entropy_splits", cross_entropy_splits=2, sequence_first=True), + LMHeadTestConfig("loss_masking", loss_masking=True), + LMHeadTestConfig("label_loss", label_loss=True), + LMHeadTestConfig("distillation_loss", distillation_loss=True), + LMHeadTestConfig("z_loss", z_loss=True), + LMHeadTestConfig("label_and_distillation_loss", label_loss=True, distillation_loss=True), + LMHeadTestConfig("label_and_z_loss_weighted", label_loss=True, z_loss=0.5), + LMHeadTestConfig("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0), +) + + +@pytest.mark.slow +@pytest.mark.parametrize("test_config", _lm_head_test_configs) +def test_lm_head(test_config): + model_config = test_config.get_config() + model, distributed = get_base_model(model_config) + input_, kwargs = test_config.get_inputs() + + tied_logit_weight = ( + torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.hidden_size**-0.5) + ).normal_(HIDDEN_SIZE**-0.5) ) - else: - logit_weight = None + if test_config.tied_embedding_weight or test_config.max_prediction_distance > 1 + else None + ) for prediction_distance, head in enumerate(model.head.heads): # Prepare the LM head Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) - is_duplicate = config.tied_embedding_weight or prediction_distance > 0 + is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 stage = get_stage( [head], distributed, tied_parameter_duplicates=[head.output_weights.tensor_name] if is_duplicate else [], - tied_parameter_duplicate_buffers={head.output_weights.tensor_name: logit_weight} if is_duplicate else {}, + tied_parameter_duplicate_buffers=( + {head.output_weights.tensor_name: tied_logit_weight} if is_duplicate else {} + ), # Names must be kept as-is for tied weights. set_names=False, ) - # Get reference outputs and grads - if is_duplicate: - logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) - logit_weight.param_grad_is_zero = True - else: - logit_weight = head.output_weights - - ref_input = input_.detach().requires_grad_() - ref_rms_weight = head.final_norm.weight.detach().requires_grad_() - ref_logit_weight = logit_weight.detach().requires_grad_() - - ref_loss, ref_z_loss = _lm_head( - ref_input, - ( - target[prediction_distance : prediction_distance + SEQUENCE_LENGTH] - if sequence_first - else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] - ), - loss_mask, - rms_weight=ref_rms_weight, - logit_weight=ref_logit_weight, - logit_scale_factor=head_config.logits_scale_factor, - losses=head_config.losses, + ref_total_loss, ref_input_grad, ref_logit_weight_grad, ref_normalization_weight_grad, ref_losses = ( + test_config.get_reference_outputs(head, input_, kwargs, tied_logit_weight) ) # Prepare LM head inputs if head._is_last_head: - head_input = input_ - output_grad = ref_input.new_full((), float("nan")) + head_input = input_.detach().requires_grad_() + output_grad = input_.new_full((), float("nan")) else: shared_hidden = torch.randn_like(input_) head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() output_grad = torch.randn_like(shared_hidden) - lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" - expected_loss_keys = {lm_head_loss_name} + if is_duplicate: + logit_weight = tied_logit_weight + logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) + logit_weight.param_grad_is_zero = True + else: + logit_weight = head.output_weights + + # lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" + # expected_loss_keys = {lm_head_loss_name} - # Get expected loss names from the loss configs - for loss_name, loss_config in head._config.losses.items(): - formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) - expected_loss_keys.add(formatted_name) + ## Get expected loss names from the loss configs + # for loss_name, loss_config in head._config.losses.items(): + # formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) + # expected_loss_keys.add(formatted_name) # if ref_z_loss is not None: # expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - Assert.eq( - {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - {loss_key: 1 for loss_key in expected_loss_keys}, - ) - losses = {key: [] for key in expected_loss_keys} - output, context = stage.forward(head_input, kwargs, losses) + # Assert.eq( + # {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, + # {loss_key: 1 for loss_key in expected_loss_keys}, + # ) + # losses = {key: [] for key in expected_loss_keys} + output, context = stage.forward(head_input, kwargs, None) stage.backward(output_grad, context) threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 - ) * head_config.logits_scale_factor + ) * test_config.logits_scale_factor - Assert.eq(losses.keys(), expected_loss_keys) - Assert.eq(len(losses[lm_head_loss_name]), 1) + # Assert.eq(losses.keys(), expected_loss_keys) + # Assert.eq(len(losses[lm_head_loss_name]), 1) # if ref_z_loss is not None: # Assert.eq(len(losses["z_loss"]), 1) # Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(losses[lm_head_loss_name][0], ref_loss, threshold, min_threshold) + Assert.rms_close_relative(output, ref_total_loss, threshold, min_threshold) if head._is_last_head: - Assert.all_equal(output, losses[lm_head_loss_name][0]) + # Assert.all_equal(output, losses[lm_head_loss_name][0]) input_grad = head_input.grad else: Assert.all_equal(output, shared_hidden) shared_hidden_grad, input_grad = head_input.grad.unbind() Assert.all_equal(shared_hidden_grad, output_grad) - Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) - Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) - Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) + Assert.rms_close_relative(input_grad, ref_input_grad, threshold, min_threshold) + Assert.rms_close_relative( + head.final_norm.weight.grad_buffer, ref_normalization_weight_grad, threshold, min_threshold + ) + Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight_grad, threshold, min_threshold) From f96c72f332e2c99577b408578990498c471f5d72 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 10:24:37 -0500 Subject: [PATCH 43/51] stuff --- fast_llm/models/gpt/model.py | 2 +- tests/utils/distributed_configs.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 5d3fc3cad..8de6822fd 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -273,7 +273,7 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False labels = torch.where(loss_mask, labels, -100) - if self._config.head.distillation_model is not None: # loss masks only used for distillation currently + if self._config.head.get_distillation_models(): # loss masks only used for distillation currently # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 60f7b22fc..f08e9a488 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -66,7 +66,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon if torch.cuda.is_available() else { (None, "norm"): get_config(ignore_tensors=True), - (None, "word_embeddings_weight"): get_config(0.08, 1e-4), + (None, "word_embeddings_weight"): get_config(8e-2, 1e-4), } ), (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(2e-2, 2e-3), @@ -81,7 +81,14 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), # TODO: Normalization gradient broken on CPU, getting inconsistent results across machines. - **({} if torch.cuda.is_available() else {(None, "norm"): get_config(ignore_tensors=True)}), + **( + {} + if torch.cuda.is_available() + else { + (None, "norm"): get_config(ignore_tensors=True), + (None, "word_embeddings_weight"): get_config(2e-2, 1e-4, scale=2**16), + } + ), (None, "bias"): ( get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(6e-3, 2e-4, scale=2**16) ), From 2a4362f35515b9da2ac28ee115f109f2a76a9097 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 13:39:33 -0500 Subject: [PATCH 44/51] fixes --- fast_llm/layers/common/auxiliary_loss.py | 8 +-- fast_llm/layers/language_model/head.py | 31 +++++----- fast_llm/layers/language_model/loss/config.py | 7 ++- .../loss/language_model_loss.py | 0 tests/layers/test_lm_head.py | 58 ++++++++++++++----- 5 files changed, 66 insertions(+), 38 deletions(-) delete mode 100644 fast_llm/layers/language_model/loss/language_model_loss.py diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index baec73b1c..97e04de16 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -13,12 +13,12 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: @torch.compile -def calculate_z_loss( +def z_loss( logits: torch.Tensor, logits_scale_factor: float = 1.0, loss_mask: "torch.Tensor | None" = None ) -> torch.Tensor: out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 if loss_mask is not None: - out *= loss_mask.unsqueeze(-1) + out = out * loss_mask return torch.mean(out) @@ -33,7 +33,7 @@ def auxiliary_z_loss( loss_mask: "torch.Tensor | None" = None, ) -> torch.Tensor: if losses is not None or (training and grad_scale is not None): - loss = calculate_z_loss(logits, logits_scale_factor, loss_mask) + loss = z_loss(logits, logits_scale_factor, loss_mask) if losses is not None and loss_name is not None: losses[loss_name].append(loss.detach()) if training and grad_scale is not None: @@ -60,7 +60,7 @@ def z_loss_forward_backward( with torch.set_grad_enabled(grad_output is not None): logits_ = logits.detach().requires_grad_(grad_output is not None) - loss = calculate_z_loss(logits, logits_scale_factor, loss_mask) + loss = z_loss(logits_, logits_scale_factor, loss_mask) if grad_output is None: grad = None else: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 13ffc4f16..c0242e25e 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -26,7 +26,6 @@ LanguageModelKwargs, _format_name, ) -from fast_llm.layers.language_model.loss.config import LanguageModelDistillationLossConfig, LanguageModelDPOLossConfig from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -69,12 +68,6 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - if prediction_distance > 0 and any( - isinstance(loss, (LanguageModelDPOLossConfig, LanguageModelDistillationLossConfig)) - for loss in self.losses.values() - ): - raise NotImplementedError("Multi-token prediction not supported with distillation or dpo.") - Assert.in_range(prediction_distance, 0, prediction_heads) self._prediction_distance = prediction_distance self._prediction_heads = prediction_heads @@ -184,7 +177,7 @@ def _logits_loss_forward_backward( loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) if not self.training: - logits, _ = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs, True) + logits, _ = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs, return_logits=True) # TODO: Make a proper way of returning the model output. logits = logits.detach() if kwargs.get("global_logits"): @@ -198,7 +191,7 @@ def _logits_loss_forward_backward( logits.detach() ) return None, None - if self._config.cross_entropy_splits == 1 or self.training: + elif self._config.cross_entropy_splits == 1: losses_, input_grad = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs) else: input_grad = torch.empty_like(input_) @@ -210,23 +203,24 @@ def _logits_loss_forward_backward( ) for tensor in [input_, loss_mask, input_grad] ] - for i, (partial_input_, loss_mask_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): + for split_index, (partial_input_, loss_mask_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): partial_losses_, grad_ = self._logits_loss_forward_backward_partial( partial_input_, loss_mask_, kwargs, + split_index=split_index, ) # TODO: Avoid copy with explicit out argument. input_grad_.copy_(grad_) - if i == 0: + if split_index == 0: losses_ = partial_losses_ else: for name in self._config.losses: losses_[name] += partial_losses_[name] loss: torch.Tensor = sum( - (loss_config.weight * self._loss_coefficient / self._config.cross_entropy_splits) * losses_[name] - for name, loss_config in self._config.losses.items() + (self.config.losses[name].weight * self._loss_coefficient / self._config.cross_entropy_splits) * loss_ + for name, loss_ in losses_.items() ) if self._sequence_parallel_logits: # TODO: Async @@ -235,14 +229,13 @@ def _logits_loss_forward_backward( if losses is not None: losses[self._total_head_loss_name].append(loss) if len(self._config.losses) > 1: - for name, loss_config in self._config.losses.items(): - loss_ = losses_[name] + for name, loss_ in losses_.items(): if self._config.cross_entropy_splits != 1: loss_ /= self._config.cross_entropy_splits if self._sequence_parallel_logits: # TODO: Async all_reduce(loss_, op=ReduceOp.AVG, group=self._parallel_dim.group) - losses[loss_config.get_name(self._prediction_distance)].append(loss_) + losses[self.config.losses[name].get_name(self._prediction_distance)].append(loss_) return loss, input_grad @@ -251,6 +244,7 @@ def _logits_loss_forward_backward_partial( input_: torch.Tensor, loss_mask: torch.Tensor | None, kwargs: dict, + split_index: int = 0, return_logits: bool = False, ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, torch.Tensor | None]: group = self._parallel_dim.group if self._vocab_parallel else None @@ -297,6 +291,11 @@ def _logits_loss_forward_backward_partial( group=group, logits_scale_factor=self._config.logits_scale_factor, kwargs=kwargs, + prediction_distance=self._prediction_distance, + prediction_heads=self._prediction_heads, + split_index=split_index, + num_splits=self._config.cross_entropy_splits, + sequence_parallel_logits=self._sequence_parallel_logits, ) losses[loss_name] = loss.detach() if grad_ is not None: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 81d881793..f4ded2062 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -83,6 +83,8 @@ def get_loss( sequence_parallel_logits: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": + import torch + from fast_llm.functional.entropy_loss import entropy_loss_forward_backward labels = kwargs[LanguageModelLossKwargs.labels] @@ -204,7 +206,7 @@ def get_loss( implementation=implementation, logits_scale_factor=logits_scale_factor, temperature=self.temperature, - target_format=TargetFormat.labels, + target_format=TargetFormat.logits, entropy_loss_type=self.loss_type, ) @@ -301,8 +303,7 @@ def get_loss( ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward - # TODO: ====== Support loss mask, vocab_parallel ====== - assert loss_mask is None + # TODO: Support vocab_parallel assert group is None return z_loss_forward_backward( diff --git a/fast_llm/layers/language_model/loss/language_model_loss.py b/fast_llm/layers/language_model/loss/language_model_loss.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b23c18a36..775b2653d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -1,3 +1,4 @@ +import collections import dataclasses import typing @@ -80,7 +81,7 @@ def get_config(self) -> GPTModelConfig: "hidden_size": HIDDEN_SIZE, "tied_embedding_weight": self.tied_embedding_weight, }, - "distributed": {"compute_dtype": self.compute_dtype}, + "distributed": {"compute_dtype": self.compute_dtype, "use_cuda": torch.cuda.is_available()}, }, ) @@ -97,9 +98,9 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: requires_grad=True, ) label_shape = ( - (SEQUENCE_LENGTH + self.max_prediction_distance - 1, BATCH_SIZE) + (SEQUENCE_LENGTH + self.prediction_heads - 1, BATCH_SIZE) if self.sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + self.max_prediction_distance - 1) + else (BATCH_SIZE, SEQUENCE_LENGTH + self.prediction_heads - 1) ) kwargs: dict[str, typing.Any] = { AttentionKwargs.sequence_first: self.sequence_first, @@ -115,12 +116,12 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: dtype=torch.int64, device=device, ) - if LanguageModelKwargs: + if LanguageModelKwargs.loss_mask in kwargs: labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) kwargs[LanguageModelKwargs.labels] = labels if self.distillation_loss is not False: - assert self.max_prediction_distance == 1 + assert self.prediction_heads == 1 kwargs[f"distillation_logits"] = torch.randn( input_.shape[:-1] + (VOCAB_SIZE,), dtype=input_.dtype, @@ -152,21 +153,37 @@ def get_reference_outputs( losses = {} if self.actual_label_loss is not False: + if self.sequence_first: + labels = kwargs[LanguageModelKwargs.labels][ + head._prediction_distance : head._prediction_distance + logits.size(0) + ] + else: + labels = kwargs[LanguageModelKwargs.labels][ + :, head._prediction_distance : head._prediction_distance + logits.size(1) + ] label_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), kwargs[LanguageModelKwargs.labels].flatten() - ) + logits.flatten(0, -2), labels.flatten(), reduction="none" + ).mean() losses["label_loss"] = label_loss.detach() total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: distillation_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1) + logits.flatten(0, -2), + torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1), + reduction="none", ) + if LanguageModelKwargs.loss_mask in kwargs: + distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() + distillation_loss = distillation_loss.mean() losses["distillation_loss"] = distillation_loss.detach() total_loss = total_loss + float(self.distillation_loss) * distillation_loss if self.z_loss is not False: - z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) + z_loss = torch.logsumexp(logits, dim=-1) ** 2 + if LanguageModelKwargs.loss_mask in kwargs: + z_loss = z_loss * kwargs[LanguageModelKwargs.loss_mask] + z_loss = z_loss.mean() losses["z_loss"] = z_loss.detach() total_loss = total_loss + float(self.z_loss) * z_loss @@ -187,7 +204,9 @@ def get_reference_outputs( LMHeadTestConfig("loss_masking", loss_masking=True), LMHeadTestConfig("label_loss", label_loss=True), LMHeadTestConfig("distillation_loss", distillation_loss=True), + LMHeadTestConfig("distillation_loss_masked", distillation_loss=True, loss_masking=True), LMHeadTestConfig("z_loss", z_loss=True), + LMHeadTestConfig("z_loss_masked", z_loss=True, loss_masking=True), LMHeadTestConfig("label_and_distillation_loss", label_loss=True, distillation_loss=True), LMHeadTestConfig("label_and_z_loss_weighted", label_loss=True, z_loss=0.5), LMHeadTestConfig("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0), @@ -195,7 +214,13 @@ def get_reference_outputs( @pytest.mark.slow -@pytest.mark.parametrize("test_config", _lm_head_test_configs) +@pytest.mark.parametrize( + "test_config", + [ + pytest.param(_lm_head_test_config, id=_lm_head_test_config.name) + for _lm_head_test_config in _lm_head_test_configs + ], +) def test_lm_head(test_config): model_config = test_config.get_config() model, distributed = get_base_model(model_config) @@ -207,7 +232,7 @@ def test_lm_head(test_config): VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device ).normal_(HIDDEN_SIZE**-0.5) ) - if test_config.tied_embedding_weight or test_config.max_prediction_distance > 1 + if test_config.tied_embedding_weight or test_config.prediction_heads > 1 else None ) @@ -228,7 +253,9 @@ def test_lm_head(test_config): ) ref_total_loss, ref_input_grad, ref_logit_weight_grad, ref_normalization_weight_grad, ref_losses = ( - test_config.get_reference_outputs(head, input_, kwargs, tied_logit_weight) + test_config.get_reference_outputs( + head, input_, kwargs, tied_logit_weight if prediction_distance > 0 else None + ) ) # Prepare LM head inputs @@ -263,9 +290,10 @@ def test_lm_head(test_config): # {loss_key: 1 for loss_key in expected_loss_keys}, # ) # losses = {key: [] for key in expected_loss_keys} - output, context = stage.forward(head_input, kwargs, None) + losses = collections.defaultdict(list) + output, context = stage.forward(head_input, kwargs, losses) + print(losses) stage.backward(output_grad, context) - threshold = 1e-5 if distributed.config.compute_dtype == DataType.float32 else 5e-3 min_threshold = ( 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 @@ -277,7 +305,7 @@ def test_lm_head(test_config): # Assert.eq(len(losses["z_loss"]), 1) # Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) - Assert.rms_close_relative(output, ref_total_loss, threshold, min_threshold) + Assert.rms_close_relative(losses[head._total_head_loss_name][0], ref_total_loss, threshold, min_threshold) if head._is_last_head: # Assert.all_equal(output, losses[lm_head_loss_name][0]) From b464e4ee515fd7f897fd2ee377ed512ae707930e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 16 Jan 2026 14:02:44 -0500 Subject: [PATCH 45/51] fixes --- fast_llm/layers/decoder/block.py | 3 +- fast_llm/layers/language_model/config.py | 9 +--- fast_llm/layers/language_model/head.py | 36 +++++---------- fast_llm/layers/language_model/loss/config.py | 8 ---- tests/layers/test_lm_head.py | 45 +++++++++---------- 5 files changed, 35 insertions(+), 66 deletions(-) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index f5abd1f6d..3ae47c0a7 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -14,7 +14,6 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig -from fast_llm.layers.language_model.head import _format_name from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -287,7 +286,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_definitions.append( LossDef( name=self._activation_distillation_loss_name, - formatted_name=_format_name(self._activation_distillation_loss_name), + formatted_name=self._activation_distillation_loss_name, count=count, ) ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index c41860ea9..b09a354b5 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,5 +1,4 @@ import abc -import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -36,8 +35,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): mask_inputs = "mask_inputs" -def _format_name(name: str) -> str: - return name.replace("_", " ") +LM_HEAD_LOSS_NAME = "lm_head_loss" @config_class() @@ -202,10 +200,7 @@ def _validate(self) -> None: if "losses" not in self._explicit_fields: self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() - - @functools.cached_property - def _loss_configs(self) -> dict[type, LanguageModelLossConfig]: - return {loss.__class__: loss for loss in self.losses.values()} + assert LM_HEAD_LOSS_NAME not in self.losses @property def max_prediction_distance(self) -> int: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c0242e25e..736d8faf0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,5 +1,4 @@ import abc -import functools import logging import typing @@ -20,11 +19,11 @@ from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( + LM_HEAD_LOSS_NAME, LanguageModelEmbeddingsConfig, LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, - _format_name, ) from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -97,10 +96,6 @@ def __init__( peft=self._peft, ) - self._formatted_loss_names = {} - for registered_loss_name, loss_config in self._config.losses.items(): - self._formatted_loss_names[registered_loss_name] = loss_config.get_name(self._prediction_distance) - def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss) return ( @@ -227,7 +222,7 @@ def _logits_loss_forward_backward( all_reduce(loss, op=ReduceOp.AVG, group=self._parallel_dim.group) if losses is not None: - losses[self._total_head_loss_name].append(loss) + losses[self.get_full_loss_name(LM_HEAD_LOSS_NAME)].append(loss) if len(self._config.losses) > 1: for name, loss_ in losses_.items(): if self._config.cross_entropy_splits != 1: @@ -235,7 +230,7 @@ def _logits_loss_forward_backward( if self._sequence_parallel_logits: # TODO: Async all_reduce(loss_, op=ReduceOp.AVG, group=self._parallel_dim.group) - losses[self.config.losses[name].get_name(self._prediction_distance)].append(loss_) + losses[name].append(loss_) return loss, input_grad @@ -304,32 +299,25 @@ def _logits_loss_forward_backward_partial( return losses, output_parallel_linear_backward(grad, context) if self.training else None - @functools.cached_property - def _total_head_loss_name(self) -> str: - """ - Combined total scaled loss used for training. - """ - name = "lm_head_loss" - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name - def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return [ - LossDef( - name=self._total_head_loss_name, formatted_name=_format_name(self._total_head_loss_name), count=count - ), + LossDef(name=(name := self.get_full_loss_name(LM_HEAD_LOSS_NAME)), formatted_name=name, count=count), *( LossDef( - name=loss_config.get_name(self._prediction_distance), - formatted_name=_format_name(loss_config.get_name(self._prediction_distance)), + name=(name_ := self.get_full_loss_name(name)), + formatted_name=name_, count=count, dtype=DataType.float32, ) - for loss_config in self._config.losses.values() + for name, loss_config in self._config.losses.values() ), ] + def get_full_loss_name(self, name) -> str: + if self._prediction_distance > 0: + name = f"{name}_{self._prediction_distance}" + return name + @property def heads(self): # For compatibility with MTP. diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f4ded2062..551554132 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -29,10 +29,6 @@ class LanguageModelLossConfig(Config): def get_name(self, prediction_distance: int = 0) -> str: return self._name if prediction_distance == 0 else f"{self._name}_{prediction_distance}" - @property - def _name(self) -> str: - raise NotImplementedError() - def get_loss( self, logits: "torch.Tensor", @@ -53,7 +49,6 @@ def get_loss( @config_class(dynamic_type={LanguageModelLossConfig: "label"}) class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "CE_loss" _abstract: typing.ClassVar[bool] = False loss_type: EntropyLossType = Field( @@ -136,7 +131,6 @@ def get_loss( @config_class(dynamic_type={LanguageModelLossConfig: "distillation"}) class LanguageModelDistillationLossConfig(LanguageModelLossConfig): - _name: typing.ClassVar[str] = "FwdKL_loss" _abstract: typing.ClassVar[bool] = False loss_type: EntropyLossType = Field( @@ -215,7 +209,6 @@ def get_loss( class LanguageModelDPOLossConfig(LanguageModelLossConfig): """Direct Preference Optimization (DPO) loss for alignment.""" - _name: typing.ClassVar[str] = "DPO_loss" _abstract: typing.ClassVar[bool] = False beta: float = Field( @@ -283,7 +276,6 @@ def get_loss( class LanguageModelZLossConfig(LanguageModelLossConfig): """Z-loss regularization to prevent overconfidence.""" - _name: typing.ClassVar[str] = "Z_loss" _abstract: typing.ClassVar[bool] = False def get_loss( diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 775b2653d..9aa53fcc4 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.utils import Assert @@ -164,7 +164,7 @@ def get_reference_outputs( label_loss = torch.nn.functional.cross_entropy( logits.flatten(0, -2), labels.flatten(), reduction="none" ).mean() - losses["label_loss"] = label_loss.detach() + losses["label"] = label_loss.detach() total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: @@ -176,7 +176,7 @@ def get_reference_outputs( if LanguageModelKwargs.loss_mask in kwargs: distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() distillation_loss = distillation_loss.mean() - losses["distillation_loss"] = distillation_loss.detach() + losses["distillation"] = distillation_loss.detach() total_loss = total_loss + float(self.distillation_loss) * distillation_loss if self.z_loss is not False: @@ -188,11 +188,22 @@ def get_reference_outputs( total_loss = total_loss + float(self.z_loss) * z_loss total_loss.backward() + + if len(losses) > 1: + losses[LM_HEAD_LOSS_NAME] = total_loss.detach() + else: + losses = {LM_HEAD_LOSS_NAME: total_loss.detach()} + + if head._prediction_distance > 0: + losses = {f"{name}_{head._prediction_distance}": loss for name, loss in losses.items()} + return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses _lm_head_test_configs = ( # TODO: Test DPO loss. + # TODO: Add more configs + # TODO: Add distributed test LMHeadTestConfig("default"), LMHeadTestConfig("bfloat16", compute_dtype=DataType.bfloat16), LMHeadTestConfig("full_precision_residual", full_precision_residual=True), @@ -274,22 +285,6 @@ def test_lm_head(test_config): else: logit_weight = head.output_weights - # lm_head_loss_name = f"lm_head_loss_{prediction_distance}" if prediction_distance > 0 else "lm_head_loss" - # expected_loss_keys = {lm_head_loss_name} - - ## Get expected loss names from the loss configs - # for loss_name, loss_config in head._config.losses.items(): - # formatted_name = loss_config.get_formatted_name(loss_name, prediction_distance) - # expected_loss_keys.add(formatted_name) - - # if ref_z_loss is not None: - # expected_loss_keys.add(f"z_loss_{prediction_distance}" if prediction_distance > 0 else "z_loss") - - # Assert.eq( - # {loss_definition.name: loss_definition.count for loss_definition in head.get_loss_definitions()}, - # {loss_key: 1 for loss_key in expected_loss_keys}, - # ) - # losses = {key: [] for key in expected_loss_keys} losses = collections.defaultdict(list) output, context = stage.forward(head_input, kwargs, losses) print(losses) @@ -299,13 +294,13 @@ def test_lm_head(test_config): 1e-5 if distributed.config.compute_dtype == DataType.float32 else 1e-4 ) * test_config.logits_scale_factor - # Assert.eq(losses.keys(), expected_loss_keys) - # Assert.eq(len(losses[lm_head_loss_name]), 1) - # if ref_z_loss is not None: - # Assert.eq(len(losses["z_loss"]), 1) - # Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + Assert.eq(losses.keys(), ref_losses.keys()) + for name, loss in losses.items(): + assert len(loss) == 1, name + losses = {name: loss[0] for name, loss in losses.items()} - Assert.rms_close_relative(losses[head._total_head_loss_name][0], ref_total_loss, threshold, min_threshold) + for name, loss in losses.items(): + Assert.rms_close_relative(loss, ref_losses[name], threshold, min_threshold, msg=name) if head._is_last_head: # Assert.all_equal(output, losses[lm_head_loss_name][0]) From ba40a407107634232d5867622e5df0837ed8087e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 19 Jan 2026 12:52:18 -0500 Subject: [PATCH 46/51] fixes --- fast_llm/layers/language_model/config.py | 2 ++ fast_llm/layers/language_model/head.py | 22 +++++++++++++--------- fast_llm/models/gpt/config.py | 3 +-- tests/models/test_checkpoint.py | 2 ++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b09a354b5..7b4d69a8e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -149,6 +149,8 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) + # TODO: Option to chose whether to split in batch or sequence dimension? + # (Currently split merged batch and sequence, depends on `sequence_first`) cross_entropy_splits: int = Field( default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 736d8faf0..2de1ae726 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -152,7 +152,7 @@ def _forward_backward( # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) - loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach().flatten(0, -2), kwargs, losses) + loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach(), kwargs, losses) if ln_output_grad is None: return loss, None else: @@ -165,14 +165,9 @@ def _logits_loss_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - if loss_mask is not None: - loss_mask = loss_mask.flatten() - if self._sequence_parallel_logits: - loss_mask = split_op(loss_mask, self._parallel_dim.group, 0) if not self.training: - logits, _ = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs, return_logits=True) + logits, _ = self._logits_loss_forward_backward_partial(input_, None, kwargs, return_logits=True) # TODO: Make a proper way of returning the model output. logits = logits.detach() if kwargs.get("global_logits"): @@ -186,7 +181,16 @@ def _logits_loss_forward_backward( logits.detach() ) return None, None - elif self._config.cross_entropy_splits == 1: + + input_ = input_.flatten(0, -2) + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + sequence_dim = 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 + if loss_mask is not None: + if self._sequence_parallel_logits: + loss_mask = split_op(loss_mask, self._parallel_dim.group, sequence_dim) + loss_mask = loss_mask.flatten() + + if self._config.cross_entropy_splits == 1: losses_, input_grad = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs) else: input_grad = torch.empty_like(input_) @@ -309,7 +313,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, dtype=DataType.float32, ) - for name, loss_config in self._config.losses.values() + for name, loss_config in self._config.losses.items() ), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index dc7f63299..9ed4a95f3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -166,8 +166,7 @@ def _validate(self) -> None: else: prediction_heads = 1 - expected_names = {name for name in (head.distillation_model, head.dpo_reference_model) if name is not None} - expected_names.update(self.model.base_model.decoder.get_distillation_models()) + expected_names = head.get_distillation_models() | self.model.base_model.decoder.get_distillation_models() Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 955fa534c..92832f09f 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -383,6 +383,8 @@ def test_huggingface_model(model_testing_config, get_convert_path): print(name) output = model(test_input, **kwargs) # TODO: Make a generic comparison util. + print("AAA", output_ref.logits.shape) + print("BBB", output.logits.shape) CompareConfig().compare_tensors( {"samples": output_ref.logits, "shape": output_ref.logits.shape, "step": 0}, {"samples": output.logits, "shape": output.logits.shape, "step": 0}, From a2ff5fb94f66bd4f51b0dc9e566805473da18f0e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 19 Jan 2026 12:52:55 -0500 Subject: [PATCH 47/51] fixes --- tests/models/test_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 92832f09f..955fa534c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -383,8 +383,6 @@ def test_huggingface_model(model_testing_config, get_convert_path): print(name) output = model(test_input, **kwargs) # TODO: Make a generic comparison util. - print("AAA", output_ref.logits.shape) - print("BBB", output.logits.shape) CompareConfig().compare_tensors( {"samples": output_ref.logits, "shape": output_ref.logits.shape, "step": 0}, {"samples": output.logits, "shape": output.logits.shape, "step": 0}, From 44bad568967e427756b499d250feca1b3f612ab7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 19 Jan 2026 17:38:17 -0500 Subject: [PATCH 48/51] fixes --- fast_llm/functional/entropy_loss.py | 6 ++-- fast_llm/layers/block/config.py | 10 +++---- fast_llm/layers/decoder/block.py | 14 ++++----- fast_llm/layers/decoder/config.py | 17 ++++------- fast_llm/layers/language_model/config.py | 6 ++++ fast_llm/layers/language_model/head.py | 8 ++--- fast_llm/layers/language_model/loss/config.py | 26 ++++++++++++++--- fast_llm/models/gpt/config.py | 29 +++++-------------- fast_llm/models/gpt/model.py | 4 +-- tests/functional/test_entropy_loss.py | 2 +- tests/utils/model_configs.py | 10 ++----- tests/utils/subtest.py | 16 +++++----- 12 files changed, 73 insertions(+), 75 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 0c0fe9fa3..cd7eca950 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -84,7 +84,7 @@ def _fused_softmax_base( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: logits = logits.float() if logits_scale_factor != 1.0: - logits *= logits_scale_factor + logits = logits * logits_scale_factor logits_max = torch.max(logits, dim=dim, keepdim=True)[0] if group is not None: all_reduce(logits_max, op=ReduceOp.MAX, group=group) @@ -285,7 +285,7 @@ def _fused_entropy_loss_forward_backward( return loss, grad -_CROSS_ENTROPY_IMPLEMENTATIONS = { +_ENTROPY_LOSS_IMPLEMENTATIONS = { EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, @@ -333,7 +333,7 @@ def entropy_loss_forward_backward( temperature, ) else: - return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation]( + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( logits, target, loss_mask, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index b06f69ee5..fd76d36cb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -92,7 +92,7 @@ def get_layer( peft=peft, ) - def get_distillation_models(self) -> set[str]: + def get_reference_models(self) -> set[str]: return set() @@ -126,8 +126,8 @@ def layer_class(self) -> "type[FixedBlockSequence]": return FixedBlockSequence - def get_distillation_models(self) -> set[str]: - return self.block.get_distillation_models() + def get_reference_models(self) -> set[str]: + return self.block.get_reference_models() @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) @@ -176,10 +176,10 @@ def preprocessing_layers(self) -> dict[str, int]: # The index at which each block first appears. These blocks are used for preprocessing. return {name: self.expanded_pattern.index(name) for name in set(self.expanded_pattern)} - def get_distillation_models(self) -> set[str]: + def get_reference_models(self) -> set[str]: models = set() for block in self.blocks.values(): - models.update(block.get_distillation_models()) + models.update(block.get_reference_models()) return models @classmethod diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 3ae47c0a7..9d5166cc7 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -175,7 +175,7 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metr Assert.eq(teacher_tensor.shape, mixer_output.shape) # TODO: un-scaled loss for reporting? Average loss over layers? # L2 loss - activation_loss_factor = self._config.activation_distillation_factor + activation_loss_factor = self._config.distillation_loss_weight # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. # Handle possible padding by using pre-computed activation mask @@ -248,8 +248,8 @@ def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metr hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0) bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None # Logging - if losses is not None and self._activation_distillation_loss_name in losses: - losses[self._activation_distillation_loss_name].append(activation_loss.detach()) + if losses is not None and self._distillation_loss_name in losses: + losses[self._distillation_loss_name].append(activation_loss.detach()) # Per-layer metrics if metrics is not None: metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach() @@ -278,15 +278,15 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self.mlp.preprocess(kwargs) # TODO: add layer_index - _activation_distillation_loss_name = "activation_distillation_loss" + _distillation_loss_name = "activation_distillation_loss" def get_loss_definitions(self, count: int = 1) -> list[LossDef]: loss_definitions = [] - if self._config.activation_distillation_factor > 0.0 and self._config.distillation_model is not None: + if self._config.distillation_model is not None: loss_definitions.append( LossDef( - name=self._activation_distillation_loss_name, - formatted_name=self._activation_distillation_loss_name, + name=self._distillation_loss_name, + formatted_name=self._distillation_loss_name, count=count, ) ) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 875be5624..2f5990ccb 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -210,18 +210,13 @@ class DecoderBlockConfig(BlockConfig): desc="Name of the reference model to use for activation-level distillation.", hint=FieldHint.feature, ) - activation_distillation_factor: float = Field( - default=0.0, - desc="Factor to scale the activation-level distillation loss by.", + distillation_loss_weight: float = Field( + default=1.0, + desc="Weight for the scale the activation distillation loss.", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - def _validate(self) -> None: - super()._validate() - if self.activation_distillation_factor > 0.0 and self.distillation_model is None: - raise ValueError("Activation distillation requires a distillation_model.") - @property def layer_class(self) -> "type[DecoderBlock]": from fast_llm.layers.decoder.block import DecoderBlock @@ -245,7 +240,5 @@ def get_layer( return_input=return_input, ) - def get_distillation_models(self) -> set[str]: - if self.distillation_model is not None and self.activation_distillation_factor > 0.0: - return {self.distillation_model} - return set() + def get_reference_models(self) -> set[str]: + return set() if self.distillation_model is None else {self.distillation_model} diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 7b4d69a8e..5f58024e0 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -208,6 +208,9 @@ def _validate(self) -> None: def max_prediction_distance(self) -> int: return 1 + def get_reference_models(self) -> set[str]: + return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()} + @config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): @@ -292,3 +295,6 @@ def layer_class(self) -> "type[LanguageModel]": from fast_llm.layers.language_model.language_model import LanguageModel return LanguageModel + + def get_reference_models(self) -> set[str]: + return self.decoder.get_reference_models() | self.head.get_reference_models() diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 2de1ae726..3d526bad1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -227,7 +227,7 @@ def _logits_loss_forward_backward( if losses is not None: losses[self.get_full_loss_name(LM_HEAD_LOSS_NAME)].append(loss) - if len(self._config.losses) > 1: + if len(self._config.losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._config.losses.values()): for name, loss_ in losses_.items(): if self._config.cross_entropy_splits != 1: loss_ /= self._config.cross_entropy_splits @@ -246,12 +246,11 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, torch.Tensor | None]: - group = self._parallel_dim.group if self._vocab_parallel else None logits, context = output_parallel_linear_forward( input_=input_, weight=self.output_weights, bias=None, - group=group, + group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) @@ -287,7 +286,7 @@ def _logits_loss_forward_backward_partial( logits, loss_mask, grad_output=None if grad_output == 0.0 else grad_output, - group=group, + group=self._parallel_dim.group, logits_scale_factor=self._config.logits_scale_factor, kwargs=kwargs, prediction_distance=self._prediction_distance, @@ -295,6 +294,7 @@ def _logits_loss_forward_backward_partial( split_index=split_index, num_splits=self._config.cross_entropy_splits, sequence_parallel_logits=self._sequence_parallel_logits, + vocab_parallel=self._vocab_parallel, ) losses[loss_name] = loss.detach() if grad_ is not None: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 551554132..8a9f4251d 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -42,10 +42,14 @@ def get_loss( split_index: int = 0, num_splits: int = 1, sequence_parallel_logits: bool = False, + vocab_parallel: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": raise NotImplementedError() + def get_reference_models(self) -> set[str]: + return set() + @config_class(dynamic_type={LanguageModelLossConfig: "label"}) class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): @@ -76,6 +80,7 @@ def get_loss( split_index: int = 0, num_splits: int = 1, sequence_parallel_logits: bool = False, + vocab_parallel: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": import torch @@ -109,7 +114,7 @@ def get_loss( if ( TritonConfig.TRITON_ENABLED and torch.cuda.is_available() - and group is None + and (group is None or not vocab_parallel) and self.loss_type == EntropyLossType.cross_entropy ): implementation = EntropyLossImplementation.triton @@ -121,7 +126,7 @@ def get_loss( labels, None, # Labels are already masked grad_output=grad_output, - group=group, + group=group if vocab_parallel else None, implementation=implementation, logits_scale_factor=logits_scale_factor, target_format=TargetFormat.labels, @@ -144,6 +149,7 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): hint=FieldHint.performance, ) reference_model: str = Field( + default="teacher", desc="Name of the reference model for knowledge distillation.", hint=FieldHint.feature, ) @@ -167,6 +173,7 @@ def get_loss( split_index: int = 0, num_splits: int = 1, sequence_parallel_logits: bool = False, + vocab_parallel: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.entropy_loss import entropy_loss_forward_backward @@ -196,7 +203,7 @@ def get_loss( reference_model_logits, loss_mask, grad_output=grad_output, - group=group, + group=group if vocab_parallel else None, implementation=implementation, logits_scale_factor=logits_scale_factor, temperature=self.temperature, @@ -204,6 +211,9 @@ def get_loss( entropy_loss_type=self.loss_type, ) + def get_reference_models(self) -> set[str]: + return {self.reference_model} + @config_class(dynamic_type={LanguageModelLossConfig: "dpo"}) class LanguageModelDPOLossConfig(LanguageModelLossConfig): @@ -236,6 +246,7 @@ def get_loss( split_index: int = 0, num_splits: int = 1, sequence_parallel_logits: bool = False, + vocab_parallel: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.functional.dpo import compute_dpo_loss @@ -244,6 +255,8 @@ def get_loss( raise NotImplementedError() if prediction_distance > 0: raise NotImplementedError() + if vocab_parallel and group is not None: + raise NotImplementedError() if logits_scale_factor != 1.0: # TODO: Make more efficient. @@ -271,6 +284,9 @@ def get_loss( grad_output=grad_output, ) + def get_reference_models(self) -> set[str]: + return {self.reference_model} + @config_class(dynamic_type={LanguageModelLossConfig: "z_loss"}) class LanguageModelZLossConfig(LanguageModelLossConfig): @@ -291,12 +307,14 @@ def get_loss( split_index: int = 0, num_splits: int = 1, sequence_parallel_logits: bool = False, + vocab_parallel: bool = False, kwargs: dict[str, typing.Any], ) -> "tuple[torch.Tensor, torch.Tensor | None]": from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward # TODO: Support vocab_parallel - assert group is None + if vocab_parallel and group is not None: + raise NotImplementedError() return z_loss_forward_backward( logits, diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 9ed4a95f3..a315beecc 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -11,7 +11,7 @@ from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.models.gpt.conversion.config import ( Apriel2TextCheckpointFormat, AprielHybridSSMCheckpointFormat, @@ -159,29 +159,14 @@ def _validate(self) -> None: Assert.geq(self.model.base_model.embeddings.num_position_embeddings, self.batch.sequence_length) # TODO: Avoid digging inside the model. - head = self.model.base_model.head - if isinstance(head, MultiTokenPredictionConfig): - prediction_heads = head.prediction_heads - head = head.head - else: - prediction_heads = 1 - - expected_names = head.get_distillation_models() | self.model.base_model.decoder.get_distillation_models() - Assert.eq(self.reference_models.keys(), expected_names) + Assert.eq(self.reference_models.keys(), self.model.base_model.get_reference_models()) for reference_model in self.reference_models.values(): - reference_head = reference_model.model.base_model.head - if isinstance(reference_head, MultiTokenPredictionConfig): - reference_prediction_heads = reference_head.prediction_heads - reference_head = reference_head.heads - else: - reference_prediction_heads = 1 - Assert.geq(reference_prediction_heads, prediction_heads) - - Assert.none(reference_head.distillation_model) - Assert.none(reference_head.dpo_reference_model) - # TODO: Support more LM head features. - Assert.none(reference_head.cross_entropy_splits) + Assert.geq( + reference_model.model.base_model.head.max_prediction_distance, + self.model.base_model.head.max_prediction_distance, + ) + Assert.empty(reference_model.model.base_model.get_reference_models()) Assert.eq( reference_model.model.base_model.embeddings.vocab_parallel, self.model.base_model.embeddings.vocab_parallel, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 8de6822fd..bd2932984 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -167,7 +167,7 @@ def preprocess_batch( if preprocessed_meta is None: preprocessed_meta = self.preprocess_meta(batch, phase) - distillation_models = self._config.decoder.get_distillation_models() + distillation_models = self._config.decoder.get_reference_models() # TODO: Support multiple distillation models? assert len(distillation_models) <= 1 reference_logits = [{} for _ in preprocessed_meta] @@ -273,7 +273,7 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False labels = torch.where(loss_mask, labels, -100) - if self._config.head.get_distillation_models(): # loss masks only used for distillation currently + if self._config.head.get_reference_models(): # loss masks only used for distillation currently # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py index 4f3f5b6cb..9c06c1919 100644 --- a/tests/functional/test_entropy_loss.py +++ b/tests/functional/test_entropy_loss.py @@ -162,7 +162,7 @@ def test_run_entropy_loss_distributed(run_parallel_script, result_path): (result_path / "test_entropy_loss",), world_size=2, backend=DistributedBackend.gloo, - use_cpu=True, # Disable device count check. + use_cuda=False, # Disable device count check. ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4e1e9d507..5e7526377 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -564,12 +564,8 @@ def update_and_add_testing_config( "mistral", "mistral_distill_logits", updates={ - ("model", "base_model", "head", "distillation_model"): "teacher", ("model", "base_model", "head", "losses"): { - "distillation_loss": { - "type": "reverse_kl_distillation", - "factor": 1.0, - }, + "distillation": {"type": "distillation", "loss_type": "reverse_kl", "reference_model": "teacher"}, }, ("batch", "use_loss_masking_spans"): True, ("reference_models"): { @@ -598,9 +594,9 @@ def update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ - ("model", "base_model", "head", "losses", "distillation_loss", "factor"): 0.001, + ("model", "base_model", "head", "losses", "distillation", "weight"): 0.001, ("model", "base_model", "decoder", "block", "distillation_model"): "teacher", - ("model", "base_model", "decoder", "block", "activation_distillation_factor"): 0.1, + ("model", "base_model", "decoder", "block", "distillation_loss_weight"): 0.1, ("reference_models"): { "teacher": { "model": {"base_model": copy.deepcopy(_mistral_base_model)}, diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index a30440ad1..b69bdace2 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -28,13 +28,13 @@ def __init__( timeout: float = 20.0, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, - use_cpu: bool = False, + use_cuda: bool = False, ) -> None: self._do_capture = do_capture self._timeout = timeout self._init_method = init_method self._backend = backend - self._use_cpu = use_cpu + self._use_cuda = use_cuda def __enter__(self): if self._do_capture: @@ -43,7 +43,7 @@ def __enter__(self): ) self._pool = ProcessGroupPool( - timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cpu=self._use_cpu + timeout=self._timeout, init_method=self._init_method, backend=self._backend, use_cuda=self._use_cuda ).__enter__() self._rank = self._pool.rank self._world_size = self._pool.world_size @@ -90,7 +90,7 @@ def __init__( self._path = base_path / name self._name = name self._num_gpus = num_gpus - self._skip = self._test_context._world_size < self._num_gpus and not self._test_context._use_cpu + self._skip = self._test_context._world_size < self._num_gpus and not self._test_context._use_cuda self._do_run = self._test_context._rank < num_gpus and not self._skip self._do_capture = self._test_context._do_capture and self._do_run self._success = False @@ -245,14 +245,14 @@ def parallel_worker( init_method: str, backend: DistributedBackend, do_capture: bool, - use_cpu: bool, + use_cuda: bool, fn: typing.Callable, fn_args: typing.Sequence[typing.Any], ): DistributedConfig.default_rank = rank DistributedConfig.default_world_size = world_size DistributedConfig.default_local_world_size = world_size - with DistributedTestContext(do_capture, 60, init_method, backend, use_cpu) as test_context: + with DistributedTestContext(do_capture, 60, init_method, backend, use_cuda) as test_context: fn(test_context, *fn_args) @@ -264,7 +264,7 @@ def do_run_parallel_script( world_size: int, timeout: float = 240, backend: DistributedBackend = DistributedBackend.nccl, - use_cpu: bool = False, # Use CPU device in process group pool. May be used to disable device count check + use_cuda: bool = True, # Use CPU device in process group pool. May be used to disable device count check ): if "PYTHONHASHSEED" not in os.environ: os.environ["PYTHONHASHSEED"] = "0" @@ -274,7 +274,7 @@ def do_run_parallel_script( ) torch.multiprocessing.spawn( parallel_worker, - args=(world_size, f"tcp://localhost:{port}", backend, do_capture, use_cpu, fn, fn_args), + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, use_cuda, fn, fn_args), nprocs=world_size, join=False, ).join(timeout, grace_period=5) From e626a0378cfb8631c3f3d0c714998a49fb367f3f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 20 Jan 2026 09:54:04 -0500 Subject: [PATCH 49/51] fix --- tests/utils/subtest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index b69bdace2..b6764c0e2 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -28,7 +28,7 @@ def __init__( timeout: float = 20.0, init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, - use_cuda: bool = False, + use_cuda: bool = True, ) -> None: self._do_capture = do_capture self._timeout = timeout @@ -90,7 +90,7 @@ def __init__( self._path = base_path / name self._name = name self._num_gpus = num_gpus - self._skip = self._test_context._world_size < self._num_gpus and not self._test_context._use_cuda + self._skip = self._test_context._world_size < self._num_gpus and self._test_context._use_cuda self._do_run = self._test_context._rank < num_gpus and not self._skip self._do_capture = self._test_context._do_capture and self._do_run self._success = False From 336560e556b9316ad2b54f93b8825f6715cab30b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 21 Jan 2026 15:20:27 -0500 Subject: [PATCH 50/51] stuff --- fast_llm/functional/autograd.py | 11 + fast_llm/functional/dpo.py | 49 --- fast_llm/functional/entropy_loss.py | 3 + fast_llm/layers/common/auxiliary_loss.py | 70 ----- fast_llm/layers/decoder/block.py | 2 +- .../layers/decoder/mlp/mixture_of_experts.py | 18 +- fast_llm/layers/language_model/head.py | 121 ++++---- fast_llm/layers/language_model/loss/config.py | 278 +++++------------- fast_llm/layers/language_model/loss/dpo.py | 81 +++++ .../language_model/loss/entropy_loss.py | 86 ++++++ fast_llm/layers/language_model/loss/grpo.py | 64 ++++ fast_llm/layers/language_model/loss/loss.py | 121 ++++++++ fast_llm/layers/language_model/loss/z_loss.py | 43 +++ tests/functional/test_functional.py | 18 +- tests/layers/test_lm_head.py | 56 ++-- 15 files changed, 582 insertions(+), 439 deletions(-) delete mode 100644 fast_llm/functional/dpo.py delete mode 100644 fast_llm/layers/common/auxiliary_loss.py create mode 100644 fast_llm/layers/language_model/loss/dpo.py create mode 100644 fast_llm/layers/language_model/loss/entropy_loss.py create mode 100644 fast_llm/layers/language_model/loss/grpo.py create mode 100644 fast_llm/layers/language_model/loss/loss.py create mode 100644 fast_llm/layers/language_model/loss/z_loss.py diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index 1428ed25e..cea5f6ee2 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -60,3 +60,14 @@ def call(*args, **kwargs): def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.Tensor: # noqa return context + + +class AuxiliaryLoss(torch.autograd.Function): + @staticmethod + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa + ctx.grad = torch.full_like(aux_loss, grad) + return input_ + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa + return grad_output, ctx.grad, None diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py deleted file mode 100644 index c5ae48eba..000000000 --- a/fast_llm/functional/dpo.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch - - -def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): - # Gather log probabilities corresponding to the target tokens - return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - -def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): - return sum( - log_probabilities[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(spans) - for begin, end in sample_spans - ) - - -def compute_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - reference_model_logits: torch.Tensor, - chosen_spans: list[list[tuple[int, int]]], - rejected_spans: list[list[tuple[int, int]]], - beta: float, - grad_output: float | None, -) -> tuple[torch.Tensor, torch.Tensor]: - with torch.enable_grad(): - logits_ = logits.float().detach().requires_grad_() - reference_model_logits_ = reference_model_logits.float().detach() - - policy_log_probabilities = _get_target_log_probabilities(logits_, targets) - policy_log_ratios = _get_target_log_probability_for_spans( - policy_log_probabilities, chosen_spans - ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - - reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) - reference_log_ratios = _get_target_log_probability_for_spans( - reference_log_probabilities, chosen_spans - ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - - # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= - losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) - - if grad_output is None: - loss = None - else: - loss = losses.mean() - loss.backward(torch.full_like(loss, grad_output)) - loss.detach() - return loss.detach(), logits_.grad.detach().to(logits.dtype) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index cd7eca950..757832a71 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -96,6 +96,7 @@ def _fused_softmax_base( return logits_norm, exp_logits, sum_exp_logits +@torch.compile def _fused_reverse_kl_base( logits: torch.Tensor, target: torch.Tensor, @@ -134,6 +135,7 @@ def _fused_reverse_kl_base( return per_sample_loss, grad +@torch.compile def _fused_cross_entropy_base( logits: torch.Tensor, target: torch.Tensor, @@ -177,6 +179,7 @@ def _fused_cross_entropy_base( return per_sample_loss, grad +@torch.compile def _fused_cross_entropy_base_from_labels( logits: torch.Tensor, target: torch.Tensor, diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py deleted file mode 100644 index 97e04de16..000000000 --- a/fast_llm/layers/common/auxiliary_loss.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - - -class AuxiliaryLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa - ctx.grad = torch.full_like(aux_loss, grad) - return input_ - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: # noqa - return grad_output, ctx.grad, None - - -@torch.compile -def z_loss( - logits: torch.Tensor, logits_scale_factor: float = 1.0, loss_mask: "torch.Tensor | None" = None -) -> torch.Tensor: - out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 - if loss_mask is not None: - out = out * loss_mask - return torch.mean(out) - - -def auxiliary_z_loss( - logits: torch.Tensor, - z_loss_factor: float, - training: bool, - grad_scale: float | None = None, - losses: dict | None = None, - loss_name: str | None = None, - logits_scale_factor: float = 1.0, - loss_mask: "torch.Tensor | None" = None, -) -> torch.Tensor: - if losses is not None or (training and grad_scale is not None): - loss = z_loss(logits, logits_scale_factor, loss_mask) - if losses is not None and loss_name is not None: - losses[loss_name].append(loss.detach()) - if training and grad_scale is not None: - logits = AuxiliaryLoss.apply(logits, loss, z_loss_factor * grad_scale) - - return logits - - -def z_loss_forward_backward( - logits: torch.Tensor, - grad_output: float | None = None, - loss_mask: "torch.Tensor | None" = None, - logits_scale_factor: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Compute z-loss and its gradient. - - Z-loss = mean(logsumexp(logits, dim=-1) ** 2) - - Returns: - loss: The z-loss value (unscaled) - grad: The gradient w.r.t. logits (scaled by grad_scale), or None if grad_scale is None - """ - - with torch.set_grad_enabled(grad_output is not None): - logits_ = logits.detach().requires_grad_(grad_output is not None) - loss = z_loss(logits_, logits_scale_factor, loss_mask) - if grad_output is None: - grad = None - else: - loss.backward(torch.full_like(loss, grad_output)) - grad = logits_.grad.detach().to(logits.dtype) - - return loss, grad diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 9d5166cc7..8f6e360fd 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -9,9 +9,9 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import BlockWithBiasConfig, DecoderBlockConfig from fast_llm.tensor import TensorMeta diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index fd3647389..413a88ed6 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -9,14 +9,15 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, auxiliary_z_loss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType from fast_llm.layers.decoder.mlp.mlp import MLPBase +from fast_llm.layers.language_model.loss.z_loss import z_loss from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -102,14 +103,13 @@ def _forward( # Apply z_loss if applicable if self._config.z_loss_coefficient > 0.0: - logits = auxiliary_z_loss( - logits, - self._config.z_loss_coefficient, - self.training, - grad_scale=kwargs.get("grad_output"), - losses=losses, - loss_name=MLPLossNames.router_z_loss, - ) + is_training = (grad_scale := kwargs.get("grad_output")) is not None and self.training + if is_training or losses is not None: + loss = z_loss(logits) + if losses is not None: + losses[MLPLossNames.router_z_loss].append(loss.detach()) + if is_training: + logits = AuxiliaryLoss.apply(logits, loss, self._config.z_loss_coefficient * grad_scale) # Apply input_jitter if applicable: if self.training and self._config.jitter_eps > 0.0: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 3d526bad1..e8c60ae9c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,4 +1,5 @@ import abc +import functools import logging import typing @@ -6,17 +7,16 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import gather_op, split_op +from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward +from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.block.block import Block from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LM_HEAD_LOSS_NAME, @@ -95,6 +95,19 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + self._losses = [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in self._config.losses.items() + ] def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss) @@ -163,11 +176,11 @@ def _logits_loss_forward_backward( self, input_: torch.Tensor, kwargs: dict, - losses: dict | None = None, + all_losses_dict: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: if not self.training: - logits, _ = self._logits_loss_forward_backward_partial(input_, None, kwargs, return_logits=True) + logits, _ = self._logits_loss_forward_backward_partial(input_, kwargs, return_logits=True) # TODO: Make a proper way of returning the model output. logits = logits.detach() if kwargs.get("global_logits"): @@ -183,15 +196,9 @@ def _logits_loss_forward_backward( return None, None input_ = input_.flatten(0, -2) - loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) - sequence_dim = 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - if loss_mask is not None: - if self._sequence_parallel_logits: - loss_mask = split_op(loss_mask, self._parallel_dim.group, sequence_dim) - loss_mask = loss_mask.flatten() if self._config.cross_entropy_splits == 1: - losses_, input_grad = self._logits_loss_forward_backward_partial(input_, loss_mask, kwargs) + loss_dict, input_grad = self._logits_loss_forward_backward_partial(input_, kwargs) else: input_grad = torch.empty_like(input_) tensors_split = [ @@ -200,48 +207,49 @@ def _logits_loss_forward_backward( if tensor is None else tensor.chunk(self._config.cross_entropy_splits) ) - for tensor in [input_, loss_mask, input_grad] + for tensor in [input_, input_grad] ] - for split_index, (partial_input_, loss_mask_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): - partial_losses_, grad_ = self._logits_loss_forward_backward_partial( + for split_index, (partial_input_, input_grad_) in enumerate(zip(*tensors_split, strict=True)): + partial_loss_dict, grad_ = self._logits_loss_forward_backward_partial( partial_input_, - loss_mask_, kwargs, split_index=split_index, ) # TODO: Avoid copy with explicit out argument. input_grad_.copy_(grad_) if split_index == 0: - losses_ = partial_losses_ + loss_dict = partial_loss_dict else: - for name in self._config.losses: - losses_[name] += partial_losses_[name] - - loss: torch.Tensor = sum( - (self.config.losses[name].weight * self._loss_coefficient / self._config.cross_entropy_splits) * loss_ - for name, loss_ in losses_.items() + Assert.eq(partial_loss_dict.keys(), loss_dict.keys()) + for name in loss_dict: + loss_dict[name] += partial_loss_dict[name] + + total_loss = sum( + (loss_.weight / self._config.cross_entropy_splits) * loss_dict[loss_.name] + for loss_ in self._losses + if loss_.weight != 0.0 and loss_.name in loss_dict ) + if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, op=ReduceOp.AVG, group=self._parallel_dim.group) + all_reduce(total_loss, op=ReduceOp.AVG, group=self._parallel_dim.group) - if losses is not None: - losses[self.get_full_loss_name(LM_HEAD_LOSS_NAME)].append(loss) - if len(self._config.losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._config.losses.values()): - for name, loss_ in losses_.items(): + if all_losses_dict is not None: + all_losses_dict[self._total_loss_name].append(total_loss) + if len(self._losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._losses): + for name, loss_value in loss_dict.items(): if self._config.cross_entropy_splits != 1: - loss_ /= self._config.cross_entropy_splits + loss_value /= self._config.cross_entropy_splits if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss_, op=ReduceOp.AVG, group=self._parallel_dim.group) - losses[name].append(loss_) + all_reduce(loss_value, op=ReduceOp.AVG, group=self._parallel_dim.group) + all_losses_dict[name].append(loss_value) - return loss, input_grad + return total_loss, input_grad def _logits_loss_forward_backward_partial( self, input_: torch.Tensor, - loss_mask: torch.Tensor | None, kwargs: dict, split_index: int = 0, return_logits: bool = False, @@ -270,33 +278,14 @@ def _logits_loss_forward_backward_partial( return logits, None losses, grad = {}, None - for loss_name, loss_config in self._config.losses.items(): + for loss in self._losses: # losses are returned unscaled but the grads are already scaled - # TODO: ====== grad_output can't be None? - grad_output = kwargs.get(LanguageModelKwargs.grad_output) - if grad_output is not None: - grad_output = ( - grad_output - * self._loss_coefficient - * loss_config.weight - / (self._parallel_dim.size if self._sequence_parallel_logits else 1) - / self._config.cross_entropy_splits - ) - loss, grad_ = loss_config.get_loss( + loss_value, grad_ = loss.forward_backward( logits, - loss_mask, - grad_output=None if grad_output == 0.0 else grad_output, - group=self._parallel_dim.group, - logits_scale_factor=self._config.logits_scale_factor, - kwargs=kwargs, - prediction_distance=self._prediction_distance, - prediction_heads=self._prediction_heads, - split_index=split_index, - num_splits=self._config.cross_entropy_splits, - sequence_parallel_logits=self._sequence_parallel_logits, - vocab_parallel=self._vocab_parallel, + kwargs, + split_index, ) - losses[loss_name] = loss.detach() + losses[loss.name] = loss_value.detach() if grad_ is not None: # TODO: Accumulate grads in-place to reduce memory and compute overhead. grad = grad_ if grad is None else grad + grad_ @@ -305,22 +294,24 @@ def _logits_loss_forward_backward_partial( def get_loss_definitions(self, count: int = 1) -> list[LossDef]: return [ - LossDef(name=(name := self.get_full_loss_name(LM_HEAD_LOSS_NAME)), formatted_name=name, count=count), + LossDef(name=self._total_loss_name, formatted_name=self._total_loss_name, count=count), *( LossDef( - name=(name_ := self.get_full_loss_name(name)), - formatted_name=name_, + name=loss.name, + formatted_name=loss.name, count=count, dtype=DataType.float32, ) - for name, loss_config in self._config.losses.items() + for loss in self._losses ), ] - def get_full_loss_name(self, name) -> str: - if self._prediction_distance > 0: - name = f"{name}_{self._prediction_distance}" - return name + def _get_full_loss_name(self, name) -> str: + return name if self._prediction_distance == 0 else f"{name}_{self._prediction_distance}" + + @functools.cached_property + def _total_loss_name(self) -> str: + return self._get_full_loss_name(LM_HEAD_LOSS_NAME) @property def heads(self): diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 8a9f4251d..a6057d67f 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -1,18 +1,30 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType from fast_llm.layers.block.config import BlockKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch + pass + + from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss + from fast_llm.layers.language_model.loss.entropy_loss import ( + LanguageModelDistillationLoss, + LanguageModelLabelEntropyLoss, + ) + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss class LanguageModelLossKwargs(BlockKwargs): labels = "labels" chosen_spans = "chosen_spans" rejected_spans = "rejected_spans" + advantages = "advantages" + old_log_probabilities = "old_log_probabilities" @config_class(registry=True) @@ -26,25 +38,31 @@ class LanguageModelLossConfig(Config): valid=check_field(Assert.geq, 0.0), ) - def get_name(self, prediction_distance: int = 0) -> str: - return self._name if prediction_distance == 0 else f"{self._name}_{prediction_distance}" - - def get_loss( + def get_layer( self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - *, - group: "torch.distributed.ProcessGroup|None" = None, - logits_scale_factor: float = 1.0, + distributed_config: DistributedConfig, + name: str, prediction_distance: int = 0, prediction_heads: int = 1, - split_index: int = 0, - num_splits: int = 1, - sequence_parallel_logits: bool = False, vocab_parallel: bool = False, - kwargs: dict[str, typing.Any], - ) -> "tuple[torch.Tensor, torch.Tensor | None]": + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + ): + return self.loss_class( + self, + distributed_config, + name=name, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + vocab_parallel=vocab_parallel, + num_splits=num_splits, + logits_scale_factor=logits_scale_factor, + weight=weight, + ) + + @property + def loss_class(self) -> "type[LanguageModelLoss]": raise NotImplementedError() def get_reference_models(self) -> set[str]: @@ -67,71 +85,11 @@ class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): hint=FieldHint.performance, ) - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - *, - group: "torch.distributed.ProcessGroup|None" = None, - logits_scale_factor: float = 1.0, - prediction_distance: int = 0, - prediction_heads: int = 1, - split_index: int = 0, - num_splits: int = 1, - sequence_parallel_logits: bool = False, - vocab_parallel: bool = False, - kwargs: dict[str, typing.Any], - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - import torch - - from fast_llm.functional.entropy_loss import entropy_loss_forward_backward - - labels = kwargs[LanguageModelLossKwargs.labels] - - # MTP: Shift the labels - if prediction_heads > 1: - sequence_q_length = labels.size(1 - kwargs[LanguageModelLossKwargs.sequence_first]) + 1 - prediction_heads - if LanguageModelLossKwargs.sequence_q_dim in kwargs: - Assert.eq(sequence_q_length, kwargs[LanguageModelLossKwargs.sequence_q_dim].size) - label_slice = slice(prediction_distance, prediction_distance + sequence_q_length) - labels = labels[label_slice] if kwargs[LanguageModelLossKwargs.sequence_first] else labels[:, label_slice] - - labels = labels.flatten() - - # Get the local chunk. - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - labels = split_op(labels, group, 0) - - # Get the chunk for the current split. - if num_splits > 1: - labels = labels.chunk(num_splits)[split_index] - - implementation = self.implementation - if implementation == EntropyLossImplementation.auto: - if ( - TritonConfig.TRITON_ENABLED - and torch.cuda.is_available() - and (group is None or not vocab_parallel) - and self.loss_type == EntropyLossType.cross_entropy - ): - implementation = EntropyLossImplementation.triton - else: - implementation = EntropyLossImplementation.fused - - return entropy_loss_forward_backward( - logits, - labels, - None, # Labels are already masked - grad_output=grad_output, - group=group if vocab_parallel else None, - implementation=implementation, - logits_scale_factor=logits_scale_factor, - target_format=TargetFormat.labels, - entropy_loss_type=self.loss_type, - ) + @property + def loss_class(self) -> "type[LanguageModelLabelEntropyLoss]": + from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelLabelEntropyLoss + + return LanguageModelLabelEntropyLoss @config_class(dynamic_type={LanguageModelLossConfig: "distillation"}) @@ -160,56 +118,11 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0.0), ) - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - *, - group: "torch.distributed.ProcessGroup|None" = None, - logits_scale_factor: float = 1.0, - prediction_distance: int = 0, - prediction_heads: int = 1, - split_index: int = 0, - num_splits: int = 1, - sequence_parallel_logits: bool = False, - vocab_parallel: bool = False, - kwargs: dict[str, typing.Any], - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.entropy_loss import entropy_loss_forward_backward - - if prediction_distance > 0: - raise NotImplementedError() - - reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) - - # Get the local chunk. - if sequence_parallel_logits: - from fast_llm.core.ops import split_op + @property + def loss_class(self) -> "type[LanguageModelDistillationLoss]": + from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelDistillationLoss - reference_model_logits = split_op(reference_model_logits, group, 0) - - # Get the chunk for the current split. - if num_splits > 1: - reference_model_logits = reference_model_logits.chunk(num_splits)[split_index] - - implementation = ( - EntropyLossImplementation.fused - if self.implementation == EntropyLossImplementation.auto - else self.implementation - ) - return entropy_loss_forward_backward( - logits, - reference_model_logits, - loss_mask, - grad_output=grad_output, - group=group if vocab_parallel else None, - implementation=implementation, - logits_scale_factor=logits_scale_factor, - temperature=self.temperature, - target_format=TargetFormat.logits, - entropy_loss_type=self.loss_type, - ) + return LanguageModelDistillationLoss def get_reference_models(self) -> set[str]: return {self.reference_model} @@ -233,56 +146,11 @@ class LanguageModelDPOLossConfig(LanguageModelLossConfig): hint=FieldHint.feature, ) - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - *, - group: "torch.distributed.ProcessGroup|None" = None, - logits_scale_factor: float = 1.0, - prediction_distance: int = 0, - prediction_heads: int = 1, - split_index: int = 0, - num_splits: int = 1, - sequence_parallel_logits: bool = False, - vocab_parallel: bool = False, - kwargs: dict[str, typing.Any], - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.functional.dpo import compute_dpo_loss - - if num_splits > 1: - raise NotImplementedError() - if prediction_distance > 0: - raise NotImplementedError() - if vocab_parallel and group is not None: - raise NotImplementedError() - - if logits_scale_factor != 1.0: - # TODO: Make more efficient. - logits = logits * logits_scale_factor - - reference_model_logits = kwargs[f"{self.reference_model}_logits"].flatten(0, -2) - target = kwargs[LanguageModelLossKwargs.labels] - - if sequence_parallel_logits: - from fast_llm.core.ops import split_op - - reference_model_logits = split_op(reference_model_logits, group, 0) - target = split_op(target, group, 0) - - chosen_spans = kwargs[LanguageModelLossKwargs.chosen_spans] - rejected_spans = kwargs[LanguageModelLossKwargs.rejected_spans] - - return compute_dpo_loss( - logits=logits, - targets=target, - reference_model_logits=reference_model_logits, - chosen_spans=chosen_spans, - rejected_spans=rejected_spans, - beta=self.beta, - grad_output=grad_output, - ) + @property + def loss_class(self) -> "type[LanguageModelDPOLoss]": + from fast_llm.layers.language_model.loss.dpo import LanguageModelDPOLoss + + return LanguageModelDPOLoss def get_reference_models(self) -> set[str]: return {self.reference_model} @@ -294,31 +162,23 @@ class LanguageModelZLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False - def get_loss( - self, - logits: "torch.Tensor", - loss_mask: "torch.Tensor | None", - grad_output: float | None = None, - *, - group: "torch.distributed.ProcessGroup|None" = None, - logits_scale_factor: float = 1.0, - prediction_distance: int = 0, - prediction_heads: int = 1, - split_index: int = 0, - num_splits: int = 1, - sequence_parallel_logits: bool = False, - vocab_parallel: bool = False, - kwargs: dict[str, typing.Any], - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - from fast_llm.layers.common.auxiliary_loss import z_loss_forward_backward - - # TODO: Support vocab_parallel - if vocab_parallel and group is not None: - raise NotImplementedError() - - return z_loss_forward_backward( - logits, - grad_output, - loss_mask, - logits_scale_factor, - ) + @property + def loss_class(self) -> "type[LanguageModelZLoss]": + from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss + + return LanguageModelZLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelLossConfig): + + _abstract: typing.ClassVar[bool] = False + + epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") + epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") + + @property + def loss_class(self) -> "type[LanguageModelGRPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + return LanguageModelGRPOLoss diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py new file mode 100644 index 000000000..15c4c788c --- /dev/null +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -0,0 +1,81 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelDPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelDPOLoss[ConfigType: LanguageModelDPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._prediction_distance > 0: + raise NotImplementedError() + if self._num_splits > 1: + raise NotImplementedError() + if self._prediction_distance > 0: + raise NotImplementedError() + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + + if self._get_loss_mask(kwargs, split_index) is not None: + raise NotImplementedError() + + return loss_forward_backward( + self._get_grad_output(kwargs), + dpo_loss, + logits, + self._get_labels(kwargs, split_index), + self._get_reference_model_logits(self._config.reference_model, kwargs, split_index), + kwargs[LanguageModelLossKwargs.chosen_spans], + kwargs[LanguageModelLossKwargs.rejected_spans], + self._config.beta, + ) + + +def dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], + beta: float = 1.0, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + + policy_log_probabilities = _get_target_log_probabilities(logits, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) + + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) + + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? ======= + return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean() + + +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + + +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py new file mode 100644 index 000000000..3ae87d2e9 --- /dev/null +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -0,0 +1,86 @@ +import typing + +import torch + +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.layers.language_model.loss.config import ( + LanguageModelDistillationLossConfig, + LanguageModelLabelEntropyLossConfig, +) +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss + + +def _get_imlementation( + default: EntropyLossImplementation = EntropyLossImplementation.auto, + loss_type: EntropyLossType = EntropyLossType.cross_entropy, + vocab_parallel: bool = False, +) -> EntropyLossImplementation: + # Vocab parallel requires fused. + if vocab_parallel: + assert default in (EntropyLossImplementation.auto, EntropyLossImplementation.fused) + return EntropyLossImplementation.fused + + # Triton only available for cross_entropy + if TritonConfig.TRITON_ENABLED and torch.cuda.is_available() and loss_type == EntropyLossType.cross_entropy: + return EntropyLossImplementation.triton if default == EntropyLossImplementation.auto else default + else: + assert default != EntropyLossImplementation.triton + + # Otherwise, use fused. + return EntropyLossImplementation.fused if default == EntropyLossImplementation.auto else default + + +class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._implementation = _get_imlementation( + self._config.implementation, self._config.loss_type, self._vocab_parallel + ) + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return entropy_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + None, # Labels are already masked + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + implementation=self._implementation, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels, + entropy_loss_type=self._config.loss_type, + ) + + +class LanguageModelDistillationLoss[ConfigType: LanguageModelDistillationLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self._prediction_distance > 0: + raise NotImplementedError() + + self._implementation = _get_imlementation( + self._config.implementation, self._config.loss_type, self._vocab_parallel + ) + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return entropy_loss_forward_backward( + logits, + self._get_reference_model_logits(self._config.reference_model, kwargs, split_index), + self._get_loss_mask(kwargs, split_index), + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + implementation=self._implementation, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.logits, + entropy_loss_type=self._config.loss_type, + ) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py new file mode 100644 index 000000000..eeac6b9c4 --- /dev/null +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -0,0 +1,64 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Support vocab_parallel + if self._num_splits > 1: + raise NotImplementedError() + if self._prediction_distance > 0: + raise NotImplementedError() + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return loss_forward_backward( + self._get_grad_output(kwargs), + grpo_loss, + logits, + self._get_loss_mask(kwargs, split_index), + self._get_labels(kwargs, split_index), + advantages, + old_log_probabilities, + self._config.epsilon_low, + self._config.epsilon_high, + self._logits_scale_factor, + ) + + +def grpo_loss( + logits: torch.Tensor, + loss_mask: "torch.Tensor | None", + labels: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + if logits_scale_factor != 1.0: + # TODO: Make more efficient. + logits = logits * logits_scale_factor + + # Log probabilities. + logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + target_log_probabilities = torch.gather(logprobs, dim=2, index=labels.unsqueeze(2)).squeeze(2) + probability_ratio = torch.exp(target_log_probabilities - old_log_probabilities) + loss = -torch.min( + probability_ratio * advantages, + torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, + ) + if loss_mask is not None: + loss = loss * loss_mask + return loss.mean() diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py new file mode 100644 index 000000000..711560a8f --- /dev/null +++ b/fast_llm/layers/language_model/loss/loss.py @@ -0,0 +1,121 @@ +import abc +import typing + +import torch + +from fast_llm.config import Configurable +from fast_llm.core.ops import split_op +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs +from fast_llm.utils import Assert + + +class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + name: str, + prediction_distance: int = 0, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + ): + super().__init__(config) + Assert.in_range(prediction_distance, 0, prediction_heads) + self._prediction_distance = prediction_distance + self._prediction_heads = prediction_heads + self._name = name + self._num_splits = num_splits + self._logits_scale_factor = logits_scale_factor + self._weight = weight * self._config.weight + self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel + self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel + self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + @abc.abstractmethod + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + pass + + @property + def name(self) -> str: + return self._name + + @property + def weight(self) -> float: + return self._weight + + def _prepare_target( + self, + target: torch.Tensor | None, + kwargs: dict[str, typing.Any], + split_index: int = 0, + *, + multi_token_format: bool = False, + ) -> torch.Tensor | None: + # MTP shift + if multi_token_format and self._prediction_heads > 1: + sequence_first: bool = kwargs[LanguageModelLossKwargs.sequence_first] + sequence_q_length = target.size(1 - sequence_first) + 1 - self._prediction_heads + target_slice = slice(self._prediction_distance, self._prediction_distance + sequence_q_length) + target = target[target_slice] if sequence_first else target[:, target_slice] + + # Flatten the batch and sequence dimensions. + target = target.flatten(0, 1) + + # Get the local chunk. + if self._sequence_parallel: + target = split_op(target, self._parallel_dim.group, 0) + + # Get the chunk for the current split. + if self._num_splits > 1: + target = target.chunk(self._num_splits)[split_index] + + return target + + def _get_grad_output(self, kwargs: dict[str, typing.Any]) -> float | None: + grad_output = kwargs.get(LanguageModelKwargs.grad_output) + if grad_output is not None: + grad_output = ( + grad_output + * self._weight + / (self._parallel_dim.size if self._sequence_parallel else 1) + / self._num_splits + ) + return grad_output + + def _get_labels(self, kwargs: dict[str, typing.Any], split_index: int = 0): + return self._prepare_target( + kwargs[LanguageModelLossKwargs.labels], kwargs, split_index, multi_token_format=True + ) + + def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): + loss_mask = kwargs.get(LanguageModelKwargs.loss_mask) + return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) + + def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): + return self._prepare_target(kwargs[f"{reference_model}_logits"], kwargs, split_index) + + +def loss_forward_backward( + grad_output: float | None, fn: typing.Callable, input_: torch.Tensor, *args, **kwargs +) -> tuple[torch.Tensor, torch.Tensor | None]: + with torch.set_grad_enabled(grad_output is not None): + input_ = input_.detach().requires_grad_(grad_output is not None) + loss = fn(input_, *args, **kwargs) + if grad_output is None: + grad = None + else: + loss.backward(torch.full_like(loss, grad_output)) + grad = input_.grad.detach().to(input_.dtype) + + return loss, grad diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py new file mode 100644 index 000000000..c94851bf2 --- /dev/null +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -0,0 +1,43 @@ +import typing + +import torch + +from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward + + +class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO: Support vocab_parallel + if self._vocab_parallel: + raise NotImplementedError() + + def forward_backward( + self, + logits: "torch.Tensor", + kwargs: dict[str, typing.Any], + split_index: int = 0, + ) -> "tuple[torch.Tensor, torch.Tensor | None]": + return loss_forward_backward( + self._get_grad_output(kwargs), + z_loss, + logits, + self._get_loss_mask(kwargs, split_index), + self._logits_scale_factor, + ) + + +@torch.compile +def z_loss( + logits: torch.Tensor, + loss_mask: "torch.Tensor | None" = None, + logits_scale_factor: float = 1.0, +) -> torch.Tensor: + """ + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + """ + out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 + if loss_mask is not None: + out = out * loss_mask + return torch.mean(out) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 76c0841d9..840e3846d 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -3,9 +3,9 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans @@ -61,20 +61,14 @@ def reference_dpo_loss( def test_dpo_loss(): - random_state = np.random.RandomState(0) - logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32).requires_grad_() - reference_model_logits = torch.from_numpy(random_state.normal(size=(10, 50, 100))).to(torch.float32) - targets = torch.from_numpy(random_state.randint(0, 100, (10, 50))) + logits = torch.normal(0, 1, (10, 50, 100)) + reference_model_logits = torch.normal(0, 1, (10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) + spans = get_random_spans(np.full(10, 50), 0, 10) - spans = get_random_spans(np.full(10, 50), 0, 10, random_state) - - fastllm_loss, fast_llm_grad = compute_dpo_loss( - logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 - ) + fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2]) reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) - reference_loss.backward() Assert.rms_close(fastllm_loss, reference_loss, 1e-5) - Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @pytest.mark.parametrize("gated", [True, False]) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9aa53fcc4..1d08986f8 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -32,7 +32,7 @@ class LMHeadTestConfig: loss_masking: bool = False prediction_heads: int = 1 tied_embedding_weight: bool = False - cross_entropy_splits: int = 1 + num_splits: int = 1 @property def actual_label_loss(self): @@ -46,7 +46,7 @@ def get_config(self) -> GPTModelConfig: head_config = { "normalization": {"type": "rms_norm"}, "logits_scale_factor": self.logits_scale_factor, - "cross_entropy_splits": self.cross_entropy_splits, + "cross_entropy_splits": self.num_splits, } losses = {} if self.label_loss is not False: @@ -200,28 +200,36 @@ def get_reference_outputs( return total_loss.detach(), input_.grad, logit_weight.grad, normalization_weight.grad, losses -_lm_head_test_configs = ( - # TODO: Test DPO loss. - # TODO: Add more configs - # TODO: Add distributed test - LMHeadTestConfig("default"), - LMHeadTestConfig("bfloat16", compute_dtype=DataType.bfloat16), - LMHeadTestConfig("full_precision_residual", full_precision_residual=True), - LMHeadTestConfig("sequence_first", sequence_first=True), - LMHeadTestConfig("logit_scaling", logits_scale_factor=5.0), - LMHeadTestConfig("tied_embedding_weight", tied_embedding_weight=True), - LMHeadTestConfig("multi_token_prediction", prediction_heads=2), - LMHeadTestConfig("cross_entropy_splits", cross_entropy_splits=2, sequence_first=True), - LMHeadTestConfig("loss_masking", loss_masking=True), - LMHeadTestConfig("label_loss", label_loss=True), - LMHeadTestConfig("distillation_loss", distillation_loss=True), - LMHeadTestConfig("distillation_loss_masked", distillation_loss=True, loss_masking=True), - LMHeadTestConfig("z_loss", z_loss=True), - LMHeadTestConfig("z_loss_masked", z_loss=True, loss_masking=True), - LMHeadTestConfig("label_and_distillation_loss", label_loss=True, distillation_loss=True), - LMHeadTestConfig("label_and_z_loss_weighted", label_loss=True, z_loss=0.5), - LMHeadTestConfig("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0), -) +_lm_head_test_configs = [] + + +def _add_configs(base_name: str, **kwargs): + # Loss masking and splits are important and error-prone, so we test them for all scenarios. + for loss_masking in (False, True): + for num_splits in (1, 2): + _lm_head_test_configs.append( + LMHeadTestConfig( + f"{base_name}{"_masked" if loss_masking else ""}{"" if num_splits == 1 else "_split"}", + loss_masking=loss_masking, + num_splits=num_splits, + **kwargs, + ) + ) + + +_add_configs("default") +_add_configs("bfloat16", compute_dtype=DataType.bfloat16) +_add_configs("full_precision_residual", full_precision_residual=True) +_add_configs("sequence_first", sequence_first=True) +_add_configs("logit_scaling", logits_scale_factor=5.0) +_add_configs("tied_embedding_weight", tied_embedding_weight=True) +_add_configs("multi_token_prediction", prediction_heads=2) +_add_configs("label_loss", label_loss=True) +_add_configs("distillation_loss", distillation_loss=True) +_add_configs("z_loss", z_loss=True) +_add_configs("label_and_distillation_loss", label_loss=True, distillation_loss=True) +_add_configs("label_and_z_loss_weighted", label_loss=True, z_loss=0.5) +_add_configs("label_and_distillation_loss_zero_weight", label_loss=True, distillation_loss=0.0) @pytest.mark.slow From 89bda84d8611805790004f2ff8f8ebd01c27c781 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 21 Jan 2026 15:21:29 -0500 Subject: [PATCH 51/51] stuff --- fast_llm/layers/language_model/loss/config.py | 16 ----- fast_llm/layers/language_model/loss/grpo.py | 64 ------------------- 2 files changed, 80 deletions(-) delete mode 100644 fast_llm/layers/language_model/loss/grpo.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index a6057d67f..f531a1d46 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -14,7 +14,6 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -167,18 +166,3 @@ def loss_class(self) -> "type[LanguageModelZLoss]": from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss return LanguageModelZLoss - - -@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) -class LanguageModelGRPOLossConfig(LanguageModelLossConfig): - - _abstract: typing.ClassVar[bool] = False - - epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") - epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") - - @property - def loss_class(self) -> "type[LanguageModelGRPOLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss - - return LanguageModelGRPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py deleted file mode 100644 index eeac6b9c4..000000000 --- a/fast_llm/layers/language_model/loss/grpo.py +++ /dev/null @@ -1,64 +0,0 @@ -import typing - -import torch - -from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig -from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward - - -class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Support vocab_parallel - if self._num_splits > 1: - raise NotImplementedError() - if self._prediction_distance > 0: - raise NotImplementedError() - if self._vocab_parallel: - raise NotImplementedError() - - def forward_backward( - self, - logits: "torch.Tensor", - kwargs: dict[str, typing.Any], - split_index: int = 0, - ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return loss_forward_backward( - self._get_grad_output(kwargs), - grpo_loss, - logits, - self._get_loss_mask(kwargs, split_index), - self._get_labels(kwargs, split_index), - advantages, - old_log_probabilities, - self._config.epsilon_low, - self._config.epsilon_high, - self._logits_scale_factor, - ) - - -def grpo_loss( - logits: torch.Tensor, - loss_mask: "torch.Tensor | None", - labels: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float = 1.0, -) -> torch.Tensor: - if logits_scale_factor != 1.0: - # TODO: Make more efficient. - logits = logits * logits_scale_factor - - # Log probabilities. - logprobs = torch.nn.functional.log_softmax(logits, dim=-1) - target_log_probabilities = torch.gather(logprobs, dim=2, index=labels.unsqueeze(2)).squeeze(2) - probability_ratio = torch.exp(target_log_probabilities - old_log_probabilities) - loss = -torch.min( - probability_ratio * advantages, - torch.clamp(probability_ratio, 1 - epsilon_low, 1 + epsilon_high) * advantages, - ) - if loss_mask is not None: - loss = loss * loss_mask - return loss.mean()