From f71f560fa8db6a457c7e8e9676a6f4a36c59549c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 26 Jan 2026 16:43:32 -0500 Subject: [PATCH 1/2] Fix default lm loss --- fast_llm/layers/language_model/config.py | 10 +------- fast_llm/layers/language_model/head.py | 30 ++++++++++++++---------- tests/test_config.py | 3 --- 3 files changed, 18 insertions(+), 25 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 5f58024e0..acdf1dbc2 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,11 +9,7 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.loss.config import ( - LanguageModelLabelEntropyLossConfig, - LanguageModelLossConfig, - LanguageModelLossKwargs, -) +from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -197,10 +193,6 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: - with self._set_implicit_default(): - if not self.losses: - if "losses" not in self._explicit_fields: - self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() assert LM_HEAD_LOSS_NAME not in self.losses diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e8c60ae9c..c6eac121e 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -25,6 +25,7 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) +from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -95,19 +96,22 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - self._losses = [ - loss_config.get_layer( - distributed_config, - self._get_full_loss_name(name), - self._prediction_distance, - self._prediction_heads, - self._vocab_parallel, - self._config.cross_entropy_splits, - self._config.logits_scale_factor, - self._loss_coefficient, - ) - for name, loss_config in self._config.losses.items() - ] + if self._config.losses: + self._losses = [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in self._config.losses.items() + ] + else: + self._losses = {"cross_entropy": LanguageModelLabelEntropyLossConfig().get_layer()} def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss) diff --git a/tests/test_config.py b/tests/test_config.py index 2e900cb14..4020b6fbc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,15 +148,12 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update - # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config)) From b9aec3b5dcc2d1d5317ff9cf5db84a9344fc0872 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 26 Jan 2026 16:51:46 -0500 Subject: [PATCH 2/2] Fix default lm loss --- fast_llm/layers/language_model/config.py | 3 ++- fast_llm/layers/language_model/head.py | 32 ++++++++++++------------ 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index acdf1dbc2..9ba1f3433 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -137,7 +137,8 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): ) losses: dict[str, LanguageModelLossConfig] = Field( default_factory=dict, - desc="A dictionary of loss names and their configurations.", + desc="A dictionary of loss names and their configurations. " + "If not specified, a cross-entropy loss with respect to the targets will be used.", hint=FieldHint.core, ) # TODO: Cleanup diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c6eac121e..d9220d3e1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -96,22 +96,22 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) - if self._config.losses: - self._losses = [ - loss_config.get_layer( - distributed_config, - self._get_full_loss_name(name), - self._prediction_distance, - self._prediction_heads, - self._vocab_parallel, - self._config.cross_entropy_splits, - self._config.logits_scale_factor, - self._loss_coefficient, - ) - for name, loss_config in self._config.losses.items() - ] - else: - self._losses = {"cross_entropy": LanguageModelLabelEntropyLossConfig().get_layer()} + loss_configs = ( + self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} + ) + self._losses = [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in loss_configs.items() + ] def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss)