Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Fixed
(`#772 <https://github.com/omni-us/jsonargparse/pull/772>`__).
- ``omegaconf+`` parser mode failing when there are ``inf``, ``-inf`` or ``nan``
values (`#773 <https://github.com/omni-us/jsonargparse/pull/773>`__).
- ``save`` with ``multifile=True`` not saving separate subconfigs for items in a
list (`#779 <https://github.com/omni-us/jsonargparse/pull/779>`__).


v4.41.0 (2025-09-04)
Expand Down
47 changes: 29 additions & 18 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,25 +930,36 @@ def check_overwrite(path):

ActionLink.strip_link_target_keys(self, cfg)

def is_path_action(key):
action = _find_action(self, key)
return isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad))

def save_path(val):
val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc")
check_overwrite(val_path)
val_out = strip_meta(val)
if isinstance(val, Namespace):
val_out = val_out.as_dict()
if "__orig__" in val:
val_str = val["__orig__"]
else:
is_json = str(val_path).lower().endswith(".json")
val_str = dump_using_format(self, val_out, "json_indented" if is_json else format)
with open(val_path.absolute, "w") as f:
f.write(val_str)
return os.path.basename(val_path)

def save_paths(cfg):
for key in cfg.get_sorted_keys():
val = cfg[key]
if isinstance(val, (Namespace, dict)) and "__path__" in val:
action = _find_action(self, key)
if isinstance(action, (ActionJsonSchema, ActionJsonnet, ActionTypeHint, _ActionConfigLoad)):
val_path = Path(os.path.basename(val["__path__"].absolute), mode="fc")
check_overwrite(val_path)
val_out = strip_meta(val)
if isinstance(val, Namespace):
val_out = val_out.as_dict()
if "__orig__" in val:
val_str = val["__orig__"]
else:
is_json = str(val_path).lower().endswith(".json")
val_str = dump_using_format(self, val_out, "json_indented" if is_json else format)
with open(val_path.absolute, "w") as f:
f.write(val_str)
cfg[key] = os.path.basename(val_path.absolute)
if is_path_action(key):
cfg[key] = save_path(val)
elif isinstance(val, list):
if is_path_action(key):
for num, item in enumerate(val):
if isinstance(item, (Namespace, dict)) and "__path__" in item:
val[num] = save_path(item)
elif isinstance(val, Path) and key in self.save_path_content and "r" in val.mode:
val_path = Path(os.path.basename(val.absolute), mode="fc")
check_overwrite(val_path)
Expand Down Expand Up @@ -1356,9 +1367,9 @@ def _apply_actions(
cfg_branch = cfg
cfg = Namespace()
cfg[parent_key] = cfg_branch
keys = [parent_key + "." + k for k in cfg_branch.__dict__]
keys = [parent_key + "." + k for k in cfg_branch.keys(branches=True, nested=False)]
else:
keys = list(cfg.__dict__)
keys = list(cfg.keys(branches=True, nested=False))

if prev_cfg:
prev_cfg = prev_cfg.clone()
Expand All @@ -1385,7 +1396,7 @@ def _apply_actions(
if isinstance(value, dict):
value = Namespace(value)
if isinstance(value, Namespace):
new_keys = value.__dict__.keys()
new_keys = value.keys(branches=True, nested=False)
keys += [key + "." + k for k in new_keys if key + "." + k not in keys]
cfg[key] = value
continue
Expand Down
8 changes: 5 additions & 3 deletions jsonargparse/_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,21 +233,23 @@ def as_flat(self) -> argparse.Namespace:
setattr(flat, key, val)
return flat

def items(self, branches: bool = False) -> Iterator[Tuple[str, Any]]:
def items(self, branches: bool = False, nested: bool = True) -> Iterator[Tuple[str, Any]]:
"""Returns a generator of all leaf (key, value) items, optionally including branches."""
for key, val in vars(self).items():
key = del_clash_mark(key)
if isinstance(val, Namespace):
if branches:
yield key, val
if not nested:
continue
for subkey, subval in val.items(branches):
yield key + "." + del_clash_mark(subkey), subval
else:
yield key, val

def keys(self, branches: bool = False) -> Iterator[str]:
def keys(self, branches: bool = False, nested: bool = True) -> Iterator[str]:
"""Returns a generator of all leaf keys, optionally including branches."""
for key, _ in self.items(branches):
for key, _ in self.items(branches=branches, nested=nested):
yield key

def values(self, branches: bool = False) -> Iterator[Any]:
Expand Down
4 changes: 2 additions & 2 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
parser = ActionTypeHint.get_class_parser(value["class_path"], sub_add_kwargs)
del_args = {}
prev_val = subclass_spec_as_namespace(prev_val)
for key, val in list(prev_val.init_args.__dict__.items()):
for key, val in list(prev_val.init_args.items(branches=True, nested=False)):
action = _find_action(parser, key)
if action:
with parser_context(lenient_check=False, load_value_mode=parser.parser_mode):
Expand Down Expand Up @@ -1515,7 +1515,7 @@ def adapt_classes_any(val, serialize, instantiate_classes, sub_add_kwargs):
val = subclass_spec_as_namespace(val)
init_args = val.get("init_args")
if init_args and not instantiate_classes:
for subkey, subval in init_args.__dict__.items():
for subkey, subval in init_args.items(branches=True, nested=False):
init_args[subkey] = adapt_classes_any(subval, serialize, instantiate_classes, sub_add_kwargs)
val["init_args"] = init_args
try:
Expand Down
41 changes: 41 additions & 0 deletions jsonargparse_tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,47 @@ def rm_out_files():
assert json_or_yaml_load(schema_yaml_out.read_text()) == {"a": 1, "b": 2}


class ListItem:
def __init__(self, a: int, b: str):
self.a = a
self.b = b


def test_save_multifile_list(parser, tmp_cwd):
in_dir = tmp_cwd / "input"
out_dir = tmp_cwd / "output"
in_dir.mkdir()
out_dir.mkdir()
main_file_in = in_dir / "main.yaml"
main_file_out = out_dir / "main.yaml"
item_1_file_in = in_dir / "item1.yaml"
item_1_file_out = out_dir / "item1.yaml"
item_2_file_in = in_dir / "item2.yaml"
item_2_file_out = out_dir / "item2.yaml"

main_content = {"items": ["item1.yaml", "item2.yaml"]}
main_file_in.write_text(json_or_yaml_dump(main_content))
item_1_file_in.write_text(json_or_yaml_dump({"a": 1, "b": "x"}))
item_2_file_in.write_text(json_or_yaml_dump({"a": 2, "b": "y"}))

parser.add_argument("--config", action="config")
parser.add_argument(
"--items",
nargs="+",
type=ListItem,
required=True,
enable_path=True,
)

cfg = parser.parse_args([f"--config={main_file_in}"])
parser.save(cfg, main_file_out)
assert json_or_yaml_load(main_file_out.read_text()) == main_content
item_1_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 1, "b": "x"}}
assert json_or_yaml_load(item_1_file_out.read_text()) == item_1_expected
item_2_expected = {"class_path": f"{__name__}.ListItem", "init_args": {"a": 2, "b": "y"}}
assert json_or_yaml_load(item_2_file_out.read_text()) == item_2_expected


def test_save_overwrite(example_parser, tmp_cwd):
cfg = example_parser.parse_args(["--nums.val1=7"])
example_parser.save(cfg, "config.yaml")
Expand Down
32 changes: 32 additions & 0 deletions jsonargparse_tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,35 @@ def test_add_argument_meta_key_error(meta_key, parser):
with pytest.raises(ValueError) as ctx:
parser.add_argument(meta_key)
ctx.match(f'"{meta_key}" not allowed')


def test_items_branches_nested():
ns = Namespace()
ns["a.b"] = 1
ns["a.c"] = 2
ns["d"] = 3

items = list(ns.items(branches=True))
assert items == [("a", Namespace(b=1, c=2)), ("a.b", 1), ("a.c", 2), ("d", 3)]

items = list(ns.items(branches=True, nested=False))
assert items == [("a", Namespace(b=1, c=2)), ("d", 3)]

items = list(ns.items(nested=False))
assert items == [("d", 3)]


def test_keys_branches_nested():
ns = Namespace()
ns["a.b"] = 1
ns["a.c"] = 2
ns["d"] = 3

keys = list(ns.keys(branches=True))
assert keys == ["a", "a.b", "a.c", "d"]

keys = list(ns.keys(branches=True, nested=False))
assert keys == ["a", "d"]

keys = list(ns.keys(nested=False))
assert keys == ["d"]