Skip to content

Commit ce995f7

Browse files
authored
FromConfig not handling correctly required parameters (#813)
1 parent 9c18147 commit ce995f7

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Fixed
2020
- Union types with str and default comment-like string incorrectly parsed as a
2121
stringified exception of an other subtype (`#812
2222
<https://github.com/omni-us/jsonargparse/pull/812>`__).
23+
- ``FromConfig`` not handling correctly required parameters (`#813
24+
<https://github.com/omni-us/jsonargparse/pull/813>`__).
2325

2426
Changed
2527
^^^^^^^

jsonargparse/_from_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
__all__ = ["FromConfigMixin"]
1010

1111
T = TypeVar("T")
12+
OVERRIDE_KINDS = {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}
1213

1314

1415
class FromConfigMixin:
@@ -55,6 +56,10 @@ def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, d
5556
"""Parse the init kwargs for ``cls`` from a config file or dict."""
5657
parser = ArgumentParser(exit_on_error=False, **kwargs)
5758
parser.add_class_arguments(cls)
59+
for required in parser.required_args:
60+
action = next((a for a in parser._actions if a.dest == required), None)
61+
action._required = False # type: ignore[union-attr]
62+
parser.required_args.clear()
5863
if isinstance(config, dict):
5964
cfg = parser.parse_object(config, defaults=False)
6065
else:
@@ -79,12 +84,15 @@ def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None:
7984
params = inspect.signature(cls.__init__).parameters
8085
for name, default in defaults.copy().items():
8186
param = params.get(name)
82-
if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}:
87+
if param and param.kind in OVERRIDE_KINDS:
88+
if param.default == inspect._empty:
89+
raise TypeError(f"Overriding of required parameters not allowed: '{param.name}'")
8390
defaults.pop(name)
8491
if param.kind == inspect.Parameter.KEYWORD_ONLY:
8592
cls.__init__.__kwdefaults__[name] = default # type: ignore[index]
8693
else:
87-
index = list(params).index(name) - 1
94+
required = [p for p in params.values() if p.kind in OVERRIDE_KINDS and p.default == inspect._empty]
95+
index = list(params).index(name) - len(required)
8896
aux = cls.__init__.__defaults__ or ()
8997
cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :]
9098

jsonargparse_tests/test_from_config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,38 @@ def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kw
8080
assert instance.child2 == 2
8181

8282

83+
def test_init_defaults_override_preserve_required(tmp_cwd):
84+
config_path = tmp_cwd / "config.yaml"
85+
config_path.write_text(json_or_yaml_dump({"param2": 2}))
86+
87+
class DefaultsOverrideRequiredParameters(FromConfigMixin):
88+
__from_config_init_defaults__ = config_path
89+
90+
def __init__(self, param1: str, param2: int = 1):
91+
self.param1 = param1
92+
self.param2 = param2
93+
94+
with pytest.raises(TypeError, match="missing 1 required positional argument: 'param1'"):
95+
DefaultsOverrideRequiredParameters()
96+
97+
instance = DefaultsOverrideRequiredParameters(param1="required")
98+
assert instance.param1 == "required"
99+
assert instance.param2 == 2
100+
101+
102+
def test_init_defaults_override_required_not_allowed(tmp_cwd):
103+
config_path = tmp_cwd / "config.yaml"
104+
config_path.write_text(json_or_yaml_dump({"param1": 2}))
105+
106+
with pytest.raises(TypeError, match="Overriding of required parameters not allowed: 'param1'"):
107+
108+
class DefaultsOverrideRequiredNotAllowed(FromConfigMixin):
109+
__from_config_init_defaults__ = config_path
110+
111+
def __init__(self, param1: int):
112+
self.param1 = param1
113+
114+
83115
def test_init_defaults_override_class_with_init_subclass(tmp_cwd):
84116
config_path = tmp_cwd / "config.yaml"
85117
config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"}))

0 commit comments

Comments
 (0)