diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 59c62edc..bc164350 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,9 @@ Added `__). - Experimental support for pydantic ``BaseModel`` subclasses (`#781 `__). +- Argument to print help for dataclasses nested in types, e.g. + ``Optional[Data]``, ``Union[Data1, Data2]`` (`#783 + `__). Fixed ^^^^^ diff --git a/jsonargparse/_actions.py b/jsonargparse/_actions.py index 4aa50edc..3b1a975b 100644 --- a/jsonargparse/_actions.py +++ b/jsonargparse/_actions.py @@ -9,7 +9,7 @@ from contextvars import ContextVar from typing import Any, Dict, List, Optional, Tuple, Type, Union -from ._common import Action, is_subclass, parser_context +from ._common import Action, is_not_subclass_type, is_subclass, parser_context from ._loaders_dumpers import get_loader_exceptions, load_value from ._namespace import Namespace, NSKeyError, split_key, split_key_root from ._optionals import _get_config_read_mode, ruamel_support @@ -20,7 +20,6 @@ change_to_path_dir, default_config_option_help, get_import_path, - get_typehint_origin, import_object, indent_text, iter_to_set_str, @@ -347,6 +346,12 @@ def check_type(self, value, parser): class _ActionHelpClassPath(Action): sub_add_kwargs: Dict[str, Any] = {} + @classmethod + def get_help_types(cls, typehint) -> Optional[tuple]: + from ._typehints import get_subclass_or_closed_types + + return get_subclass_or_closed_types(typehint=typehint, also_lists=True, callable_return=True) + def __init__(self, typehint=None, **kwargs): if typehint is not None: self._typehint = typehint @@ -355,34 +360,28 @@ def __init__(self, typehint=None, **kwargs): super().__init__(**kwargs) def update_init_kwargs(self, kwargs): - from ._typehints import ( - get_optional_arg, - get_subclass_names, - get_subclass_types, - get_unaliased_type, - is_protocol, - ) + from ._typehints import is_protocol - typehint = get_unaliased_type(get_optional_arg(kwargs.pop("_typehint"))) - if get_typehint_origin(typehint) is not Union: - assert "nargs" not in kwargs - kwargs["nargs"] = "?" - self._typehint = typehint - self._basename = iter_to_set_str(get_subclass_names(typehint, callable_return=True)) - self._baseclasses = get_subclass_types(typehint, callable_return=True) - assert self._baseclasses and all(isinstance(b, type) for b in self._baseclasses) - - self._kind = "subclass of" - if any(is_protocol(b) for b in self._baseclasses): - self._kind = "subclass or implementer of protocol" - - kwargs.update( - { - "metavar": "CLASS_PATH_OR_NAME", - "default": SUPPRESS, - "help": f"Show the help for the given {self._kind} {self._basename} and exit.", - } - ) + self._typehint = kwargs.pop("_typehint") + self._help_types = self.get_help_types(self._typehint) + assert self._help_types and all(isinstance(b, type) for b in self._help_types) + self._not_subclass = len(self._help_types) == 1 and is_not_subclass_type(self._help_types[0]) + self._basename = iter_to_set_str(t.__name__ for t in self._help_types) + + if len(self._help_types) == 1: + kwargs["nargs"] = 0 if self._not_subclass else "?" + + if self._not_subclass: + msg = "" + else: + kwargs["metavar"] = "CLASS_PATH_OR_NAME" + self._kind = "subclass of" + if any(is_protocol(b) for b in self._help_types): + self._kind = "subclass or implementer of protocol" + msg = f"the given {self._kind} " + + kwargs["default"] = SUPPRESS + kwargs["help"] = f"Show the help for {msg}{self._basename} and exit." def __call__(self, *args, **kwargs): if len(args) == 0: @@ -399,14 +398,14 @@ def print_help(self, call_args): parser, _, value, option_string = call_args try: - if self.nargs == "?" and value is None: - val_class = self._typehint + if self.nargs == 0 or (self.nargs == "?" and value is None): + val_class = self._help_types[0] else: - val_class = import_object(resolve_class_path_by_name(self._baseclasses, value)) + val_class = import_object(resolve_class_path_by_name(self._help_types, value)) except Exception as ex: raise TypeError(f"{option_string}: {ex}") from ex - if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses): + if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._help_types): raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}') dest = re.sub("\\.help$", "", self.dest) subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}") diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index beee9ac0..4ac90c3d 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -25,7 +25,6 @@ callable_instances, get_subclass_names, is_optional, - is_subclass_spec, not_required_types, sequence_origin_types, ) @@ -129,8 +128,6 @@ def add_class_arguments( if skip_not_added: skip.update(skip_not_added) # skip init=False if defaults: - if is_subclass_spec(defaults): - defaults = defaults.get("init_args", {}) defaults = {prefix + k: v for k, v in defaults.items() if k not in skip} self.set_defaults(**defaults) # type: ignore[attr-defined] diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index 02a55bd4..21c6808c 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -292,9 +292,7 @@ def prepare_add_argument(args, kwargs, enable_path, container, logger, sub_add_k typehint = kwargs.pop("type") if args[0].startswith("--") and ActionTypeHint.supports_append(typehint): args = tuple(list(args) + [args[0] + "+"]) - if ActionTypeHint.is_subclass_typehint( - typehint, all_subtypes=False - ) or ActionTypeHint.is_return_subclass_typehint(typehint): + if _ActionHelpClassPath.get_help_types(typehint): help_option = f"--{args[0]}.help" if args[0][0] != "-" else f"{args[0]}.help" help_action = container.add_argument(help_option, action=_ActionHelpClassPath(typehint=typehint)) if sub_add_kwargs: @@ -315,6 +313,7 @@ def is_supported_typehint(typehint, full=False): or get_typehint_origin(typehint) in root_types or get_registered_type(typehint) is not None or is_subclass(typehint, Enum) + or is_not_subclass_type(typehint) or ActionTypeHint.is_subclass_typehint(typehint) ) if full and supported: @@ -349,7 +348,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False): test = all if all_subtypes else any k = {"also_lists": also_lists} return test(ActionTypeHint.is_subclass_typehint(s, **k) for s in subtypes) - return is_single_subclass_typehint(typehint, typehint_origin) + return is_single_subclass_type(typehint, typehint_origin) @staticmethod def is_return_subclass_typehint(typehint): @@ -1243,8 +1242,8 @@ def get_callable_return_type(typehint): return return_type -def is_single_subclass_typehint(typehint, typehint_origin): - return ( +def is_single_class_type(typehint, typehint_origin, closed_class): + if not ( ( (inspect.isclass(typehint) and typehint_origin is None) or (is_generic_class(typehint) and inspect.isclass(typehint.__origin__)) @@ -1254,35 +1253,61 @@ def is_single_subclass_typehint(typehint, typehint_origin): and not is_pydantic_type(typehint) and not is_subclass(typehint, (Path, Enum)) and getattr(typehint_origin, "__module__", "") != "builtins" - ) + ): + return False + if not closed_class: + return not is_not_subclass_type(typehint) + return True -def yield_subclass_types(typehint, also_lists=False, callable_return=False): +is_single_subclass_type = partial(is_single_class_type, closed_class=False) +is_single_subclass_or_closed_type = partial(is_single_class_type, closed_class=True) + + +def yield_class_types(typehint, is_single, also_lists=False, callable_return=False): typehint = typehint_from_action(typehint) if typehint is None: return typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint))) typehint_origin = get_typehint_origin(typehint) + kwargs = {"is_single": is_single, "also_lists": also_lists, "callable_return": callable_return} if callable_return and (typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint)): return_type = get_callable_return_type(typehint) if return_type: - k = {"also_lists": also_lists, "callable_return": callable_return} - yield from yield_subclass_types(return_type, **k) + yield from yield_class_types(return_type, **kwargs) elif typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types): - k = {"also_lists": also_lists, "callable_return": callable_return} for subtype in typehint.__args__: - yield from yield_subclass_types(subtype, **k) - if is_single_subclass_typehint(typehint, typehint_origin): + yield from yield_class_types(subtype, **kwargs) + if is_single(typehint, typehint_origin): yield typehint def get_subclass_types(typehint, also_lists=False, callable_return=False): - types = tuple(yield_subclass_types(typehint, also_lists=also_lists, callable_return=callable_return)) + types = tuple( + yield_class_types( + typehint, is_single=is_single_subclass_type, also_lists=also_lists, callable_return=callable_return + ) + ) + return types or None + + +def get_subclass_or_closed_types(typehint, also_lists=False, callable_return=False): + types = tuple( + yield_class_types( + typehint, + is_single=is_single_subclass_or_closed_type, + also_lists=also_lists, + callable_return=callable_return, + ) + ) return types or None def get_subclass_names(typehint, callable_return=False): - return tuple(t.__name__ for t in yield_subclass_types(typehint, callable_return=callable_return)) + return tuple( + t.__name__ + for t in yield_class_types(typehint, is_single=is_single_subclass_type, callable_return=callable_return) + ) def adapt_partial_callable_class(callable_type, subclass_spec): diff --git a/jsonargparse_tests/test_dataclasses.py b/jsonargparse_tests/test_dataclasses.py index 1f465b2d..253376a2 100644 --- a/jsonargparse_tests/test_dataclasses.py +++ b/jsonargparse_tests/test_dataclasses.py @@ -417,26 +417,38 @@ class Data: p2: int = 0 -parser_optional_data = ArgumentParser(exit_on_error=False) -parser_optional_data.add_argument("--data", type=Optional[Data]) +@pytest.fixture +def parser_optional_data() -> ArgumentParser: + parser = ArgumentParser(exit_on_error=False) + parser.add_argument("--data", type=Optional[Data]) + return parser + +def test_optional_dataclass_help(parser_optional_data): + help_str = get_parser_help(parser_optional_data) + assert "--data.help" in help_str + assert "CLASS_PATH_OR_NAME" not in help_str + help_str = get_parse_args_stdout(parser_optional_data, ["--data.help"]) + assert "--data.p1" in help_str + assert "--data.p2" in help_str -def test_optional_dataclass_type_all_fields(): + +def test_optional_dataclass_type_all_fields(parser_optional_data): cfg = parser_optional_data.parse_args(['--data={"p1": "x", "p2": 1}']) assert cfg == Namespace(data=Namespace(p1="x", p2=1)) -def test_optional_dataclass_type_single_field(): +def test_optional_dataclass_type_single_field(parser_optional_data): cfg = parser_optional_data.parse_args(['--data={"p1": "y"}']) assert cfg == Namespace(data=Namespace(p1="y", p2=0)) -def test_optional_dataclass_type_invalid_field(): +def test_optional_dataclass_type_invalid_field(parser_optional_data): with pytest.raises(ArgumentError, match="Expected a . Got value: 1"): parser_optional_data.parse_args(['--data={"p1": 1}']) -def test_optional_dataclass_type_instantiate(): +def test_optional_dataclass_type_instantiate(parser_optional_data): cfg = parser_optional_data.parse_args(['--data={"p1": "y", "p2": 2}']) init = parser_optional_data.instantiate_classes(cfg) assert isinstance(init.data, Data) @@ -444,17 +456,17 @@ def test_optional_dataclass_type_instantiate(): assert init.data.p2 == 2 -def test_optional_dataclass_type_dump(): +def test_optional_dataclass_type_dump(parser_optional_data): cfg = parser_optional_data.parse_args(['--data={"p1": "z"}']) assert json_or_yaml_load(parser_optional_data.dump(cfg)) == {"data": {"p1": "z", "p2": 0}} -def test_optional_dataclass_type_missing_required_field(): +def test_optional_dataclass_type_missing_required_field(parser_optional_data): with pytest.raises(ArgumentError): parser_optional_data.parse_args(['--data={"p2": 2}']) -def test_optional_dataclass_type_null_value(): +def test_optional_dataclass_type_null_value(parser_optional_data): cfg = parser_optional_data.parse_args(["--data=null"]) assert cfg == Namespace(data=None) assert cfg == parser_optional_data.instantiate_classes(cfg) @@ -514,6 +526,10 @@ def test_dataclass_in_union_type(parser): cfg = parser.parse_args(["--union=1"]) assert cfg == Namespace(union=1) assert cfg == parser.instantiate_classes(cfg) + help_str = get_parser_help(parser) + assert "--union.help" in help_str + help_str = get_parse_args_stdout(parser, ["--union.help"]) + assert f"Help for --union.help={__name__}.Data" in help_str def test_dataclass_in_list_type(parser): @@ -631,6 +647,16 @@ def test_class_path_union_mixture_dataclass_and_class(parser, union_type): assert init.union.prm_1 == 1.2 assert json_or_yaml_load(parser.dump(cfg))["union"] == value + help_str = get_parser_help(parser) + assert "--union.help" in help_str + help_str = [x for x in help_str.split("\n") if "help for the given subclass" in x][0] + assert "UnionData" in help_str + assert "UnionClass" in help_str + help_str = get_parse_args_stdout(parser, ["--union.help=UnionData"]) + assert f"Help for --union.help={__name__}.UnionData" in help_str + help_str = get_parse_args_stdout(parser, ["--union.help=UnionClass"]) + assert f"Help for --union.help={__name__}.UnionClass" in help_str + def test_class_path_union_dataclasses(parser): parser.add_argument("--union", type=Union[Data, SingleParamChange, UnionData]) @@ -735,13 +761,18 @@ def test_dataclass_not_subclass(parser): parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) help_str = get_parser_help(parser) - assert "--data.help [CLASS_PATH_OR_NAME]" not in help_str + assert "--data.help" not in help_str config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} with pytest.raises(ArgumentError, match="Group 'data' does not accept nested key 'init_args.p2'"): parser.parse_args([f"--data={json.dumps(config)}"]) +def test_add_subclass_dataclass_not_subclass(parser): + with pytest.raises(ValueError, match="Expected .* a subclass type or a tuple of subclass types"): + parser.add_subclass_arguments(DataMain, "data") + + @pytest.fixture def subclass_behavior(): with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors: @@ -749,7 +780,28 @@ def subclass_behavior(): yield -def test_dataclass_argument_as_subclass(parser, subtests, subclass_behavior): +@pytest.mark.parametrize("default", [None, DataMain()]) +def test_add_subclass_dataclass_as_subclass(parser, default, subclass_behavior): + parser.add_subclass_arguments(DataMain, "data", default=default) + + config = {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataMain) + assert dataclasses.asdict(init.data) == {"p1": 2} + dump = json_or_yaml_load(parser.dump(cfg))["data"] + assert dump == {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}} + + config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}} + cfg = parser.parse_args([f"--data={json.dumps(config)}"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, DataSub) + assert dataclasses.asdict(init.data) == {"p1": 1, "p2": "y"} + dump = json_or_yaml_load(parser.dump(cfg))["data"] + assert dump == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 1, "p2": "y"}} + + +def test_add_argument_dataclass_as_subclass(parser, subtests, subclass_behavior): parser.add_argument("--data", type=DataMain, default=DataMain(p1=2)) with subtests.test("help"): diff --git a/jsonargparse_tests/test_pydantic.py b/jsonargparse_tests/test_pydantic.py index 8b9bcc77..132e15f8 100644 --- a/jsonargparse_tests/test_pydantic.py +++ b/jsonargparse_tests/test_pydantic.py @@ -405,6 +405,7 @@ def test_model_argument_as_subclass(parser, subtests, subclass_behavior): assert f"{__name__}.Person" in help_str help_str = get_parse_args_stdout(parser, ["--person.help"]) assert f"Help for --person.help={__name__}.Person" in help_str + assert "--person.pets.help [CLASS_PATH_OR_NAME]" in help_str with subtests.test("defaults"): defaults = parser.get_defaults() diff --git a/jsonargparse_tests/test_stubs_resolver.py b/jsonargparse_tests/test_stubs_resolver.py index c17051f8..f06f1128 100644 --- a/jsonargparse_tests/test_stubs_resolver.py +++ b/jsonargparse_tests/test_stubs_resolver.py @@ -291,7 +291,7 @@ def test_get_params_complex_function_requests_get(parser): assert ["url", "params"] == list(parser.get_defaults().keys()) help_str = get_parser_help(parser) assert "default: Unknown" in help_str - assert "--cookies.help CLASS_PATH_OR_NAME" in help_str + assert "--cookies.help [CLASS_PATH_OR_NAME]" in help_str # stubs only resolver tests