Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ def _validate_dict(cls, value, type_, name: str):
errors.append(f"Duplicate key `{new_key}` after validation (from `{old_keys[new_key]}`, `{key}`)")
old_keys[new_key] = key
new_value[new_key] = new_value_
# Ensure dicts are sorted W.R.T validated keys.
new_value = {new_key: new_value[new_key] for new_key in sorted(new_value)}
if errors:
raise ValidationError(*errors)
return new_value
Expand Down
2 changes: 1 addition & 1 deletion tests/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ExampleConfig(Config):
type_field: type[int] = Field(default=int, hint=FieldHint.optional)
enum_field: ExampleEnum = Field(default=ExampleEnum.a, hint=FieldHint.optional)
core_field: int = Field(default=4, hint=FieldHint.core)
complex_field: dict[str | int, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional)
complex_field: dict[str, list[tuple[str, int]] | None] = Field(default_factory=dict, hint=FieldHint.optional)

def _validate(self) -> None:
with self._set_implicit_default():
Expand Down
6 changes: 3 additions & 3 deletions tests/config/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def test_core_field():
"value",
(
{},
{3: None, "text": [], 0: [["", 3], ["a", -7]]},
{0: [[".", 8]]},
{"3": None, "text": [], "0": [["", 3], ["a", -7]]},
{"0": [[".", 8]]},
),
)
def test_complex_field(value):
Expand All @@ -194,7 +194,7 @@ def test_complex_field(value):

@pytest.mark.parametrize(
"value",
({3: None, "text": [], False: [["", 3], ["a", -7]]},),
({"3": None, "text": [], False: [["", 3], ["a", -7]]},),
)
def test_complex_field_invalid(value):
check_invalid_config({"complex_field": value})
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/test_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]:
device=device,
)
if LanguageModelKwargs.loss_mask in kwargs:
labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], -100, labels)
labels = torch.where(kwargs[LanguageModelKwargs.loss_mask], labels, -100)
kwargs[LanguageModelKwargs.labels] = labels

if self.distillation_loss is not False:
Expand Down