diff --git a/fast_llm/config.py b/fast_llm/config.py index 02a83375c..4339fa401 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -562,6 +562,8 @@ def _validate_dict(cls, value, type_, name: str): errors.append(f"Duplicate key `{new_key}` after validation (from `{old_keys[new_key]}`, `{key}`)") old_keys[new_key] = key new_value[new_key] = new_value_ + # Ensure dicts are sorted W.R.T validated keys. + new_value = {new_key: new_value[new_key] for new_key in sorted(new_value)} if errors: raise ValidationError(*errors) return new_value diff --git a/tests/config/common.py b/tests/config/common.py index 9ccfb5972..b341bd0cb 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -31,7 +31,7 @@ class ExampleConfig(Config): type_field: type[int] = Field(default=int, hint=FieldHint.optional) enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional) core_field: int = Field(default=4, hint=FieldHint.core) - complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + complex_field: dict[str, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 91b5c0d82..bc5881167 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -184,8 +184,8 @@ def test_core_field(): "value", ( {}, - {3: None, "text": [], 0: [["", 3], ["a", -7]]}, - {0: [[".", 8]]}, + {"3": None, "text": [], "0": [["", 3], ["a", -7]]}, + {"0": [[".", 8]]}, ), ) def test_complex_field(value): @@ -194,7 +194,7 @@ def test_complex_field(value): @pytest.mark.parametrize( "value", - ({3: None, "text": [], False: [["", 3], ["a", -7]]},), + ({"3": None, "text": [], False: [["", 3], ["a", -7]]},), ) def test_complex_field_invalid(value): check_invalid_config({"complex_field": value}) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 1d08986f8..ee3e0e2e1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -117,7 +117,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device=device, ) if LanguageModelKwargs.loss_mask in kwargs: - labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100) kwargs[LanguageModelKwargs.labels] = labels if self.distillation_loss is not False: