Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ Added
<https://github.com/omni-us/jsonargparse/pull/775>`__).
- Experimental support for pydantic ``BaseModel`` subclasses (`#781
<https://github.com/omni-us/jsonargparse/pull/781>`__).
- Argument to print help for dataclasses nested in types, e.g.
``Optional[Data]``, ``Union[Data1, Data2]`` (`#783
<https://github.com/omni-us/jsonargparse/pull/783>`__).

Fixed
^^^^^
Expand Down
65 changes: 32 additions & 33 deletions jsonargparse/_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextvars import ContextVar
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from ._common import Action, is_subclass, parser_context
from ._common import Action, is_not_subclass_type, is_subclass, parser_context
from ._loaders_dumpers import get_loader_exceptions, load_value
from ._namespace import Namespace, NSKeyError, split_key, split_key_root
from ._optionals import _get_config_read_mode, ruamel_support
Expand All @@ -20,7 +20,6 @@
change_to_path_dir,
default_config_option_help,
get_import_path,
get_typehint_origin,
import_object,
indent_text,
iter_to_set_str,
Expand Down Expand Up @@ -347,6 +346,12 @@ def check_type(self, value, parser):
class _ActionHelpClassPath(Action):
sub_add_kwargs: Dict[str, Any] = {}

@classmethod
def get_help_types(cls, typehint) -> Optional[tuple]:
from ._typehints import get_subclass_or_closed_types

return get_subclass_or_closed_types(typehint=typehint, also_lists=True, callable_return=True)

def __init__(self, typehint=None, **kwargs):
if typehint is not None:
self._typehint = typehint
Expand All @@ -355,34 +360,28 @@ def __init__(self, typehint=None, **kwargs):
super().__init__(**kwargs)

def update_init_kwargs(self, kwargs):
from ._typehints import (
get_optional_arg,
get_subclass_names,
get_subclass_types,
get_unaliased_type,
is_protocol,
)
from ._typehints import is_protocol

typehint = get_unaliased_type(get_optional_arg(kwargs.pop("_typehint")))
if get_typehint_origin(typehint) is not Union:
assert "nargs" not in kwargs
kwargs["nargs"] = "?"
self._typehint = typehint
self._basename = iter_to_set_str(get_subclass_names(typehint, callable_return=True))
self._baseclasses = get_subclass_types(typehint, callable_return=True)
assert self._baseclasses and all(isinstance(b, type) for b in self._baseclasses)

self._kind = "subclass of"
if any(is_protocol(b) for b in self._baseclasses):
self._kind = "subclass or implementer of protocol"

kwargs.update(
{
"metavar": "CLASS_PATH_OR_NAME",
"default": SUPPRESS,
"help": f"Show the help for the given {self._kind} {self._basename} and exit.",
}
)
self._typehint = kwargs.pop("_typehint")
self._help_types = self.get_help_types(self._typehint)
assert self._help_types and all(isinstance(b, type) for b in self._help_types)
self._not_subclass = len(self._help_types) == 1 and is_not_subclass_type(self._help_types[0])
self._basename = iter_to_set_str(t.__name__ for t in self._help_types)

if len(self._help_types) == 1:
kwargs["nargs"] = 0 if self._not_subclass else "?"

if self._not_subclass:
msg = ""
else:
kwargs["metavar"] = "CLASS_PATH_OR_NAME"
self._kind = "subclass of"
if any(is_protocol(b) for b in self._help_types):
self._kind = "subclass or implementer of protocol"
msg = f"the given {self._kind} "

kwargs["default"] = SUPPRESS
kwargs["help"] = f"Show the help for {msg}{self._basename} and exit."

def __call__(self, *args, **kwargs):
if len(args) == 0:
Expand All @@ -399,14 +398,14 @@ def print_help(self, call_args):

parser, _, value, option_string = call_args
try:
if self.nargs == "?" and value is None:
val_class = self._typehint
if self.nargs == 0 or (self.nargs == "?" and value is None):
val_class = self._help_types[0]
else:
val_class = import_object(resolve_class_path_by_name(self._baseclasses, value))
val_class = import_object(resolve_class_path_by_name(self._help_types, value))
except Exception as ex:
raise TypeError(f"{option_string}: {ex}") from ex

if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses):
if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._help_types):
raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}')
dest = re.sub("\\.help$", "", self.dest)
subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}")
Expand Down
3 changes: 0 additions & 3 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
callable_instances,
get_subclass_names,
is_optional,
is_subclass_spec,
not_required_types,
sequence_origin_types,
)
Expand Down Expand Up @@ -129,8 +128,6 @@ def add_class_arguments(
if skip_not_added:
skip.update(skip_not_added) # skip init=False
if defaults:
if is_subclass_spec(defaults):
defaults = defaults.get("init_args", {})
defaults = {prefix + k: v for k, v in defaults.items() if k not in skip}
self.set_defaults(**defaults) # type: ignore[attr-defined]

Expand Down
55 changes: 40 additions & 15 deletions jsonargparse/_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,7 @@ def prepare_add_argument(args, kwargs, enable_path, container, logger, sub_add_k
typehint = kwargs.pop("type")
if args[0].startswith("--") and ActionTypeHint.supports_append(typehint):
args = tuple(list(args) + [args[0] + "+"])
if ActionTypeHint.is_subclass_typehint(
typehint, all_subtypes=False
) or ActionTypeHint.is_return_subclass_typehint(typehint):
if _ActionHelpClassPath.get_help_types(typehint):
help_option = f"--{args[0]}.help" if args[0][0] != "-" else f"{args[0]}.help"
help_action = container.add_argument(help_option, action=_ActionHelpClassPath(typehint=typehint))
if sub_add_kwargs:
Expand All @@ -315,6 +313,7 @@ def is_supported_typehint(typehint, full=False):
or get_typehint_origin(typehint) in root_types
or get_registered_type(typehint) is not None
or is_subclass(typehint, Enum)
or is_not_subclass_type(typehint)
or ActionTypeHint.is_subclass_typehint(typehint)
)
if full and supported:
Expand Down Expand Up @@ -349,7 +348,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
test = all if all_subtypes else any
k = {"also_lists": also_lists}
return test(ActionTypeHint.is_subclass_typehint(s, **k) for s in subtypes)
return is_single_subclass_typehint(typehint, typehint_origin)
return is_single_subclass_type(typehint, typehint_origin)

@staticmethod
def is_return_subclass_typehint(typehint):
Expand Down Expand Up @@ -1243,8 +1242,8 @@ def get_callable_return_type(typehint):
return return_type


def is_single_subclass_typehint(typehint, typehint_origin):
return (
def is_single_class_type(typehint, typehint_origin, closed_class):
if not (
(
(inspect.isclass(typehint) and typehint_origin is None)
or (is_generic_class(typehint) and inspect.isclass(typehint.__origin__))
Expand All @@ -1254,35 +1253,61 @@ def is_single_subclass_typehint(typehint, typehint_origin):
and not is_pydantic_type(typehint)
and not is_subclass(typehint, (Path, Enum))
and getattr(typehint_origin, "__module__", "") != "builtins"
)
):
return False
if not closed_class:
return not is_not_subclass_type(typehint)
return True


def yield_subclass_types(typehint, also_lists=False, callable_return=False):
is_single_subclass_type = partial(is_single_class_type, closed_class=False)
is_single_subclass_or_closed_type = partial(is_single_class_type, closed_class=True)


def yield_class_types(typehint, is_single, also_lists=False, callable_return=False):
typehint = typehint_from_action(typehint)
if typehint is None:
return
typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint)))
typehint_origin = get_typehint_origin(typehint)
kwargs = {"is_single": is_single, "also_lists": also_lists, "callable_return": callable_return}
if callable_return and (typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint)):
return_type = get_callable_return_type(typehint)
if return_type:
k = {"also_lists": also_lists, "callable_return": callable_return}
yield from yield_subclass_types(return_type, **k)
yield from yield_class_types(return_type, **kwargs)
elif typehint_origin == Union or (also_lists and typehint_origin in sequence_origin_types):
k = {"also_lists": also_lists, "callable_return": callable_return}
for subtype in typehint.__args__:
yield from yield_subclass_types(subtype, **k)
if is_single_subclass_typehint(typehint, typehint_origin):
yield from yield_class_types(subtype, **kwargs)
if is_single(typehint, typehint_origin):
yield typehint


def get_subclass_types(typehint, also_lists=False, callable_return=False):
types = tuple(yield_subclass_types(typehint, also_lists=also_lists, callable_return=callable_return))
types = tuple(
yield_class_types(
typehint, is_single=is_single_subclass_type, also_lists=also_lists, callable_return=callable_return
)
)
return types or None


def get_subclass_or_closed_types(typehint, also_lists=False, callable_return=False):
types = tuple(
yield_class_types(
typehint,
is_single=is_single_subclass_or_closed_type,
also_lists=also_lists,
callable_return=callable_return,
)
)
return types or None


def get_subclass_names(typehint, callable_return=False):
return tuple(t.__name__ for t in yield_subclass_types(typehint, callable_return=callable_return))
return tuple(
t.__name__
for t in yield_class_types(typehint, is_single=is_single_subclass_type, callable_return=callable_return)
)


def adapt_partial_callable_class(callable_type, subclass_spec):
Expand Down
74 changes: 63 additions & 11 deletions jsonargparse_tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,44 +417,56 @@ class Data:
p2: int = 0


parser_optional_data = ArgumentParser(exit_on_error=False)
parser_optional_data.add_argument("--data", type=Optional[Data])
@pytest.fixture
def parser_optional_data() -> ArgumentParser:
parser = ArgumentParser(exit_on_error=False)
parser.add_argument("--data", type=Optional[Data])
return parser


def test_optional_dataclass_help(parser_optional_data):
help_str = get_parser_help(parser_optional_data)
assert "--data.help" in help_str
assert "CLASS_PATH_OR_NAME" not in help_str
help_str = get_parse_args_stdout(parser_optional_data, ["--data.help"])
assert "--data.p1" in help_str
assert "--data.p2" in help_str

def test_optional_dataclass_type_all_fields():

def test_optional_dataclass_type_all_fields(parser_optional_data):
cfg = parser_optional_data.parse_args(['--data={"p1": "x", "p2": 1}'])
assert cfg == Namespace(data=Namespace(p1="x", p2=1))


def test_optional_dataclass_type_single_field():
def test_optional_dataclass_type_single_field(parser_optional_data):
cfg = parser_optional_data.parse_args(['--data={"p1": "y"}'])
assert cfg == Namespace(data=Namespace(p1="y", p2=0))


def test_optional_dataclass_type_invalid_field():
def test_optional_dataclass_type_invalid_field(parser_optional_data):
with pytest.raises(ArgumentError, match="Expected a <class 'str'>. Got value: 1"):
parser_optional_data.parse_args(['--data={"p1": 1}'])


def test_optional_dataclass_type_instantiate():
def test_optional_dataclass_type_instantiate(parser_optional_data):
cfg = parser_optional_data.parse_args(['--data={"p1": "y", "p2": 2}'])
init = parser_optional_data.instantiate_classes(cfg)
assert isinstance(init.data, Data)
assert init.data.p1 == "y"
assert init.data.p2 == 2


def test_optional_dataclass_type_dump():
def test_optional_dataclass_type_dump(parser_optional_data):
cfg = parser_optional_data.parse_args(['--data={"p1": "z"}'])
assert json_or_yaml_load(parser_optional_data.dump(cfg)) == {"data": {"p1": "z", "p2": 0}}


def test_optional_dataclass_type_missing_required_field():
def test_optional_dataclass_type_missing_required_field(parser_optional_data):
with pytest.raises(ArgumentError):
parser_optional_data.parse_args(['--data={"p2": 2}'])


def test_optional_dataclass_type_null_value():
def test_optional_dataclass_type_null_value(parser_optional_data):
cfg = parser_optional_data.parse_args(["--data=null"])
assert cfg == Namespace(data=None)
assert cfg == parser_optional_data.instantiate_classes(cfg)
Expand Down Expand Up @@ -514,6 +526,10 @@ def test_dataclass_in_union_type(parser):
cfg = parser.parse_args(["--union=1"])
assert cfg == Namespace(union=1)
assert cfg == parser.instantiate_classes(cfg)
help_str = get_parser_help(parser)
assert "--union.help" in help_str
help_str = get_parse_args_stdout(parser, ["--union.help"])
assert f"Help for --union.help={__name__}.Data" in help_str


def test_dataclass_in_list_type(parser):
Expand Down Expand Up @@ -631,6 +647,16 @@ def test_class_path_union_mixture_dataclass_and_class(parser, union_type):
assert init.union.prm_1 == 1.2
assert json_or_yaml_load(parser.dump(cfg))["union"] == value

help_str = get_parser_help(parser)
assert "--union.help" in help_str
help_str = [x for x in help_str.split("\n") if "help for the given subclass" in x][0]
assert "UnionData" in help_str
assert "UnionClass" in help_str
help_str = get_parse_args_stdout(parser, ["--union.help=UnionData"])
assert f"Help for --union.help={__name__}.UnionData" in help_str
help_str = get_parse_args_stdout(parser, ["--union.help=UnionClass"])
assert f"Help for --union.help={__name__}.UnionClass" in help_str


def test_class_path_union_dataclasses(parser):
parser.add_argument("--union", type=Union[Data, SingleParamChange, UnionData])
Expand Down Expand Up @@ -735,21 +761,47 @@ def test_dataclass_not_subclass(parser):
parser.add_argument("--data", type=DataMain, default=DataMain(p1=2))

help_str = get_parser_help(parser)
assert "--data.help [CLASS_PATH_OR_NAME]" not in help_str
assert "--data.help" not in help_str

config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}}
with pytest.raises(ArgumentError, match="Group 'data' does not accept nested key 'init_args.p2'"):
parser.parse_args([f"--data={json.dumps(config)}"])


def test_add_subclass_dataclass_not_subclass(parser):
with pytest.raises(ValueError, match="Expected .* a subclass type or a tuple of subclass types"):
parser.add_subclass_arguments(DataMain, "data")


@pytest.fixture
def subclass_behavior():
with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors:
not_subclass_type_selectors.pop("dataclass")
yield


def test_dataclass_argument_as_subclass(parser, subtests, subclass_behavior):
@pytest.mark.parametrize("default", [None, DataMain()])
def test_add_subclass_dataclass_as_subclass(parser, default, subclass_behavior):
parser.add_subclass_arguments(DataMain, "data", default=default)

config = {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}}
cfg = parser.parse_args([f"--data={json.dumps(config)}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.data, DataMain)
assert dataclasses.asdict(init.data) == {"p1": 2}
dump = json_or_yaml_load(parser.dump(cfg))["data"]
assert dump == {"class_path": f"{__name__}.DataMain", "init_args": {"p1": 2}}

config = {"class_path": f"{__name__}.DataSub", "init_args": {"p2": "y"}}
cfg = parser.parse_args([f"--data={json.dumps(config)}"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.data, DataSub)
assert dataclasses.asdict(init.data) == {"p1": 1, "p2": "y"}
dump = json_or_yaml_load(parser.dump(cfg))["data"]
assert dump == {"class_path": f"{__name__}.DataSub", "init_args": {"p1": 1, "p2": "y"}}


def test_add_argument_dataclass_as_subclass(parser, subtests, subclass_behavior):
parser.add_argument("--data", type=DataMain, default=DataMain(p1=2))

with subtests.test("help"):
Expand Down
1 change: 1 addition & 0 deletions jsonargparse_tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def test_model_argument_as_subclass(parser, subtests, subclass_behavior):
assert f"{__name__}.Person" in help_str
help_str = get_parse_args_stdout(parser, ["--person.help"])
assert f"Help for --person.help={__name__}.Person" in help_str
assert "--person.pets.help [CLASS_PATH_OR_NAME]" in help_str

with subtests.test("defaults"):
defaults = parser.get_defaults()
Expand Down
Loading