22
33import inspect
44import os
5+ import re
56from contextlib import contextmanager
7+ from copy import deepcopy
68from importlib .metadata import version
79from importlib .util import find_spec
8- from typing import Optional , Union
10+ from typing import List , Optional , Union
911
1012__all__ = [
1113 "_get_config_read_mode" ,
@@ -266,7 +268,7 @@ def get_doc_short_description(function_or_class, method_name=None, logger=None):
266268 return None
267269
268270
269- def get_omegaconf_loader ():
271+ def get_omegaconf_loader (mode ):
270272 """Returns a yaml loader function based on OmegaConf which supports variable interpolation."""
271273 import io
272274
@@ -275,6 +277,22 @@ def get_omegaconf_loader():
275277 with missing_package_raise ("omegaconf" , "get_omegaconf_loader" ):
276278 from omegaconf import OmegaConf
277279
280+ assert mode in {"omegaconf" , "omegaconf+" }
281+
282+ if mode == "omegaconf+" :
283+ from ._common import get_parsing_setting
284+
285+ if not get_parsing_setting ("omegaconf_absolute_to_relative_paths" ):
286+ return yaml_load
287+
288+ def omegaconf_plus_load (value ):
289+ value = yaml_load (value )
290+ if isinstance (value , dict ):
291+ value = omegaconf_absolute_to_relative_paths (value )
292+ return value
293+
294+ return omegaconf_plus_load
295+
278296 def omegaconf_load (value ):
279297 value_pyyaml = yaml_load (value )
280298 if isinstance (value_pyyaml , (str , int , float , bool )) or value_pyyaml is None :
@@ -302,6 +320,57 @@ def omegaconf_apply(parser, cfg):
302320 return parser ._apply_actions (cfg_dict )
303321
304322
323+ def omegaconf_tokenize (path : str ) -> List [str ]:
324+ """Very small tokenizer: 'a.b[0].c' -> ['a','b','0','c']."""
325+ return [t for t in path .replace ("]" , "" ).replace ("[" , "." ).split ("." ) if t ]
326+
327+
328+ def omegaconf_tokens_to_path (tokens : List [str ]) -> str :
329+ """Render tokens back to a normalized path: ['a','0','b'] -> 'a[0].b'."""
330+ s = ""
331+ for t in tokens :
332+ if t .isdigit ():
333+ s += f"[{ t } ]"
334+ else :
335+ s += ("" if s == "" else "." ) + t
336+ return s
337+
338+
339+ def omegaconf_absolute_to_relative_paths (data : dict ) -> dict :
340+ """
341+ Return a new nested dict/list where absolute ${...} interpolations
342+ are rewritten to relative form from the node where they appear.
343+ """
344+ data = deepcopy (data )
345+
346+ regex_absolute_path = re .compile (r"\$\{([a-zA-Z][a-zA-Z0-9[\]_.]*)\}" )
347+
348+ def _walk (node , current_path : List [Union [str , int ]]):
349+ if isinstance (node , dict ):
350+ return {k : _walk (v , current_path + [k ]) for k , v in node .items ()}
351+ if isinstance (node , list ):
352+ return [_walk (v , current_path + [i ]) for i , v in enumerate (node )]
353+
354+ if isinstance (node , str ):
355+
356+ def _replace (m : re .Match ) -> str :
357+ dst_tokens = omegaconf_tokenize (m .group (1 ))
358+ # compute common prefix length
359+ i = 0
360+ while i < len (current_path ) and i < len (dst_tokens ) and str (current_path [i ]) == dst_tokens [i ]:
361+ i += 1
362+ up = max (1 , len (current_path ) - i )
363+ dots = "." * up
364+ down = omegaconf_tokens_to_path (dst_tokens [i :])
365+ return "${" + dots + down + "}"
366+
367+ return regex_absolute_path .sub (_replace , node )
368+
369+ return node
370+
371+ return _walk (data , [])
372+
373+
305374annotated_alias = typing_extensions_import ("_AnnotatedAlias" )
306375
307376
0 commit comments