Skip to content

Commit 10cfd30

Browse files
authored
FromConfigMixin.from_config now supports subclasses (#822)
1 parent 8795fb0 commit 10cfd30

File tree

3 files changed

+95
-25
lines changed

3 files changed

+95
-25
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Added
2020
- Signature methods now when given ``sub_configs=True``, list of paths types can
2121
now receive a file containing a list of paths (`#816
2222
<https://github.com/omni-us/jsonargparse/pull/816>`__).
23+
- ``FromConfigMixin.from_config`` now supports subclasses (`#822
24+
<https://github.com/omni-us/jsonargparse/pull/822>`__).
2325

2426
Fixed
2527
^^^^^
@@ -52,7 +54,7 @@ Fixed
5254
- Union types with str and default comment-like string incorrectly parsed as a
5355
stringified exception of an other subtype (`#812
5456
<https://github.com/omni-us/jsonargparse/pull/812>`__).
55-
- ``FromConfig`` not handling correctly required parameters (`#813
57+
- ``FromConfigMixin`` not handling correctly required parameters (`#813
5658
<https://github.com/omni-us/jsonargparse/pull/813>`__).
5759

5860
Changed

jsonargparse/_from_config.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from pathlib import Path
55
from typing import Optional, Type, TypeVar, Union
66

7+
from ._common import parser_context
78
from ._core import ArgumentParser
9+
from ._loaders_dumpers import get_loader_exceptions, load_value
10+
from ._optionals import _get_config_read_mode
11+
from ._typehints import is_subclass_spec, resolve_class_path_by_name
12+
from ._util import import_object
813

914
__all__ = ["FromConfigMixin"]
1015

@@ -48,23 +53,42 @@ def from_config(cls: Type[T], config: Union[str, PathLike, dict]) -> T:
4853
Args:
4954
config: Path to a config file or a dict with config values.
5055
"""
51-
kwargs = _parse_class_kwargs_from_config(cls, config, **cls.__from_config_parser_kwargs__) # type: ignore[attr-defined]
56+
kwargs, cls = _parse_class_kwargs_from_config(cls, config, **cls.__from_config_parser_kwargs__) # type: ignore[attr-defined]
5257
return cls(**kwargs)
5358

5459

55-
def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict], **kwargs) -> dict:
60+
def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, dict], **kwargs) -> tuple[dict, Type[T]]:
5661
"""Parse the init kwargs for ``cls`` from a config file or dict."""
5762
parser = ArgumentParser(exit_on_error=False, **kwargs)
63+
if not isinstance(config, dict):
64+
from .typing import Path
65+
66+
cfg_path = Path(config, mode=_get_config_read_mode())
67+
cfg_str = cfg_path.get_content()
68+
with parser_context(load_value_mode=parser.parser_mode):
69+
try:
70+
config = load_value(cfg_str, path=str(config))
71+
except get_loader_exceptions() as ex:
72+
raise TypeError(f"Problems parsing config '{config}': {ex}") from ex
73+
74+
if not isinstance(config, dict):
75+
raise TypeError(f"Expected config to be a dict or parse into a dict: {config}")
76+
77+
if is_subclass_spec(config):
78+
class_path = resolve_class_path_by_name(cls, config["class_path"])
79+
obj = import_object(class_path)
80+
if not issubclass(obj, cls):
81+
raise TypeError(f"Class '{class_path}' is not a subclass of '{cls.__name__}'")
82+
cls = obj
83+
config = {**config.get("init_args", {}), **config.get("dict_kwargs", {})}
84+
5885
parser.add_class_arguments(cls)
5986
for required in parser.required_args:
6087
action = next((a for a in parser._actions if a.dest == required), None)
6188
action._required = False # type: ignore[union-attr]
6289
parser.required_args.clear()
63-
if isinstance(config, dict):
64-
cfg = parser.parse_object(config, defaults=False)
65-
else:
66-
cfg = parser.parse_path(config, defaults=False)
67-
return parser.instantiate_classes(cfg).as_dict()
90+
cfg = parser.parse_object(config, defaults=False)
91+
return parser.instantiate_classes(cfg).as_dict(), cls
6892

6993

7094
def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None:
@@ -75,7 +99,7 @@ def _override_init_defaults(cls: Type[T], parser_kwargs: dict) -> None:
7599
if not (isinstance(config, (str, PathLike)) and Path(config).is_file()):
76100
return
77101

78-
defaults = _parse_class_kwargs_from_config(cls, config, **parser_kwargs)
102+
defaults, cls = _parse_class_kwargs_from_config(cls, config, **parser_kwargs)
79103
_override_init_defaults_this_class(cls, defaults)
80104
_override_init_defaults_parent_classes(cls, defaults)
81105

jsonargparse_tests/test_from_config.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44

55
from jsonargparse import FromConfigMixin
6-
from jsonargparse_tests.conftest import json_or_yaml_dump, skip_if_omegaconf_unavailable
6+
from jsonargparse._paths import PathError
7+
from jsonargparse_tests.conftest import json_or_yaml_dump, skip_if_no_pyyaml, skip_if_omegaconf_unavailable
78

89
# __init__ defaults override tests
910

@@ -178,16 +179,47 @@ class DefaultsOverrideInvalid(DefaultsOverrideParent):
178179
# from_config method tests
179180

180181

182+
class FromConfigMethodParent(FromConfigMixin):
183+
def __init__(self, parent_param: str = "parent_default"):
184+
self.parent_param = parent_param
185+
186+
187+
class FromConfigMethodChild(FromConfigMethodParent):
188+
def __init__(self, child_param: str = "child_default", **kwargs):
189+
super().__init__(**kwargs)
190+
self.child_param = child_param
191+
192+
181193
def test_from_config_method_path(tmp_cwd):
182194
config_path = tmp_cwd / "config.yaml"
183-
config_path.write_text(json_or_yaml_dump({"param": "value_from_file"}))
195+
config_path.write_text(json_or_yaml_dump({"parent_param": "value_from_file"}))
196+
197+
instance = FromConfigMethodParent.from_config(config_path)
198+
assert instance.parent_param == "value_from_file"
199+
200+
201+
def test_from_config_method_path_does_not_exist(tmp_cwd):
202+
config_path = tmp_cwd / "config.yaml"
203+
204+
with pytest.raises(PathError, match="File does not exist:"):
205+
FromConfigMethodParent.from_config(config_path)
206+
207+
208+
@skip_if_no_pyyaml
209+
def test_from_config_method_path_invalid_yaml(tmp_cwd):
210+
config_path = tmp_cwd / "config.yaml"
211+
config_path.write_text("::: invalid content :::")
212+
213+
with pytest.raises(TypeError, match="Problems parsing config"):
214+
FromConfigMethodParent.from_config(config_path)
184215

185-
class FromConfigMethodPath(FromConfigMixin):
186-
def __init__(self, param: str = "default_value"):
187-
self.param = param
188216

189-
instance = FromConfigMethodPath.from_config(config_path)
190-
assert instance.param == "value_from_file"
217+
def test_from_config_method_path_not_dict(tmp_cwd):
218+
config_path = tmp_cwd / "config.yaml"
219+
config_path.write_text("[1, 2]")
220+
221+
with pytest.raises(TypeError, match="Expected config to be a dict or parse into a dict"):
222+
FromConfigMethodParent.from_config(config_path)
191223

192224

193225
def test_from_config_method_dict():
@@ -224,15 +256,6 @@ def __init__(self, param1: str = "default_value", param2: int = 1):
224256

225257

226258
def test_from_config_method_subclass():
227-
class FromConfigMethodParent(FromConfigMixin):
228-
def __init__(self, parent_param: str = "parent_default"):
229-
self.parent_param = parent_param
230-
231-
class FromConfigMethodChild(FromConfigMethodParent):
232-
def __init__(self, child_param: str = "child_default", **kwargs):
233-
super().__init__(**kwargs)
234-
self.child_param = child_param
235-
236259
instance = FromConfigMethodChild.from_config(
237260
{"parent_param": "overridden_parent", "child_param": "overridden_child"}
238261
)
@@ -241,6 +264,27 @@ def __init__(self, child_param: str = "child_default", **kwargs):
241264
assert instance.child_param == "overridden_child"
242265

243266

267+
def test_from_config_method_class_path_subclass():
268+
instance = FromConfigMethodParent.from_config(
269+
{
270+
"class_path": f"{__name__}.FromConfigMethodChild",
271+
"init_args": {"parent_param": "overridden_parent", "child_param": "overridden_child"},
272+
}
273+
)
274+
assert isinstance(instance, FromConfigMethodChild)
275+
assert instance.parent_param == "overridden_parent"
276+
assert instance.child_param == "overridden_child"
277+
278+
279+
class SomeOtherClass:
280+
pass
281+
282+
283+
def test_from_config_method_class_path_not_subclass():
284+
with pytest.raises(TypeError, match="SomeOtherClass' is not a subclass of 'FromConfigMethodParent'"):
285+
FromConfigMethodParent.from_config({"class_path": f"{__name__}.SomeOtherClass"})
286+
287+
244288
@skip_if_omegaconf_unavailable
245289
def test_from_config_method_parser_kwargs():
246290
class FromConfigMethodParserKwargs(FromConfigMixin):

0 commit comments

Comments
 (0)