From 1dccf55064bb400a706bc2eb22102e0b9ae38864 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Fri, 3 Oct 2025 10:18:21 +0200 Subject: [PATCH 1/2] Fix save with multifile=True not saving separate subconfigs for items in a list (#777). --- CHANGELOG.rst | 2 ++ jsonargparse/_core.py | 47 ++++++++++++++++++++------------- jsonargparse/_namespace.py | 8 ++++-- jsonargparse/_typehints.py | 4 +-- jsonargparse_tests/test_core.py | 41 ++++++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0f464aa5..59c62edc 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -35,6 +35,8 @@ Fixed (`#772 `__). - ``omegaconf+`` parser mode failing when there are ``inf``, ``-inf`` or ``nan`` values (`#773 `__). +- ``save`` with ``multifile=True`` not saving separate subconfigs for items in a + list (`#779 `__). v4.41.0 (2025-09-04) diff --git a/jsonargparse/_core.py b/jsonargparse/_core.py index a97801bb..923067c8 100644 --- a/jsonargparse/_core.py +++ b/jsonargparse/_core.py @@ -930,25 +930,36 @@ def check_overwrite(path): ActionLink.strip_link_target_keys(self, cfg) + def is_path_action(key): + action = _find_action(self, key) + return isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad)) + + def save_path(val): + val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc") + check_overwrite(val_path) + val_out = strip_meta(val) + if isinstance(val, Namespace): + val_out = val_out.as_dict() + if "__orig__" in val: + val_str = val["__orig__"] + else: + is_json = str(val_path).lower().endswith(".json") + val_str = dump_using_format(self, val_out, "json_indented" if is_json else format) + with open(val_path.absolute, "w") as f: + f.write(val_str) + return os.path.basename(val_path) + def save_paths(cfg): for key in cfg.get_sorted_keys(): val = cfg[key] if isinstance(val, (Namespace, dict)) and "__path__" in val: - action = _find_action(self, key) - if isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad)): - val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc") - check_overwrite(val_path) - val_out = strip_meta(val) - if isinstance(val, Namespace): - val_out = val_out.as_dict() - if "__orig__" in val: - val_str = val["__orig__"] - else: - is_json = str(val_path).lower().endswith(".json") - val_str = dump_using_format(self, val_out, "json_indented" if is_json else format) - with open(val_path.absolute, "w") as f: - f.write(val_str) - cfg[key] = os.path.basename(val_path.absolute) + if is_path_action(key): + cfg[key] = save_path(val) + elif isinstance(val, list): + if is_path_action(key): + for num, item in enumerate(val): + if isinstance(item, (Namespace, dict)) and "__path__" in item: + val[num] = save_path(item) elif isinstance(val, Path) and key in self.save_path_content and "r" in val.mode: val_path = Path(os.path.basename(val.absolute), mode="fc") check_overwrite(val_path) @@ -1356,9 +1367,9 @@ def _apply_actions( cfg_branch = cfg cfg = Namespace() cfg[parent_key] = cfg_branch - keys = [parent_key + "." + k for k in cfg_branch.__dict__] + keys = [parent_key + "." + k for k in cfg_branch.keys(branches=True, nested=False)] else: - keys = list(cfg.__dict__) + keys = list(cfg.keys(branches=True, nested=False)) if prev_cfg: prev_cfg = prev_cfg.clone() @@ -1385,7 +1396,7 @@ def _apply_actions( if isinstance(value, dict): value = Namespace(value) if isinstance(value, Namespace): - new_keys = value.__dict__.keys() + new_keys = value.keys(branches=True, nested=False) keys += [key + "." + k for k in new_keys if key + "." + k not in keys] cfg[key] = value continue diff --git a/jsonargparse/_namespace.py b/jsonargparse/_namespace.py index f1f2d212..e440f2d6 100644 --- a/jsonargparse/_namespace.py +++ b/jsonargparse/_namespace.py @@ -233,9 +233,11 @@ def as_flat(self) -> argparse.Namespace: setattr(flat, key, val) return flat - def items(self, branches: bool = False) -> Iterator[Tuple[str, Any]]: + def items(self, branches: bool = False, nested: bool = True) -> Iterator[Tuple[str, Any]]: """Returns a generator of all leaf (key, value) items, optionally including branches.""" for key, val in vars(self).items(): + if not nested and "." in key: + continue key = del_clash_mark(key) if isinstance(val, Namespace): if branches: @@ -245,9 +247,11 @@ def items(self, branches: bool = False) -> Iterator[Tuple[str, Any]]: else: yield key, val - def keys(self, branches: bool = False) -> Iterator[str]: + def keys(self, branches: bool = False, nested: bool = True) -> Iterator[str]: """Returns a generator of all leaf keys, optionally including branches.""" for key, _ in self.items(branches): + if not nested and "." in key: + continue yield key def values(self, branches: bool = False) -> Iterator[Any]: diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 2637f30a..02a55bd4 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -1396,7 +1396,7 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value): parser = ActionTypeHint.get_class_parser(value["class_path"], sub_add_kwargs) del_args = {} prev_val = subclass_spec_as_namespace(prev_val) - for key, val in list(prev_val.init_args.__dict__.items()): + for key, val in list(prev_val.init_args.items(branches=True, nested=False)): action = _find_action(parser, key) if action: with parser_context(lenient_check=False, load_value_mode=parser.parser_mode): @@ -1515,7 +1515,7 @@ def adapt_classes_any(val, serialize, instantiate_classes, sub_add_kwargs): val = subclass_spec_as_namespace(val) init_args = val.get("init_args") if init_args and not instantiate_classes: - for subkey, subval in init_args.__dict__.items(): + for subkey, subval in init_args.items(branches=True, nested=False): init_args[subkey] = adapt_classes_any(subval, serialize, instantiate_classes, sub_add_kwargs) val["init_args"] = init_args try: diff --git a/jsonargparse_tests/test_core.py b/jsonargparse_tests/test_core.py index 019f680c..13dad120 100644 --- a/jsonargparse_tests/test_core.py +++ b/jsonargparse_tests/test_core.py @@ -699,6 +699,47 @@ def rm_out_files(): assert json_or_yaml_load(schema_yaml_out.read_text()) == {"a": 1, "b": 2} +class ListItem: + def __init__(self, a: int, b: str): + self.a = a + self.b = b + + +def test_save_multifile_list(parser, tmp_cwd): + in_dir = tmp_cwd / "input" + out_dir = tmp_cwd / "output" + in_dir.mkdir() + out_dir.mkdir() + main_file_in = in_dir / "main.yaml" + main_file_out = out_dir / "main.yaml" + item_1_file_in = in_dir / "item1.yaml" + item_1_file_out = out_dir / "item1.yaml" + item_2_file_in = in_dir / "item2.yaml" + item_2_file_out = out_dir / "item2.yaml" + + main_content = {"items": ["item1.yaml", "item2.yaml"]} + main_file_in.write_text(json_or_yaml_dump(main_content)) + item_1_file_in.write_text(json_or_yaml_dump({"a": 1, "b": "x"})) + item_2_file_in.write_text(json_or_yaml_dump({"a": 2, "b": "y"})) + + parser.add_argument("--config", action="config") + parser.add_argument( + "--items", + nargs="+", + type=ListItem, + required=True, + enable_path=True, + ) + + cfg = parser.parse_args([f"--config={main_file_in}"]) + parser.save(cfg, main_file_out) + assert json_or_yaml_load(main_file_out.read_text()) == main_content + item_1_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 1, "b": "x"}} + assert json_or_yaml_load(item_1_file_out.read_text()) == item_1_expected + item_2_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 2, "b": "y"}} + assert json_or_yaml_load(item_2_file_out.read_text()) == item_2_expected + + def test_save_overwrite(example_parser, tmp_cwd): cfg = example_parser.parse_args(["--nums.val1=7"]) example_parser.save(cfg, "config.yaml") From d389da959dbacfb5a9c567cde853eeecea6cb433 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 8 Oct 2025 06:04:47 +0200 Subject: [PATCH 2/2] Tests for namespace nested parameter --- jsonargparse/_namespace.py | 8 +++---- jsonargparse_tests/test_namespace.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/jsonargparse/_namespace.py b/jsonargparse/_namespace.py index e440f2d6..1ae156be 100644 --- a/jsonargparse/_namespace.py +++ b/jsonargparse/_namespace.py @@ -236,12 +236,12 @@ def as_flat(self) -> argparse.Namespace: def items(self, branches: bool = False, nested: bool = True) -> Iterator[Tuple[str, Any]]: """Returns a generator of all leaf (key, value) items, optionally including branches.""" for key, val in vars(self).items(): - if not nested and "." in key: - continue key = del_clash_mark(key) if isinstance(val, Namespace): if branches: yield key, val + if not nested: + continue for subkey, subval in val.items(branches): yield key + "." + del_clash_mark(subkey), subval else: @@ -249,9 +249,7 @@ def items(self, branches: bool = False, nested: bool = True) -> Iterator[Tuple[s def keys(self, branches: bool = False, nested: bool = True) -> Iterator[str]: """Returns a generator of all leaf keys, optionally including branches.""" - for key, _ in self.items(branches): - if not nested and "." in key: - continue + for key, _ in self.items(branches=branches, nested=nested): yield key def values(self, branches: bool = False) -> Iterator[Any]: diff --git a/jsonargparse_tests/test_namespace.py b/jsonargparse_tests/test_namespace.py index 77552f09..3ac08afb 100644 --- a/jsonargparse_tests/test_namespace.py +++ b/jsonargparse_tests/test_namespace.py @@ -296,3 +296,35 @@ def test_add_argument_meta_key_error(meta_key, parser): with pytest.raises(ValueError) as ctx: parser.add_argument(meta_key) ctx.match(f'"{meta_key}" not allowed') + + +def test_items_branches_nested(): + ns = Namespace() + ns["a.b"] = 1 + ns["a.c"] = 2 + ns["d"] = 3 + + items = list(ns.items(branches=True)) + assert items == [("a", Namespace(b=1, c=2)), ("a.b", 1), ("a.c", 2), ("d", 3)] + + items = list(ns.items(branches=True, nested=False)) + assert items == [("a", Namespace(b=1, c=2)), ("d", 3)] + + items = list(ns.items(nested=False)) + assert items == [("d", 3)] + + +def test_keys_branches_nested(): + ns = Namespace() + ns["a.b"] = 1 + ns["a.c"] = 2 + ns["d"] = 3 + + keys = list(ns.keys(branches=True)) + assert keys == ["a", "a.b", "a.c", "d"] + + keys = list(ns.keys(branches=True, nested=False)) + assert keys == ["a", "d"] + + keys = list(ns.keys(nested=False)) + assert keys == ["d"]