Skip to content

Commit e2c0819

Browse files
authored
Improve load/save/to_dict (see description) (#233)
* Improve `load`/`save` functions (see desc.) - Fixes #173 - `save` and `to_dict` now have a `save_dc_types` argument. - When set, this saves the types of dataclass attributes in a new entry in the dict. By default, this is saved in `"_type_"`. - `load` has been improved: - When loading, if the `"_type_"` key is present, returns an instance of that type. - When `drop_extra_fields` is False, tries to use subclasses of the annotated type, rather than subclasses of Serializable. This means that dataclasses that don't subclass Serializable can still be dynamically found and used when loading a dataclass. - `get_decoding_fn`'s `lru_cache` has been disabled. If this causes significant performance drops, let me know. - Unit tests are no longer run twice (using the regular and 'simple' APIs). This was causing a lot of weird bugs, as some unit tests were leaking from one run to the other, particularly tests involving deserialization of dataclasses. - Updated the regression file names due to this - Update some other test regression files Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1 parent 0059dd7 commit e2c0819

32 files changed

+325
-355
lines changed

simple_parsing/helpers/fields.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
DEFAULT_NEGATIVE_PREFIX,
1919
BooleanOptionalAction,
2020
)
21-
from simple_parsing.utils import Dataclass, str2bool
21+
from simple_parsing.utils import DataclassT, str2bool
2222

2323
# NOTE: backward-compatibility import because it was moved to a different file.
2424
from .subgroups import subgroups # noqa: F401
@@ -344,7 +344,9 @@ def mutable_field(
344344

345345

346346
def subparsers(
347-
subcommands: dict[str, type[Dataclass]], default: Dataclass | _MISSING_TYPE = MISSING, **kwargs
347+
subcommands: dict[str, type[DataclassT]],
348+
default: DataclassT | _MISSING_TYPE = MISSING,
349+
**kwargs,
348350
) -> Any:
349351
return field(
350352
metadata={

simple_parsing/helpers/serialization/decoding.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Mapping
99
from dataclasses import Field
1010
from enum import Enum
11-
from functools import lru_cache, partial
11+
from functools import partial
1212
from logging import getLogger
1313
from pathlib import Path
1414
from typing import Any, Callable, TypeVar
@@ -55,7 +55,12 @@ def decode_bool(v: Any) -> bool:
5555
_decoding_fns[bool] = decode_bool
5656

5757

58-
def decode_field(field: Field, raw_value: Any, containing_dataclass: type | None = None) -> Any:
58+
def decode_field(
59+
field: Field,
60+
raw_value: Any,
61+
containing_dataclass: type | None = None,
62+
drop_extra_fields: bool | None = None,
63+
) -> Any:
5964
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.
6065
6166
When serializing a dataclass to json, all objects are converted to dicts.
@@ -84,10 +89,17 @@ def decode_field(field: Field, raw_value: Any, containing_dataclass: type | None
8489
if isinstance(field_type, str) and containing_dataclass:
8590
field_type = evaluate_string_annotation(field_type, containing_dataclass)
8691

87-
return get_decoding_fn(field_type)(raw_value)
92+
decoding_function = get_decoding_fn(field_type)
8893

94+
if is_dataclass_type(field_type) and drop_extra_fields is not None:
95+
# Pass the drop_extra_fields argument to the decoding function.
96+
return decoding_function(raw_value, drop_extra_fields=drop_extra_fields)
8997

90-
@lru_cache(maxsize=100)
98+
return decoding_function(raw_value)
99+
100+
101+
# NOTE: Disabling the caching here might help avoid some bugs, and it's unclear if this has that
102+
# much of a performance impact.
91103
def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]:
92104
"""Fetches/Creates a decoding function for the given type annotation.
93105

0 commit comments

Comments
 (0)