diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d5a91c85..1fc5056b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,7 +11,6 @@ The semantic versioning only considers the public API as described in :ref:`api-ref`. Components not mentioned in :ref:`api-ref` or different import paths are considered internals and can change in minor and patch releases. - v4.40.1 (2025-05-??) -------------------- @@ -19,6 +18,9 @@ Fixed ^^^^^ - ``print_shtab`` incorrectly parsed from environment variable (`#725 `__). +- ``adapt_class_type`` used a locally defined `partial_instance` wrapper + function that is not pickleable (`#728 + `__). v4.40.0 (2025-05-16) diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index c093eab5..4a49264f 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1424,11 +1424,11 @@ def adapt_class_type( instantiator_fn = get_class_instantiator() if partial_classes: - - def partial_instance(*args): - return instantiator_fn(val_class, *args, **{**init_args, **dict_kwargs}) - - return partial_instance + return partial( + instantiator_fn, + val_class, + **{**init_args, **dict_kwargs}, + ) return instantiator_fn(val_class, **{**init_args, **dict_kwargs}) prev_init_args = prev_val.get("init_args") if isinstance(prev_val, Namespace) else None diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 0e2ac72a..d7fb2c2e 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -1239,6 +1239,21 @@ def test_callable_args_return_type_class_subconfig(parser, tmp_cwd): assert optimizer.momentum == 0.8 +def test_callable_args_pickleable(parser, tmp_cwd): + config = { + "class_path": "Adam", + "init_args": {"momentum": 0.8}, + } + Path("optimizer.yaml").write_text(json_or_yaml_dump(config)) + parser.add_class_arguments(CallableSubconfig, "m", sub_configs=True) + cfg = parser.parse_args(["--m.o=optimizer.yaml"]) + init = parser.instantiate_classes(cfg) + + filepath = str(tmp_cwd) + "/pickled.pkl" + with open(filepath, "wb") as f: + pickle.dump(init, f) + + class Module: pass