From 5facd4bbb12e3c34daf584e5016874d78be44ab7 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Thu, 18 Sep 2025 23:40:01 +0200 Subject: [PATCH] Experimental support for dataclass inheritance (#287) --- CHANGELOG.rst | 2 + jsonargparse/_common.py | 58 ++++--- jsonargparse/_core.py | 8 +- jsonargparse/_namespace.py | 2 + jsonargparse/_optionals.py | 28 ++-- jsonargparse/_parameter_resolvers.py | 3 - jsonargparse/_signatures.py | 30 ++-- jsonargparse/_typehints.py | 109 +++++++------ jsonargparse_tests/test_dataclasses.py | 171 +++++++++++++++++++-- jsonargparse_tests/test_final_classes.py | 1 - jsonargparse_tests/test_loaders_dumpers.py | 4 +- jsonargparse_tests/test_subclasses.py | 5 +- 12 files changed, 308 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9234bd6c..b2e3c2df 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Added - ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770 `__). +- Experimental support for dataclass inheritance (`#775 + `__). Fixed ^^^^^ diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index d140de31..e383d159 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -6,6 +6,7 @@ from contextlib import contextmanager from contextvars import ContextVar from typing import ( # type: ignore[attr-defined] + Callable, Dict, Generic, List, @@ -28,6 +29,8 @@ import_reconplogger, is_alias_type, is_annotated, + is_attrs_class, + is_pydantic_model, reconplogger_support, typing_extensions_import, ) @@ -212,11 +215,22 @@ def supports_optionals_as_positionals(parser): def is_subclass(cls, class_or_tuple) -> bool: - """Extension of issubclass that supports non-class arguments.""" + """Extension of issubclass that supports non-class arguments and generics.""" try: - return inspect.isclass(cls) and issubclass(cls, class_or_tuple) + class_or_tuple = get_generic_origins(class_or_tuple) + if inspect.isclass(cls): + return issubclass(cls, class_or_tuple) + elif is_generic_class(cls): + return issubclass(cls.__origin__, class_or_tuple) except TypeError: - return False + pass # TypeError means that cls is not a class + return False + + +def is_instance(obj, class_or_tuple) -> bool: + """Extension of isinstance that supports generics.""" + class_or_tuple = get_generic_origins(class_or_tuple) + return isinstance(obj, class_or_tuple) def is_final_class(cls) -> bool: @@ -236,6 +250,12 @@ def get_generic_origin(cls): return cls.__origin__ if is_generic_class(cls) else cls +def get_generic_origins(class_or_tuple): + if isinstance(class_or_tuple, tuple): + return tuple(get_generic_origin(cls) for cls in class_or_tuple) + return get_generic_origin(class_or_tuple) + + def get_unaliased_type(cls): new_cls = cls while True: @@ -249,29 +269,25 @@ def get_unaliased_type(cls): return cur_cls -def is_dataclass_like(cls) -> bool: - if is_generic_class(cls): - return is_dataclass_like(cls.__origin__) - if not inspect.isclass(cls) or cls is object: - return False - if is_final_class(cls): - return True +def is_pure_dataclass(cls) -> bool: classes = [c for c in inspect.getmro(cls) if c not in {object, Generic}] - all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes) - - if not all_dataclasses: - from ._optionals import attrs_support, is_pydantic_model + return all(dataclasses.is_dataclass(c) for c in classes) - if is_pydantic_model(cls): - return True - if attrs_support: - import attrs +not_subclass_type_selectors: Dict[str, Callable[[Type], Union[bool, int]]] = { + "final": is_final_class, + "dataclass": is_pure_dataclass, + "pydantic": is_pydantic_model, + "attrs": is_attrs_class, +} - if attrs.has(cls): - return True - return all_dataclasses +def is_not_subclass_type(cls) -> bool: + if is_generic_class(cls): + return is_not_subclass_type(cls.__origin__) + if not inspect.isclass(cls): + return False + return any(validator(cls) for validator in not_subclass_type_selectors.values()) def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType: diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index 558a5e96..a97801bb 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -43,7 +43,7 @@ class_instantiators, debug_mode_active, get_optionals_as_positionals_actions, - is_dataclass_like, + is_not_subclass_type, lenient_check, parser_context, supports_optionals_as_positionals, @@ -132,7 +132,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs): return ActionParser._move_parser_actions(parser, args, kwargs) ActionConfigFile._ensure_single_config_argument(self, kwargs["action"]) if "type" in kwargs: - if is_dataclass_like(kwargs["type"]): + if is_not_subclass_type(kwargs["type"]): nested_key = args[0].lstrip("-") self.add_class_arguments(kwargs.pop("type"), nested_key, **kwargs) return _find_action(parser, nested_key) @@ -1236,9 +1236,7 @@ def instantiate_classes( """ components: List[Union[ActionTypeHint, _ActionConfigLoad, ArgumentGroup]] = [] for action in filter_default_actions(self._actions): - if isinstance(action, ActionTypeHint) or ( - isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype) - ): + if isinstance(action, ActionTypeHint): components.append(action) if instantiate_groups: diff --git a/jsonargparse/_namespace.py b/jsonargparse/_namespace.py index f5d14c6f..f1f2d212 100644 --- a/jsonargparse/_namespace.py +++ b/jsonargparse/_namespace.py @@ -293,6 +293,8 @@ def update( if not only_unset or key not in self: self[key] = value else: + if key and not isinstance(self.get(key), Namespace): + self[key] = Namespace() prefix = key + "." if key else "" for subkey, subval in value.items(): if not only_unset or prefix + subkey not in self: diff --git a/jsonargparse/_optionals.py b/jsonargparse/_optionals.py index 37aa8cca..2df24c48 100644 --- a/jsonargparse/_optionals.py +++ b/jsonargparse/_optionals.py @@ -346,18 +346,28 @@ def get_pydantic_supports_field_init() -> bool: def is_pydantic_model(class_type) -> int: - classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else [] - for cls in classes: - if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel": - import pydantic - - if issubclass(cls, pydantic.BaseModel): - return pydantic_support - elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel): - return 1 + if pydantic_support: + classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else [] + for cls in classes: + if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel": + import pydantic + + if issubclass(cls, pydantic.BaseModel): + return pydantic_support + elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel): + return 1 return 0 +def is_attrs_class(class_type) -> bool: + if attrs_support: + import attrs + + if attrs.has(class_type): + return True + return False + + def get_module(value): return getattr(type(value), "__module__", "").split(".", 1)[0] diff --git a/jsonargparse/_parameter_resolvers.py b/jsonargparse/_parameter_resolvers.py index e440a7a7..129788cc 100644 --- a/jsonargparse/_parameter_resolvers.py +++ b/jsonargparse/_parameter_resolvers.py @@ -17,7 +17,6 @@ LoggerProperty, get_generic_origin, get_unaliased_type, - is_dataclass_like, is_generic_class, is_subclass, is_unpack_typehint, @@ -372,8 +371,6 @@ def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component def is_param_subclass_instance_default(param: ParamData) -> bool: - if is_dataclass_like(type(param.default)): - return False from ._typehints import ActionTypeHint, get_optional_arg, get_subclass_types annotation = get_optional_arg(param.annotation) diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index 3af652d8..57c726c9 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -12,12 +12,12 @@ get_class_instantiator, get_generic_origin, get_unaliased_type, - is_dataclass_like, is_final_class, + is_not_subclass_type, is_subclass, ) from ._namespace import Namespace -from ._optionals import attrs_support, get_doc_short_description, is_pydantic_model, pydantic_support +from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters from ._typehints import ( ActionTypeHint, @@ -25,6 +25,7 @@ callable_instances, get_subclass_names, is_optional, + is_subclass_spec, not_required_types, ) from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str @@ -86,7 +87,7 @@ def add_class_arguments( or (isinstance(default, LazyInitBaseClass) and isinstance(default, unaliased_class_type)) or ( not is_final_class(default.__class__) - and is_dataclass_like(default.__class__) + and is_not_subclass_type(default.__class__) and isinstance(default, unaliased_class_type) ) ): @@ -120,15 +121,15 @@ def add_class_arguments( defaults = default if isinstance(default, LazyInitBaseClass): defaults = default.lazy_get_init_args().as_dict() - elif is_dataclass_like(default.__class__): - defaults = dataclass_to_dict(default) + elif is_convertible_to_dict(default.__class__): + defaults = convert_to_dict(default) args = {k[len(prefix) :] for k in added_args} skip_not_added = [k for k in defaults if k not in args] if skip_not_added: skip.update(skip_not_added) # skip init=False - elif isinstance(default, Namespace): - defaults = default.as_dict() if defaults: + if is_subclass_spec(defaults): + defaults = defaults.get("init_args", {}) defaults = {prefix + k: v for k, v in defaults.items() if k not in skip} self.set_defaults(**defaults) # type: ignore[attr-defined] @@ -389,7 +390,7 @@ def _add_signature_parameter( elif not as_positional or is_non_positional: kwargs["required"] = True is_subclass_typehint = False - is_dataclass_like_typehint = is_dataclass_like(annotation) + is_not_subclass_typehint = is_not_subclass_type(annotation) dest = (nested_key + "." if nested_key else "") + name args = [dest if is_required and as_positional and not is_non_positional else "--" + dest] if param.origin: @@ -407,7 +408,7 @@ def _add_signature_parameter( if ( annotation in {str, int, float, bool} or is_subclass(annotation, (str, int, float)) - or is_dataclass_like_typehint + or is_not_subclass_typehint ): kwargs["type"] = annotation register_pydantic_type(annotation) @@ -441,7 +442,7 @@ def _add_signature_parameter( "sub_configs": sub_configs, "instantiate": instantiate, } - if is_dataclass_like_typehint: + if is_not_subclass_typehint: kwargs.update(sub_add_kwargs) with ActionTypeHint.allow_default_instance_context(): action = container.add_argument(*args, **kwargs) @@ -492,8 +493,6 @@ def add_subclass_arguments( Raises: ValueError: When given an invalid base class. """ - if is_dataclass_like(baseclass): - raise ValueError("Not allowed for dataclass-like classes.") if type(baseclass) is not tuple: baseclass = (baseclass,) # type: ignore[assignment] if not baseclass or not all(ActionTypeHint.is_subclass_typehint(c, also_lists=True) for c in baseclass): @@ -590,7 +589,11 @@ def is_factory_class(value): return value.__class__ == dataclasses._HAS_DEFAULT_FACTORY_CLASS -def dataclass_to_dict(value) -> dict: +def is_convertible_to_dict(value): + return dataclasses.is_dataclass(value) or is_attrs_class(value) or is_pydantic_model(value) + + +def convert_to_dict(value) -> dict: if pydantic_support: pydantic_model = is_pydantic_model(type(value)) if pydantic_model: @@ -602,6 +605,7 @@ def dataclass_to_dict(value) -> dict: is_attrs_dataclass = attrs.has(type(value)) if is_attrs_dataclass: return attrs.asdict(value) + return dataclasses.asdict(value) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index c9024ee9..2637f30a 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -52,7 +52,9 @@ from ._common import ( get_class_instantiator, get_unaliased_type, - is_dataclass_like, + is_generic_class, + is_instance, + is_not_subclass_type, is_subclass, lenient_check, nested_links, @@ -259,13 +261,13 @@ def __init__(self, typehint: Optional[Type] = None, enable_path: bool = False, * self.default = self.normalize_default(self.default) def normalize_default(self, default): + from ._signatures import convert_to_dict, is_convertible_to_dict + is_subclass_type = self.is_subclass_typehint(self._typehint, all_subtypes=False) if isinstance(default, LazyInitBaseClass): - default = default.lazy_get_init_data() - elif is_dataclass_like(default.__class__): - from ._signatures import dataclass_to_dict - - default = dataclass_to_dict(default) + default = default.lazy_get_init_data().as_dict() + elif is_convertible_to_dict(default.__class__): + default = convert_to_dict(default) elif is_subclass_type and isinstance(default, dict) and "class_path" in default: default = subclass_spec_as_namespace(default) default.class_path = normalize_import_path(default.class_path, self._typehint) @@ -313,7 +315,6 @@ def is_supported_typehint(typehint, full=False): or get_typehint_origin(typehint) in root_types or get_registered_type(typehint) is not None or is_subclass(typehint, Enum) - or is_dataclass_like(typehint) or ActionTypeHint.is_subclass_typehint(typehint) ) if full and supported: @@ -475,7 +476,7 @@ def skip_sub_defaults_apply(v): ) with ActionTypeHint.sub_defaults_context(), parent_parsers_context(None, None): - parser._apply_actions(cfg, skip_fn=skip_sub_defaults_apply) + parser._apply_actions(cfg, skip_fn=skip_sub_defaults_apply, prev_cfg=cfg.clone()) @staticmethod def supports_append(action): @@ -641,8 +642,6 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=None): kwargs = dict(sub_add_kwargs) if sub_add_kwargs else {} if skip_args: kwargs.setdefault("skip", set()).update(skip_args) - if is_subclass_spec(kwargs.get("default")): - kwargs["default"] = kwargs["default"].get("init_args") parser = parent_parser.get() from ._core import ArgumentParser @@ -749,7 +748,6 @@ def adapt_typehints( prev_val=None, orig_val=None, append=False, - list_item=False, enable_path=False, sub_add_kwargs=None, default=None, @@ -845,7 +843,7 @@ def adapt_typehints( # Union elif typehint_origin == Union: vals = [] - sorted_subtypes = sort_subtypes_for_union(subtypehints, val, append) + sorted_subtypes = sort_subtypes_for_union(subtypehints, val, prev_val, append) for subtype in sorted_subtypes: try: vals.append(adapt_typehints(val, subtype, **adapt_kwargs)) @@ -909,7 +907,7 @@ def adapt_typehints( else: adapt_kwargs_n = deepcopy(adapt_kwargs) with change_to_path_dir(list_path): - val[n] = adapt_typehints(v, subtypehints[0], list_item=True, **adapt_kwargs_n) + val[n] = adapt_typehints(v, subtypehints[0], **adapt_kwargs_n) # Dict, Mapping elif typehint_origin in mapping_origin_types: @@ -1035,29 +1033,8 @@ def adapt_typehints( except (ImportError, AttributeError, ArgumentError) as ex: raise_unexpected_value(f"Type {typehint} expects a function or a callable class: {ex}", val, ex) - # Dataclass-like - elif is_dataclass_like(typehint): - if isinstance(prev_val, (dict, Namespace)): - assert isinstance(sub_add_kwargs, dict) - sub_add_kwargs["default"] = prev_val - parser = ActionTypeHint.get_class_parser(typehint, sub_add_kwargs=sub_add_kwargs) - if instantiate_classes: - init_args = parser.instantiate_classes(val) - return typehint(**init_args) - if serialize: - val = load_value(parser.dump(val, **dump_kwargs.get())) - elif isinstance(val, (dict, Namespace)): - if is_subclass_spec(val) and get_import_path(typehint) == val.get("class_path"): - val = val.get("init_args") - val = parser.parse_object(val, defaults=sub_defaults.get() or list_item) - elif isinstance(val, NestedArg): - prev_val = prev_val if isinstance(prev_val, Namespace) else None - val = parser.parse_args([f"--{val.key}={val.val}"], namespace=prev_val) - else: - raise_unexpected_value(f"Type {typehint} expects a dict or Namespace", val) - # Subclass - elif not hasattr(typehint, "__origin__") and inspect.isclass(typehint): + elif inspect.isclass(typehint_origin): if is_instance_or_supports_protocol(val, typehint): if serialize: val = serialize_class_instance(val) @@ -1068,12 +1045,21 @@ def adapt_typehints( prev_implicit_defaults = False if prev_val is None and not inspect.isabstract(typehint) and not is_protocol(typehint): with suppress(ValueError): - prev_val = Namespace(class_path=get_import_path(typehint)) # implicit class_path + # implicit prev_val class_path + prev_val = Namespace(class_path=get_import_path(typehint)) if parse_kwargs.get().get("defaults") is True: prev_implicit_defaults = True + if isinstance(prev_val, (dict, Namespace)) and "class_path" not in prev_val: + # implicit prev_val class_path and init_args + prev_val = Namespace(class_path=get_import_path(typehint), init_args=Namespace(prev_val)) + val_input = val val = subclass_spec_as_namespace(val, prev_val) + if val and not is_subclass_spec(val) and "init_args" not in val: + # implicit val class_path + val = Namespace(class_path=get_import_path(typehint), init_args=val) + if not is_subclass_spec(val): msg = "Does not implement protocol" if is_protocol(typehint) else "Not a valid subclass of" raise_unexpected_value( @@ -1081,7 +1067,8 @@ def adapt_typehints( "Subclass types expect one of:\n" "- a class path (str)\n" "- a dict with class_path entry\n" - "- a dict without class_path but with init_args entry (class path given previously)" + "- a dict without class_path but with init_args entry (class path given previously)\n" + "- a dict with parameters accepted by the base class (implicit class_path)" ) try: @@ -1111,7 +1098,14 @@ def adapt_typehints( msg = "implement protocol" if is_protocol(typehint) else "correspond to a subclass of" raise_unexpected_value(f"Import path {val['class_path']} does not {msg} {typehint.__name__}") val["class_path"] = class_path - val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val) + val = adapt_class_type( + val, + serialize, + instantiate_classes, + sub_add_kwargs, + prev_val=prev_val, + typehint=typehint, + ) except (ImportError, AttributeError, AssertionError, ArgumentError) as ex: class_path = val if isinstance(val, str) else val["class_path"] error = indent_text(str(ex)) @@ -1184,7 +1178,7 @@ def is_subclass_or_implements_protocol(value, class_type) -> bool: def is_instance_or_supports_protocol(value, class_type): if is_protocol(class_type): return is_subclass_or_implements_protocol(value.__class__, class_type) - return isinstance(value, class_type) + return is_instance(value, class_type) def is_instance_factory_protocol(class_type, logger=None): @@ -1251,13 +1245,15 @@ def get_callable_return_type(typehint): def is_single_subclass_typehint(typehint, typehint_origin): return ( - inspect.isclass(typehint) + ( + (inspect.isclass(typehint) and typehint_origin is None) + or (is_generic_class(typehint) and inspect.isclass(typehint.__origin__)) + ) and typehint not in leaf_or_root_types and not get_registered_type(typehint) and not is_pydantic_type(typehint) - and not is_dataclass_like(typehint) - and typehint_origin is None and not is_subclass(typehint, (Path, Enum)) + and getattr(typehint_origin, "__module__", "") != "builtins" ) @@ -1417,10 +1413,21 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value): ) -def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None, partial_skip_args=None): +def adapt_class_type( + value, + serialize, + instantiate_classes, + sub_add_kwargs, + prev_val=None, + partial_skip_args=None, + typehint=None, +): prev_val = subclass_spec_as_namespace(prev_val) value = subclass_spec_as_namespace(value) - val_class = import_object(value.class_path) + if is_generic_class(typehint): + val_class = typehint + else: + val_class = import_object(value.class_path) parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs, skip_args=partial_skip_args) # No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream. @@ -1495,6 +1502,10 @@ def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev with suppress(get_loader_exceptions()): val = load_value(val, simple_types=True) value["dict_kwargs"][key] = val + + if is_not_subclass_type(typehint) and value.class_path == get_import_path(typehint): + value = Namespace({**value.get("init_args", {}), **value.get("dict_kwargs", {})}) + return value @@ -1520,7 +1531,7 @@ def adapt_classes_any(val, serialize, instantiate_classes, sub_add_kwargs): return val -def sort_subtypes_for_union(subtypes, val, append): +def sort_subtypes_for_union(subtypes, val, prev_val, append): if len(subtypes) > 1: if isinstance(val, str): key_fn = lambda x: ( @@ -1530,8 +1541,12 @@ def sort_subtypes_for_union(subtypes, val, append): else: key_fn = lambda x: x != NoneType subtypes = sorted(subtypes, key=key_fn) - if append: - subtypes = sorted(subtypes, key=lambda x: get_typehint_origin(x) not in sequence_origin_types) + if append or (isinstance(prev_val, list) and isinstance(val, NestedArg)): + key_fn = lambda x: ( + x != NoneType, + get_typehint_origin(x) not in sequence_origin_types, + ) + subtypes = sorted(subtypes, key=key_fn) return subtypes diff --git a/jsonargparse_tests/test_dataclasses.py b/jsonargparse_tests/test_dataclasses.py index 6ca4facd..89781f71 100644 --- a/jsonargparse_tests/test_dataclasses.py +++ b/jsonargparse_tests/test_dataclasses.py @@ -23,6 +23,7 @@ ) from jsonargparse.typing import PositiveFloat, PositiveInt from jsonargparse_tests.conftest import ( + get_parse_args_stdout, get_parser_help, json_or_yaml_load, skip_if_docstring_parser_unavailable, @@ -57,10 +58,6 @@ class DataClassB: b2: DataClassA = DataClassA(a2="x") -class MixedClass(int, DataClassA): - """MixedClass description""" - - def test_add_class_arguments(parser, subtests): parser.add_class_arguments(DataClassA, "a", default=DataClassA(), help="CustomA title") parser.add_class_arguments(DataClassB, "b", default=DataClassB()) @@ -96,8 +93,6 @@ def test_add_class_arguments(parser, subtests): parser.add_class_arguments(1, "c") with pytest.raises(NSKeyError, match='No action for key "c.b2.b1" to set its default'): parser.add_class_arguments(DataClassB, "c", default=DataClassB(b2=DataClassB())) - with pytest.raises(ValueError): - parser.add_class_arguments(MixedClass, "c") @dataclasses.dataclass @@ -238,7 +233,7 @@ def test_list_append_defaults(parser): def test_add_argument_dataclass_type(parser): parser.add_argument("--b", type=DataClassB, default=DataClassB(b1=7.0)) cfg = parser.get_defaults() - assert {"b1": 7.0, "b2": {"a1": 1, "a2": "x"}} == cfg.b.as_dict() + assert Namespace(b1=7.0, b2=Namespace(a1=1, a2="x")) == cfg.b init = parser.instantiate_classes(cfg) assert isinstance(init.b, DataClassB) assert isinstance(init.b.b2, DataClassA) @@ -599,9 +594,6 @@ def test_generic_dataclass_subclass(parser): assert isinstance(init.x.children[1], GenericChild) -# union mixture tests - - @dataclasses.dataclass class UnionData: data_a: int = 1 @@ -639,6 +631,31 @@ def test_class_path_union_mixture_dataclass_and_class(parser, union_type): assert json_or_yaml_load(parser.dump(cfg))["union"] == value +def test_class_path_union_dataclasses(parser): + parser.add_argument("--union", type=Union[Data, SingleParamChange, UnionData]) + + value = {"class_path": f"{__name__}.UnionData", "init_args": {"data_a": 2, "data_b": "x"}} + cfg = parser.parse_args([f"--union={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.union, UnionData) + assert dataclasses.asdict(init.union) == {"data_a": 2, "data_b": "x"} + assert json_or_yaml_load(parser.dump(cfg))["union"] == value["init_args"] + + value = {"class_path": f"{__name__}.SingleParamChange", "init_args": {"p1": 2}} + cfg = parser.parse_args([f"--union={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.union, SingleParamChange) + assert dataclasses.asdict(init.union) == {"p1": 2, "p2": 0} + assert json_or_yaml_load(parser.dump(cfg))["union"] == {"p1": 2, "p2": 0} + + value = {"class_path": f"{__name__}.Data", "init_args": {"p1": "x"}} + cfg = parser.parse_args([f"--union={json.dumps(value)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.union, Data) + assert dataclasses.asdict(init.union) == {"p1": "x", "p2": 0} + assert json_or_yaml_load(parser.dump(cfg))["union"] == {"p1": "x", "p2": 0} + + if type_alias_type: IntOrString = type_alias_type("IntOrString", Union[int, str]) @@ -701,3 +718,137 @@ def test_dataclass_with_annotated_alias_type(self, parser): assert cfg.data.p1 == "MyString" cfg = parser.parse_args(["--data.p1=3"]) assert cfg.data.p1 == 3 + + +@dataclasses.dataclass +class DataMain: + p1: int = 1 + + +@dataclasses.dataclass +class DataSub(DataMain): + p2: str = "-" + + +def test_dataclass_not_subclass(parser): + parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) + + help_str = get_parser_help(parser) + assert "--data.help [CLASS_PATH_OR_NAME]" not in help_str + + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + with pytest.raises(ArgumentError, match="Group 'data' does not accept nested key 'init_args.p2'"): + parser.parse_args([f"--data={json.dumps(config)}"]) + + +@pytest.fixture +def subclass_behavior(): + with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors: + not_subclass_type_selectors.pop("dataclass") + yield + + +def test_dataclass_argument_as_subclass(parser, subtests, subclass_behavior): + parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) + + with subtests.test("help"): + help_str = get_parser_help(parser) + assert "--data.help [CLASS_PATH_OR_NAME]" in help_str + assert f"{__name__}.DataMain" in help_str + assert f"{__name__}.DataSub" in help_str + + with subtests.test("defaults"): + defaults = parser.get_defaults() + dump = json_or_yaml_load(parser.dump(defaults))["data"] + assert dump == {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}} + + with subtests.test("sub-param"): + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 2, "p2": "y"} + dump = json_or_yaml_load(parser.dump(cfg))["data"] + assert dump == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 2, "p2": "y"}} + + with subtests.test("sub-default"): + config = {"class_path": "DataSub", "init_args": {"p1": 4}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 4, "p2": "-"} + + with subtests.test("mixed params"): + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 3, "p2": "x"}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + assert cfg.data == Namespace(class_path=f"{__name__}.DataSub", init_args=Namespace(p1=3, p2="x")) + assert cfg.data.init_args.p1 == 3 + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 3, "p2": "x"} + + with subtests.test("empty init_args"): + config = {"class_path": f"{__name__}.DataSub", "init_args": {}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 2, "p2": "-"} + + with subtests.test("class_path"): + cfg = parser.parse_args(["--data=DataSub"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 2, "p2": "-"} + + +class ParentData: + def __init__(self, data: DataMain = DataMain(p1=2)): + self.data = data + + +def test_dataclass_nested_not_subclass(parser): + parser.add_argument("--parent", type=ParentData) + + help_str = get_parse_args_stdout(parser, ["--parent.help"]) + assert "--parent.data.help [CLASS_PATH_OR_NAME]" not in help_str + + config = { + "class_path": f"{__name__}.ParentData", + "init_args": { + "data": { + "class_path": f"{__name__}.DataSub", + "init_args": {"p1": 3, "p2": "x"}, + } + }, + } + with pytest.raises(ArgumentError, match="Group 'data' does not accept nested key 'init_args.p1'"): + parser.parse_args([f"--parent={json.dumps(config)}"]) + + +def test_dataclass_nested_as_subclass(parser, subclass_behavior): + parser.add_argument("--parent", type=ParentData) + + help_str = get_parse_args_stdout(parser, ["--parent.help"]) + assert "--parent.data.help [CLASS_PATH_OR_NAME]" in help_str + + config = { + "class_path": f"{__name__}.ParentData", + "init_args": { + "data": { + "class_path": f"{__name__}.DataSub", + "init_args": {"p1": 3, "p2": "x"}, + } + }, + } + + cfg = parser.parse_args([f"--parent={json.dumps(config)}"]) + assert cfg.parent.init_args.data == Namespace(class_path=f"{__name__}.DataSub", init_args=Namespace(p1=3, p2="x")) + + dump = json_or_yaml_load(parser.dump(cfg))["parent"] + assert dump["class_path"] == f"{__name__}.ParentData" + assert dump["init_args"]["data"] == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 3, "p2": "x"}} + + init = parser.instantiate_classes(cfg) + assert isinstance(init.parent, ParentData) + assert isinstance(init.parent.data, DataSub) + assert dataclasses.asdict(init.parent.data) == {"p1": 3, "p2": "x"} diff --git a/jsonargparse_tests/test_final_classes.py b/jsonargparse_tests/test_final_classes.py index 055f243a..b28bda64 100644 --- a/jsonargparse_tests/test_final_classes.py +++ b/jsonargparse_tests/test_final_classes.py @@ -32,5 +32,4 @@ def test_add_class_final(parser): pytest.raises(ArgumentError, lambda: parser.parse_args(['--b.b2={"bad": "value"}'])) pytest.raises(ArgumentError, lambda: parser.parse_args(['--b.b2="bad"'])) - pytest.raises(ValueError, lambda: parser.add_subclass_arguments(FinalClass, "a")) pytest.raises(ValueError, lambda: parser.add_class_arguments(FinalClass, "a", default=FinalClass())) diff --git a/jsonargparse_tests/test_loaders_dumpers.py b/jsonargparse_tests/test_loaders_dumpers.py index f9eb9b86..fcbfa82f 100644 --- a/jsonargparse_tests/test_loaders_dumpers.py +++ b/jsonargparse_tests/test_loaders_dumpers.py @@ -113,7 +113,7 @@ def custom_loader(data): else: data = json.loads(data) if isinstance(data, dict) and "fn" in data: - data["fn"] = {k: custom_loader for k in data["fn"]} + data["fn"] = {k: f"custom loaded {v}" for k, v in data["fn"].items()} return data @@ -129,7 +129,7 @@ def test_nested_parser_mode(parser): parser.parser_mode = "custom" parser.add_argument("--custom", type=CustomContainer) cfg = parser.parse_args(['--custom.data={"fn": {"key": "value"}}']) - assert cfg.custom.init_args.data["fn"]["key"] is custom_loader + assert cfg.custom.init_args.data["fn"]["key"] == "custom loaded value" dump = json_or_yaml_load(parser.dump(cfg)) assert dump["custom"]["init_args"]["data"] == {"fn": {"key": "dumped"}} diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index 2f0506c9..d4a78246 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -1547,7 +1547,7 @@ def test_parse_implements_protocol(parser): with pytest.raises(ArgumentError, match="does not implement protocol"): parser.parse_args([f"--cls={__name__}.NotImplementsInterface1"]) with pytest.raises(ArgumentError, match="Does not implement protocol Interface"): - parser.parse_args(['--cls={"batch_size": 5}']) + parser.parse_args(["--cls=[1]"]) # callable protocol tests @@ -1613,7 +1613,7 @@ def test_parse_implements_callable_protocol(parser): with pytest.raises(ArgumentError, match="does not implement protocol"): parser.parse_args([f"--cls={__name__}.NotImplementsCallableInterface1"]) with pytest.raises(ArgumentError, match="Does not implement protocol CallableInterface"): - parser.parse_args(['--cls={"batch_size": 7}']) + parser.parse_args(["--cls=[1]"]) # parameter skip tests @@ -1873,6 +1873,7 @@ def test_subclass_error_indentation_in_union_invalid_value(parser): - a class path (str) - a dict with class_path entry - a dict without class_path but with init_args entry (class path given previously) + - a dict with parameters accepted by the base class (implicit class_path) Given value type: Given value: [{'class_path': 'ErrorIndentation2', 'init_args': {'val': 'x'}}] """