diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 699c5015..692c9389 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,13 +12,22 @@ The semantic versioning only considers the public API as described in paths are considered internals and can change in minor and patch releases. +v4.40.2 (2025-07-??) +-------------------- + +Fixed +^^^^^ +- Subclass defaults incorrectly taken from base class (`#743 + `__). + + v4.40.1 (2025-07-24) -------------------- Fixed ^^^^^ -- ``print_shtab`` incorrectly parsed from environment variable (`#725 - `__). +- ``print_shtab`` incorrectly parsed from environment variable (`#726 + `__). - ``adapt_class_type`` used a locally defined `partial_instance` wrapper function that is not pickleable (`#728 `__). diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 4a49264f..ae9a7981 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1081,7 +1081,8 @@ def adapt_typehints( ) try: - val_class = import_object(resolve_class_path_by_name(typehint, val["class_path"])) + class_path = resolve_class_path_by_name(typehint, val["class_path"]) + val_class = import_object(class_path) if is_instance_or_supports_protocol(val_class, typehint): return val_class # importable instance if is_protocol(val_class): @@ -1098,10 +1099,14 @@ def adapt_typehints( elif prev_implicit_defaults: inner_parser = ActionTypeHint.get_class_parser(typehint, sub_add_kwargs) prev_val.init_args = inner_parser.get_defaults() + if prev_val.class_path != class_path: + inner_parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs) + for key in inner_parser.get_defaults().keys(): + prev_val.init_args.pop(key, None) if not_subclass: msg = "implement protocol" if is_protocol(typehint) else "correspond to a subclass of" raise_unexpected_value(f"Import path {val['class_path']} does not {msg} {typehint.__name__}") - val["class_path"] = get_import_path(val_class) + val["class_path"] = class_path val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val) except (ImportError, AttributeError, AssertionError, ArgumentError) as ex: class_path = val if isinstance(val, str) else val["class_path"] diff --git a/jsonargparse_tests/test_subclasses.py b/jsonargparse_tests/test_subclasses.py index 7bb4ff21..86152c79 100644 --- a/jsonargparse_tests/test_subclasses.py +++ b/jsonargparse_tests/test_subclasses.py @@ -51,6 +51,24 @@ def test_subclass_basics(parser, type): assert init["op"] is None +class BaseClassDefault: + def __init__(self, param: str = "base_default"): + self.param = param + + +class SubClassDefault(BaseClassDefault): + def __init__(self, param: str = "sub_default"): + super().__init__(param=param) + + +def test_subclass_defaults(parser): + parser.add_subclass_arguments(BaseClassDefault, "cls") + cfg = parser.parse_args(["--cls=BaseClassDefault"]) + assert cfg.cls.init_args.param == "base_default" + cfg = parser.parse_args(["--cls=SubClassDefault"]) + assert cfg.cls.init_args.param == "sub_default" + + def test_subclass_init_args_in_subcommand(parser, subparser): subparser.add_subclass_arguments(Calendar, "cal", default=lazy_instance(Calendar)) subcommands = parser.add_subcommands()