Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Fixed
- Union types with str and default comment-like string incorrectly parsed as a
stringified exception of an other subtype (`#812
<https://github.com/omni-us/jsonargparse/pull/812>`__).
- ``FromConfig`` not handling correctly required parameters (`#813
<https://github.com/omni-us/jsonargparse/pull/813>`__).

Changed
^^^^^^^
Expand Down
12 changes: 10 additions & 2 deletions jsonargparse/_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__all__ = ["FromConfigMixin"]

T = TypeVar("T")
OVERRIDE_KINDS = {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}


class FromConfigMixin:
Expand Down Expand Up @@ -55,6 +56,10 @@ def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, d
"""Parse the init kwargs for ``cls`` from a config file or dict."""
parser = ArgumentParser(exit_on_error=False, **kwargs)
parser.add_class_arguments(cls)
for required in parser.required_args:
action = next((a for a in parser._actions if a.dest == required), None)
action._required = False # type: ignore[union-attr]
parser.required_args.clear()
if isinstance(config, dict):
cfg = parser.parse_object(config, defaults=False)
else:
Expand All @@ -79,12 +84,15 @@ def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None:
params = inspect.signature(cls.__init__).parameters
for name, default in defaults.copy().items():
param = params.get(name)
if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}:
if param and param.kind in OVERRIDE_KINDS:
if param.default == inspect._empty:
raise TypeError(f"Overriding of required parameters not allowed: '{param.name}'")
defaults.pop(name)
if param.kind == inspect.Parameter.KEYWORD_ONLY:
cls.__init__.__kwdefaults__[name] = default # type: ignore[index]
else:
index = list(params).index(name) - 1
required = [p for p in params.values() if p.kind in OVERRIDE_KINDS and p.default == inspect._empty]
index = list(params).index(name) - len(required)
aux = cls.__init__.__defaults__ or ()
cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :]

Expand Down
32 changes: 32 additions & 0 deletions jsonargparse_tests/test_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@
assert instance.child2 == 2


def test_init_defaults_override_preserve_required(tmp_cwd):
config_path = tmp_cwd / "config.yaml"
config_path.write_text(json_or_yaml_dump({"param2": 2}))

class DefaultsOverrideRequiredParameters(FromConfigMixin):
__from_config_init_defaults__ = config_path

def __init__(self, param1: str, param2: int = 1):
self.param1 = param1
self.param2 = param2

with pytest.raises(TypeError, match="missing 1 required positional argument: 'param1'"):
DefaultsOverrideRequiredParameters()

instance = DefaultsOverrideRequiredParameters(param1="required")
assert instance.param1 == "required"
assert instance.param2 == 2


def test_init_defaults_override_required_not_allowed(tmp_cwd):
config_path = tmp_cwd / "config.yaml"
config_path.write_text(json_or_yaml_dump({"param1": 2}))

with pytest.raises(TypeError, match="Overriding of required parameters not allowed: 'param1'"):

class DefaultsOverrideRequiredNotAllowed(FromConfigMixin):
__from_config_init_defaults__ = config_path

def __init__(self, param1: int):
self.param1 = param1


def test_init_defaults_override_class_with_init_subclass(tmp_cwd):
config_path = tmp_cwd / "config.yaml"
config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"}))
Expand Down
Loading