diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 5f58024e0..9ba1f3433 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -9,11 +9,7 @@ from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.config import DecoderBlockConfig -from fast_llm.layers.language_model.loss.config import ( - LanguageModelLabelEntropyLossConfig, - LanguageModelLossConfig, - LanguageModelLossKwargs, -) +from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -141,7 +137,8 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): ) losses: dict[str, LanguageModelLossConfig] = Field( default_factory=dict, - desc="A dictionary of loss names and their configurations.", + desc="A dictionary of loss names and their configurations. " + "If not specified, a cross-entropy loss with respect to the targets will be used.", hint=FieldHint.core, ) # TODO: Cleanup @@ -197,10 +194,6 @@ def layer_class(self) -> "type[LanguageModelHead]": return LanguageModelHead def _validate(self) -> None: - with self._set_implicit_default(): - if not self.losses: - if "losses" not in self._explicit_fields: - self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()} super()._validate() assert LM_HEAD_LOSS_NAME not in self.losses diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index e8c60ae9c..d9220d3e1 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -25,6 +25,7 @@ LanguageModelHeadConfig, LanguageModelKwargs, ) +from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -95,6 +96,9 @@ def __init__( lr_scale=self._lr_scale, peft=self._peft, ) + loss_configs = ( + self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} + ) self._losses = [ loss_config.get_layer( distributed_config, @@ -106,7 +110,7 @@ def __init__( self._config.logits_scale_factor, self._loss_coefficient, ) - for name, loss_config in self._config.losses.items() + for name, loss_config in loss_configs.items() ] def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/tests/test_config.py b/tests/test_config.py index 2e900cb14..4020b6fbc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -148,15 +148,12 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, - "head": {"losses": {"lm_loss": {"type": "cross_entropy"}}}, "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } else: expected_config["base_model"] = base_model_update - # added by default - expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}} check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config))