Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
from fast_llm.layers.common.normalization.config import NormalizationConfig
from fast_llm.layers.common.peft.config import PeftConfig
from fast_llm.layers.decoder.config import DecoderBlockConfig
from fast_llm.layers.language_model.loss.config import (
LanguageModelLabelEntropyLossConfig,
LanguageModelLossConfig,
LanguageModelLossKwargs,
)
from fast_llm.layers.language_model.loss.config import LanguageModelLossConfig, LanguageModelLossKwargs
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -141,7 +137,8 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig):
)
losses: dict[str, LanguageModelLossConfig] = Field(
default_factory=dict,
desc="A dictionary of loss names and their configurations.",
desc="A dictionary of loss names and their configurations. "
"If not specified, a cross-entropy loss with respect to the targets will be used.",
hint=FieldHint.core,
)
# TODO: Cleanup
Expand Down Expand Up @@ -197,10 +194,6 @@ def layer_class(self) -> "type[LanguageModelHead]":
return LanguageModelHead

def _validate(self) -> None:
with self._set_implicit_default():
if not self.losses:
if "losses" not in self._explicit_fields:
self.losses = {"lm_loss": LanguageModelLabelEntropyLossConfig()}
super()._validate()
assert LM_HEAD_LOSS_NAME not in self.losses

Expand Down
6 changes: 5 additions & 1 deletion fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LanguageModelHeadConfig,
LanguageModelKwargs,
)
from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig
from fast_llm.tensor import TensorMeta
from fast_llm.utils import Assert

Expand Down Expand Up @@ -95,6 +96,9 @@ def __init__(
lr_scale=self._lr_scale,
peft=self._peft,
)
loss_configs = (
self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()}
)
self._losses = [
loss_config.get_layer(
distributed_config,
Expand All @@ -106,7 +110,7 @@ def __init__(
self._config.logits_scale_factor,
self._loss_coefficient,
)
for name, loss_config in self._config.losses.items()
for name, loss_config in loss_configs.items()
]

def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int:
Expand Down
3 changes: 0 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,12 @@ def test_pretrained_config(load_config: ModelConfigType, result_path):
},
"num_blocks": 12,
},
"head": {"losses": {"lm_loss": {"type": "cross_entropy"}}},
"hidden_size": 512,
"tied_embedding_weight": False,
"peft": {"freeze_others": False},
}
else:
expected_config["base_model"] = base_model_update
# added by default
expected_config["base_model"]["head"] = {"losses": {"lm_loss": {"type": "cross_entropy"}}}

check_equal_nested(_trim_type(serialized_config), _trim_type(expected_config))

Expand Down
Loading