diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c2eaeb7f..6d853100 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,6 +20,8 @@ Fixed - Union types with str and default comment-like string incorrectly parsed as a stringified exception of an other subtype (`#812 `__). +- ``FromConfig`` not handling correctly required parameters (`#813 + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_from_config.py b/jsonargparse/_from_config.py index 10508731..37cd1fe5 100644 --- a/jsonargparse/_from_config.py +++ b/jsonargparse/_from_config.py @@ -9,6 +9,7 @@ __all__ = ["FromConfigMixin"] T = TypeVar("T") +OVERRIDE_KINDS = {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD} class FromConfigMixin: @@ -55,6 +56,10 @@ def _parse_class_kwargs_from_config(cls: Type[T], config: Union[str, PathLike, d """Parse the init kwargs for ``cls`` from a config file or dict.""" parser = ArgumentParser(exit_on_error=False, **kwargs) parser.add_class_arguments(cls) + for required in parser.required_args: + action = next((a for a in parser._actions if a.dest == required), None) + action._required = False # type: ignore[union-attr] + parser.required_args.clear() if isinstance(config, dict): cfg = parser.parse_object(config, defaults=False) else: @@ -79,12 +84,15 @@ def _override_init_defaults_this_class(cls: Type[T], defaults: dict) -> None: params = inspect.signature(cls.__init__).parameters for name, default in defaults.copy().items(): param = params.get(name) - if param and param.kind in {inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD}: + if param and param.kind in OVERRIDE_KINDS: + if param.default == inspect._empty: + raise TypeError(f"Overriding of required parameters not allowed: '{param.name}'") defaults.pop(name) if param.kind == inspect.Parameter.KEYWORD_ONLY: cls.__init__.__kwdefaults__[name] = default # type: ignore[index] else: - index = list(params).index(name) - 1 + required = [p for p in params.values() if p.kind in OVERRIDE_KINDS and p.default == inspect._empty] + index = list(params).index(name) - len(required) aux = cls.__init__.__defaults__ or () cls.__init__.__defaults__ = aux[:index] + (default,) + aux[index + 1 :] diff --git a/jsonargparse_tests/test_from_config.py b/jsonargparse_tests/test_from_config.py index b324c7ec..a55df57c 100644 --- a/jsonargparse_tests/test_from_config.py +++ b/jsonargparse_tests/test_from_config.py @@ -80,6 +80,38 @@ def __init__(self, *, child2: int = 2, child1: str = "child_default_value", **kw assert instance.child2 == 2 +def test_init_defaults_override_preserve_required(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param2": 2})) + + class DefaultsOverrideRequiredParameters(FromConfigMixin): + __from_config_init_defaults__ = config_path + + def __init__(self, param1: str, param2: int = 1): + self.param1 = param1 + self.param2 = param2 + + with pytest.raises(TypeError, match="missing 1 required positional argument: 'param1'"): + DefaultsOverrideRequiredParameters() + + instance = DefaultsOverrideRequiredParameters(param1="required") + assert instance.param1 == "required" + assert instance.param2 == 2 + + +def test_init_defaults_override_required_not_allowed(tmp_cwd): + config_path = tmp_cwd / "config.yaml" + config_path.write_text(json_or_yaml_dump({"param1": 2})) + + with pytest.raises(TypeError, match="Overriding of required parameters not allowed: 'param1'"): + + class DefaultsOverrideRequiredNotAllowed(FromConfigMixin): + __from_config_init_defaults__ = config_path + + def __init__(self, param1: int): + self.param1 = param1 + + def test_init_defaults_override_class_with_init_subclass(tmp_cwd): config_path = tmp_cwd / "config.yaml" config_path.write_text(json_or_yaml_dump({"parent": "overridden_parent", "child": "overridden_child"}))