diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bc164350..f4b8faef 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,6 +27,10 @@ Added - Argument to print help for dataclasses nested in types, e.g. ``Optional[Data]``, ``Union[Data1, Data2]`` (`#783 `__). +- ``set_parsing_settings`` now supports ``omegaconf_absolute_to_relative_paths`` + to enable backward compatibility of ``omegaconf+`` parser mode by converting + absolute paths to relative in interpolations (`#774 + `__). Fixed ^^^^^ diff --git a/DOCUMENTATION.rst b/DOCUMENTATION.rst index b2114c9b..58e8a2c6 100644 --- a/DOCUMENTATION.rst +++ b/DOCUMENTATION.rst @@ -2475,7 +2475,10 @@ limitations of the ``omegaconf`` mode mentioned earlier. Instead of applying OmegaConf resolvers to each YAML config individually, the resolving is performed once at the end of the parsing process. As a result, in nested sub-configs, references to nodes must be either relative or parser-level absolute to function -correctly. +correctly. Alternatively, you can +``set_parsing_settings(omegaconf_absolute_to_relative_paths=True)`` to enable +automatic conversion of absolute paths to relative ones during parsing. Be aware +that this automatic conversion does not work for every possible case. Based on community feedback, this mode may become the default ``omegaconf`` mode in version 5.0.0. This change would introduce a breaking modification, as diff --git a/jsonargparse/_common.py b/jsonargparse/_common.py index e383d159..4467e67f 100644 --- a/jsonargparse/_common.py +++ b/jsonargparse/_common.py @@ -100,6 +100,7 @@ def parser_context(**kwargs): "validate_defaults": False, "parse_optionals_as_positionals": False, "stubs_resolver_allow_py_files": False, + "omegaconf_absolute_to_relative_paths": False, } @@ -112,6 +113,7 @@ def set_parsing_settings( docstring_parse_attribute_docstrings: Optional[bool] = None, parse_optionals_as_positionals: Optional[bool] = None, stubs_resolver_allow_py_files: Optional[bool] = None, + omegaconf_absolute_to_relative_paths: Optional[bool] = None, ) -> None: """ Modify settings that affect the parsing behavior. @@ -136,6 +138,10 @@ def set_parsing_settings( parser. By default, this is False. stubs_resolver_allow_py_files: Whether the stubs resolver should search in ``.py`` files in addition to ``.pyi`` files. + omegaconf_absolute_to_relative_paths: If True, when loading configs + with ``omegaconf+`` parser mode, absolute interpolation paths are + converted to relative. This is only intended for backward + compatibility with ``omegaconf`` parser mode. """ # validate_defaults if isinstance(validate_defaults, bool): @@ -162,6 +168,13 @@ def set_parsing_settings( parsing_settings["stubs_resolver_allow_py_files"] = stubs_resolver_allow_py_files elif stubs_resolver_allow_py_files is not None: raise ValueError(f"stubs_resolver_allow_py_files must be a boolean, but got {stubs_resolver_allow_py_files}.") + # omegaconf_absolute_to_relative_paths + if isinstance(omegaconf_absolute_to_relative_paths, bool): + parsing_settings["omegaconf_absolute_to_relative_paths"] = omegaconf_absolute_to_relative_paths + elif omegaconf_absolute_to_relative_paths is not None: + raise ValueError( + f"omegaconf_absolute_to_relative_paths must be a boolean, but got {omegaconf_absolute_to_relative_paths}." + ) def get_parsing_setting(name: str): diff --git a/jsonargparse/_loaders_dumpers.py b/jsonargparse/_loaders_dumpers.py index 66932892..e8be44ba 100644 --- a/jsonargparse/_loaders_dumpers.py +++ b/jsonargparse/_loaders_dumpers.py @@ -339,7 +339,7 @@ def set_omegaconf_loader(mode="omegaconf"): if omegaconf_support and mode not in loaders: from ._optionals import get_omegaconf_loader - loader = yaml_load if mode == "omegaconf+" else get_omegaconf_loader() + loader = get_omegaconf_loader(mode) set_loader(mode, loader, get_loader_exceptions("yaml")) diff --git a/jsonargparse/_optionals.py b/jsonargparse/_optionals.py index 2df24c48..587e206c 100644 --- a/jsonargparse/_optionals.py +++ b/jsonargparse/_optionals.py @@ -2,10 +2,12 @@ import inspect import os +import re from contextlib import contextmanager +from copy import deepcopy from importlib.metadata import version from importlib.util import find_spec -from typing import Optional, Union +from typing import List, Optional, Union __all__ = [ "_get_config_read_mode", @@ -266,7 +268,7 @@ def get_doc_short_description(function_or_class, method_name=None, logger=None): return None -def get_omegaconf_loader(): +def get_omegaconf_loader(mode): """Returns a yaml loader function based on OmegaConf which supports variable interpolation.""" import io @@ -275,6 +277,22 @@ def get_omegaconf_loader(): with missing_package_raise("omegaconf", "get_omegaconf_loader"): from omegaconf import OmegaConf + assert mode in {"omegaconf", "omegaconf+"} + + if mode == "omegaconf+": + from ._common import get_parsing_setting + + if not get_parsing_setting("omegaconf_absolute_to_relative_paths"): + return yaml_load + + def omegaconf_plus_load(value): + value = yaml_load(value) + if isinstance(value, dict): + value = omegaconf_absolute_to_relative_paths(value) + return value + + return omegaconf_plus_load + def omegaconf_load(value): value_pyyaml = yaml_load(value) if isinstance(value_pyyaml, (str, int, float, bool)) or value_pyyaml is None: @@ -302,6 +320,57 @@ def omegaconf_apply(parser, cfg): return parser._apply_actions(cfg_dict) +def omegaconf_tokenize(path: str) -> List[str]: + """Very small tokenizer: 'a.b[0].c' -> ['a','b','0','c'].""" + return [t for t in path.replace("]", "").replace("[", ".").split(".") if t] + + +def omegaconf_tokens_to_path(tokens: List[str]) -> str: + """Render tokens back to a normalized path: ['a','0','b'] -> 'a[0].b'.""" + s = "" + for t in tokens: + if t.isdigit(): + s += f"[{t}]" + else: + s += ("" if s == "" else ".") + t + return s + + +def omegaconf_absolute_to_relative_paths(data: dict) -> dict: + """ + Return a new nested dict/list where absolute ${...} interpolations + are rewritten to relative form from the node where they appear. + """ + data = deepcopy(data) + + regex_absolute_path = re.compile(r"\$\{([a-zA-Z][a-zA-Z0-9[\]_.]*)\}") + + def _walk(node, current_path: List[Union[str, int]]): + if isinstance(node, dict): + return {k: _walk(v, current_path + [k]) for k, v in node.items()} + if isinstance(node, list): + return [_walk(v, current_path + [i]) for i, v in enumerate(node)] + + if isinstance(node, str): + + def _replace(m: re.Match) -> str: + dst_tokens = omegaconf_tokenize(m.group(1)) + # compute common prefix length + i = 0 + while i < len(current_path) and i < len(dst_tokens) and str(current_path[i]) == dst_tokens[i]: + i += 1 + up = max(1, len(current_path) - i) + dots = "." * up + down = omegaconf_tokens_to_path(dst_tokens[i:]) + return "${" + dots + down + "}" + + return regex_absolute_path.sub(_replace, node) + + return node + + return _walk(data, []) + + annotated_alias = typing_extensions_import("_AnnotatedAlias") diff --git a/jsonargparse_tests/test_omegaconf.py b/jsonargparse_tests/test_omegaconf.py index a742009c..64cbb51c 100644 --- a/jsonargparse_tests/test_omegaconf.py +++ b/jsonargparse_tests/test_omegaconf.py @@ -10,9 +10,9 @@ import pytest from jsonargparse import ArgumentParser, Namespace -from jsonargparse._common import parser_context +from jsonargparse._common import parser_context, set_parsing_settings from jsonargparse._loaders_dumpers import loaders, yaml_dump -from jsonargparse._optionals import omegaconf_support +from jsonargparse._optionals import omegaconf_absolute_to_relative_paths, omegaconf_support from jsonargparse.typing import Path_fr from jsonargparse_tests.conftest import get_parser_help @@ -25,6 +25,12 @@ ) +@pytest.fixture(autouse=True) +def patch_loaders(): + with patch.dict("jsonargparse._loaders_dumpers.loaders"): + yield + + @pytest.mark.skipif( not (omegaconf_support and "JSONARGPARSE_OMEGACONF_FULL_TEST" in os.environ), reason="only for omegaconf as the yaml loader", @@ -57,19 +63,23 @@ def test_omegaconf_interpolation(mode): @skip_if_omegaconf_unavailable -@pytest.mark.parametrize("mode", ["omegaconf", "omegaconf+"]) +@pytest.mark.parametrize("mode", ["omegaconf", "omegaconf+", "omegaconf+absolute"]) +@patch.dict("jsonargparse._common.parsing_settings") def test_omegaconf_interpolation_in_subcommands(mode, parser, subparser): subparser.add_argument("--config", action="config") subparser.add_argument("--source", type=str) subparser.add_argument("--target", type=str) - parser.parser_mode = mode + if mode == "omegaconf+absolute": + set_parsing_settings(omegaconf_absolute_to_relative_paths=True) + + parser.parser_mode = mode.replace("absolute", "") subcommands = parser.add_subcommands() subcommands.add_subcommand("sub", subparser) config = { "source": "hello", - "target": "${source}" if mode == "omegaconf" else "${.source}", + "target": "${.source}" if mode == "omegaconf+" else "${source}", } cfg = parser.parse_args(["sub", f"--config={yaml_dump(config)}"]) assert cfg.sub.target == "hello" @@ -193,3 +203,24 @@ def test_omegaconf_inf_nan(parser): assert math.isnan(cfg.c) assert cfg.d == float("inf") assert cfg.e == float("-inf") + + +@skip_if_omegaconf_unavailable +def test_omegaconf_absolute_to_relative_paths(): + data = { + "a": "x", + "b": "prefix ${a} suffix", + "c": {"d": "${b}", "e": "${c.d}"}, + "f": [10, "${c.e}", "${..b}"], + "g": "${env:USER}", + "h": "${f[0]}", + } + expected = { + "a": "x", + "b": "prefix ${.a} suffix", + "c": {"d": "${..b}", "e": "${.d}"}, + "f": [10, "${..c.e}", "${..b}"], + "g": "${env:USER}", + "h": "${.f[0]}", + } + assert omegaconf_absolute_to_relative_paths(data) == expected diff --git a/jsonargparse_tests/test_parsing_settings.py b/jsonargparse_tests/test_parsing_settings.py index 0d2d2c6b..c1c824ed 100644 --- a/jsonargparse_tests/test_parsing_settings.py +++ b/jsonargparse_tests/test_parsing_settings.py @@ -195,3 +195,11 @@ def test_optionals_as_positionals_unsupported_arguments(parser): def test_set_stubs_resolver_allow_py_files_failure(): with pytest.raises(ValueError, match="stubs_resolver_allow_py_files must be a boolean"): set_parsing_settings(stubs_resolver_allow_py_files="invalid") + + +# omegaconf_absolute_to_relative_paths + + +def test_set_omegaconf_absolute_to_relative_paths_failure(): + with pytest.raises(ValueError, match="omegaconf_absolute_to_relative_paths must be a boolean"): + set_parsing_settings(omegaconf_absolute_to_relative_paths="invalid")