From fe2492f904523b494704d943de8f069d79703dc7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 22 Jan 2026 17:26:50 -0500 Subject: [PATCH 1/3] Fix dict ordering --- fast_llm/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 02a83375c..ab2402ff7 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -555,9 +555,9 @@ def _validate_dict(cls, value, type_, name: str): errors = [] new_value = {} old_keys = {} - for key, value_ in value.items(): + for key in sorted(value): new_key = cls._validate_nested(key, args[0], f"{name}(key {key})", None, errors, True) - new_value_ = cls._validate_nested(value_, args[1], f"{name}[{key}]", None, errors, True) + new_value_ = cls._validate_nested(value[key], args[1], f"{name}[{key}]", None, errors, True) if key in new_value: errors.append(f"Duplicate key `{new_key}` after validation (from `{old_keys[new_key]}`, `{key}`)") old_keys[new_key] = key From 9ce58a7432d0df2fc4e1070ff57597c60bca480f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 22 Jan 2026 18:08:02 -0500 Subject: [PATCH 2/3] fix --- fast_llm/config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index ab2402ff7..4339fa401 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -555,13 +555,15 @@ def _validate_dict(cls, value, type_, name: str): errors = [] new_value = {} old_keys = {} - for key in sorted(value): + for key, value_ in value.items(): new_key = cls._validate_nested(key, args[0], f"{name}(key {key})", None, errors, True) - new_value_ = cls._validate_nested(value[key], args[1], f"{name}[{key}]", None, errors, True) + new_value_ = cls._validate_nested(value_, args[1], f"{name}[{key}]", None, errors, True) if key in new_value: errors.append(f"Duplicate key `{new_key}` after validation (from `{old_keys[new_key]}`, `{key}`)") old_keys[new_key] = key new_value[new_key] = new_value_ + # Ensure dicts are sorted W.R.T validated keys. + new_value = {new_key: new_value[new_key] for new_key in sorted(new_value)} if errors: raise ValidationError(*errors) return new_value From e3a4af63065f39fa8a28ebb20270f0b9870d699b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 22 Jan 2026 18:48:30 -0500 Subject: [PATCH 3/3] fix --- tests/config/common.py | 2 +- tests/config/test_field.py | 6 +++--- tests/layers/test_lm_head.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/config/common.py b/tests/config/common.py index 9ccfb5972..b341bd0cb 100644 --- a/tests/config/common.py +++ b/tests/config/common.py @@ -31,7 +31,7 @@ class ExampleConfig(Config): type_field: type[int] = Field(default=int, hint=FieldHint.optional) enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional) core_field: int = Field(default=4, hint=FieldHint.core) - complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) + complex_field: dict[str, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional) def _validate(self) -> None: with self._set_implicit_default(): diff --git a/tests/config/test_field.py b/tests/config/test_field.py index 91b5c0d82..bc5881167 100644 --- a/tests/config/test_field.py +++ b/tests/config/test_field.py @@ -184,8 +184,8 @@ def test_core_field(): "value", ( {}, - {3: None, "text": [], 0: [["", 3], ["a", -7]]}, - {0: [[".", 8]]}, + {"3": None, "text": [], "0": [["", 3], ["a", -7]]}, + {"0": [[".", 8]]}, ), ) def test_complex_field(value): @@ -194,7 +194,7 @@ def test_complex_field(value): @pytest.mark.parametrize( "value", - ({3: None, "text": [], False: [["", 3], ["a", -7]]},), + ({"3": None, "text": [], False: [["", 3], ["a", -7]]},), ) def test_complex_field_invalid(value): check_invalid_config({"complex_field": value}) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 1d08986f8..ee3e0e2e1 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -117,7 +117,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device=device, ) if LanguageModelKwargs.loss_mask in kwargs: - labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels) + labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100) kwargs[LanguageModelKwargs.labels] = labels if self.distillation_loss is not False: