Skip to content

Commit 89b50d3

Browse files
authored
Fix save with multifile=True not saving separate subconfigs for items in a list (#779)
1 parent b4ff170 commit 89b50d3

File tree

6 files changed

+111
-23
lines changed

6 files changed

+111
-23
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ Fixed
3535
(`#772 <https://github.com/omni-us/jsonargparse/pull/772>`__).
3636
- ``omegaconf+`` parser mode failing when there are ``inf``, ``-inf`` or ``nan``
3737
values (`#773 <https://github.com/omni-us/jsonargparse/pull/773>`__).
38+
- ``save`` with ``multifile=True`` not saving separate subconfigs for items in a
39+
list (`#779 <https://github.com/omni-us/jsonargparse/pull/779>`__).
3840

3941

4042
v4.41.0 (2025-09-04)

jsonargparse/_core.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -930,25 +930,36 @@ def check_overwrite(path):
930930

931931
ActionLink.strip_link_target_keys(self, cfg)
932932

933+
def is_path_action(key):
934+
action = _find_action(self, key)
935+
return isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad))
936+
937+
def save_path(val):
938+
val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc")
939+
check_overwrite(val_path)
940+
val_out = strip_meta(val)
941+
if isinstance(val, Namespace):
942+
val_out = val_out.as_dict()
943+
if "__orig__" in val:
944+
val_str = val["__orig__"]
945+
else:
946+
is_json = str(val_path).lower().endswith(".json")
947+
val_str = dump_using_format(self, val_out, "json_indented" if is_json else format)
948+
with open(val_path.absolute, "w") as f:
949+
f.write(val_str)
950+
return os.path.basename(val_path)
951+
933952
def save_paths(cfg):
934953
for key in cfg.get_sorted_keys():
935954
val = cfg[key]
936955
if isinstance(val, (Namespace, dict)) and "__path__" in val:
937-
action = _find_action(self, key)
938-
if isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad)):
939-
val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc")
940-
check_overwrite(val_path)
941-
val_out = strip_meta(val)
942-
if isinstance(val, Namespace):
943-
val_out = val_out.as_dict()
944-
if "__orig__" in val:
945-
val_str = val["__orig__"]
946-
else:
947-
is_json = str(val_path).lower().endswith(".json")
948-
val_str = dump_using_format(self, val_out, "json_indented" if is_json else format)
949-
with open(val_path.absolute, "w") as f:
950-
f.write(val_str)
951-
cfg[key] = os.path.basename(val_path.absolute)
956+
if is_path_action(key):
957+
cfg[key] = save_path(val)
958+
elif isinstance(val, list):
959+
if is_path_action(key):
960+
for num, item in enumerate(val):
961+
if isinstance(item, (Namespace, dict)) and "__path__" in item:
962+
val[num] = save_path(item)
952963
elif isinstance(val, Path) and key in self.save_path_content and "r" in val.mode:
953964
val_path = Path(os.path.basename(val.absolute), mode="fc")
954965
check_overwrite(val_path)
@@ -1356,9 +1367,9 @@ def _apply_actions(
13561367
cfg_branch = cfg
13571368
cfg = Namespace()
13581369
cfg[parent_key] = cfg_branch
1359-
keys = [parent_key + "." + k for k in cfg_branch.__dict__]
1370+
keys = [parent_key + "." + k for k in cfg_branch.keys(branches=True, nested=False)]
13601371
else:
1361-
keys = list(cfg.__dict__)
1372+
keys = list(cfg.keys(branches=True, nested=False))
13621373

13631374
if prev_cfg:
13641375
prev_cfg = prev_cfg.clone()
@@ -1385,7 +1396,7 @@ def _apply_actions(
13851396
if isinstance(value, dict):
13861397
value = Namespace(value)
13871398
if isinstance(value, Namespace):
1388-
new_keys = value.__dict__.keys()
1399+
new_keys = value.keys(branches=True, nested=False)
13891400
keys += [key + "." + k for k in new_keys if key + "." + k not in keys]
13901401
cfg[key] = value
13911402
continue

jsonargparse/_namespace.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,23 @@ def as_flat(self) -> argparse.Namespace:
233233
setattr(flat, key, val)
234234
return flat
235235

236-
def items(self, branches: bool = False) -> Iterator[Tuple[str, Any]]:
236+
def items(self, branches: bool = False, nested: bool = True) -> Iterator[Tuple[str, Any]]:
237237
"""Returns a generator of all leaf (key, value) items, optionally including branches."""
238238
for key, val in vars(self).items():
239239
key = del_clash_mark(key)
240240
if isinstance(val, Namespace):
241241
if branches:
242242
yield key, val
243+
if not nested:
244+
continue
243245
for subkey, subval in val.items(branches):
244246
yield key + "." + del_clash_mark(subkey), subval
245247
else:
246248
yield key, val
247249

248-
def keys(self, branches: bool = False) -> Iterator[str]:
250+
def keys(self, branches: bool = False, nested: bool = True) -> Iterator[str]:
249251
"""Returns a generator of all leaf keys, optionally including branches."""
250-
for key, _ in self.items(branches):
252+
for key, _ in self.items(branches=branches, nested=nested):
251253
yield key
252254

253255
def values(self, branches: bool = False) -> Iterator[Any]:

jsonargparse/_typehints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,7 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
13961396
parser = ActionTypeHint.get_class_parser(value["class_path"], sub_add_kwargs)
13971397
del_args = {}
13981398
prev_val = subclass_spec_as_namespace(prev_val)
1399-
for key, val in list(prev_val.init_args.__dict__.items()):
1399+
for key, val in list(prev_val.init_args.items(branches=True, nested=False)):
14001400
action = _find_action(parser, key)
14011401
if action:
14021402
with parser_context(lenient_check=False, load_value_mode=parser.parser_mode):
@@ -1515,7 +1515,7 @@ def adapt_classes_any(val, serialize, instantiate_classes, sub_add_kwargs):
15151515
val = subclass_spec_as_namespace(val)
15161516
init_args = val.get("init_args")
15171517
if init_args and not instantiate_classes:
1518-
for subkey, subval in init_args.__dict__.items():
1518+
for subkey, subval in init_args.items(branches=True, nested=False):
15191519
init_args[subkey] = adapt_classes_any(subval, serialize, instantiate_classes, sub_add_kwargs)
15201520
val["init_args"] = init_args
15211521
try:

jsonargparse_tests/test_core.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,47 @@ def rm_out_files():
699699
assert json_or_yaml_load(schema_yaml_out.read_text()) == {"a": 1, "b": 2}
700700

701701

702+
class ListItem:
703+
def __init__(self, a: int, b: str):
704+
self.a = a
705+
self.b = b
706+
707+
708+
def test_save_multifile_list(parser, tmp_cwd):
709+
in_dir = tmp_cwd / "input"
710+
out_dir = tmp_cwd / "output"
711+
in_dir.mkdir()
712+
out_dir.mkdir()
713+
main_file_in = in_dir / "main.yaml"
714+
main_file_out = out_dir / "main.yaml"
715+
item_1_file_in = in_dir / "item1.yaml"
716+
item_1_file_out = out_dir / "item1.yaml"
717+
item_2_file_in = in_dir / "item2.yaml"
718+
item_2_file_out = out_dir / "item2.yaml"
719+
720+
main_content = {"items": ["item1.yaml", "item2.yaml"]}
721+
main_file_in.write_text(json_or_yaml_dump(main_content))
722+
item_1_file_in.write_text(json_or_yaml_dump({"a": 1, "b": "x"}))
723+
item_2_file_in.write_text(json_or_yaml_dump({"a": 2, "b": "y"}))
724+
725+
parser.add_argument("--config", action="config")
726+
parser.add_argument(
727+
"--items",
728+
nargs="+",
729+
type=ListItem,
730+
required=True,
731+
enable_path=True,
732+
)
733+
734+
cfg = parser.parse_args([f"--config={main_file_in}"])
735+
parser.save(cfg, main_file_out)
736+
assert json_or_yaml_load(main_file_out.read_text()) == main_content
737+
item_1_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 1, "b": "x"}}
738+
assert json_or_yaml_load(item_1_file_out.read_text()) == item_1_expected
739+
item_2_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 2, "b": "y"}}
740+
assert json_or_yaml_load(item_2_file_out.read_text()) == item_2_expected
741+
742+
702743
def test_save_overwrite(example_parser, tmp_cwd):
703744
cfg = example_parser.parse_args(["--nums.val1=7"])
704745
example_parser.save(cfg, "config.yaml")

jsonargparse_tests/test_namespace.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,35 @@ def test_add_argument_meta_key_error(meta_key, parser):
296296
with pytest.raises(ValueError) as ctx:
297297
parser.add_argument(meta_key)
298298
ctx.match(f'"{meta_key}" not allowed')
299+
300+
301+
def test_items_branches_nested():
302+
ns = Namespace()
303+
ns["a.b"] = 1
304+
ns["a.c"] = 2
305+
ns["d"] = 3
306+
307+
items = list(ns.items(branches=True))
308+
assert items == [("a", Namespace(b=1, c=2)), ("a.b", 1), ("a.c", 2), ("d", 3)]
309+
310+
items = list(ns.items(branches=True, nested=False))
311+
assert items == [("a", Namespace(b=1, c=2)), ("d", 3)]
312+
313+
items = list(ns.items(nested=False))
314+
assert items == [("d", 3)]
315+
316+
317+
def test_keys_branches_nested():
318+
ns = Namespace()
319+
ns["a.b"] = 1
320+
ns["a.c"] = 2
321+
ns["d"] = 3
322+
323+
keys = list(ns.keys(branches=True))
324+
assert keys == ["a", "a.b", "a.c", "d"]
325+
326+
keys = list(ns.keys(branches=True, nested=False))
327+
assert keys == ["a", "d"]
328+
329+
keys = list(ns.keys(nested=False))
330+
assert keys == ["d"]

0 commit comments

Comments
 (0)