Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Added
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
<https://github.com/omni-us/jsonargparse/pull/770>`__).
- Experimental support for dataclass inheritance (`#775
<https://github.com/omni-us/jsonargparse/pull/775>`__).

Fixed
^^^^^
Expand Down
58 changes: 37 additions & 21 deletions jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import ( # type: ignore[attr-defined]
Callable,
Dict,
Generic,
List,
Expand All @@ -28,6 +29,8 @@
import_reconplogger,
is_alias_type,
is_annotated,
is_attrs_class,
is_pydantic_model,
reconplogger_support,
typing_extensions_import,
)
Expand Down Expand Up @@ -212,11 +215,22 @@ def supports_optionals_as_positionals(parser):


def is_subclass(cls, class_or_tuple) -> bool:
"""Extension of issubclass that supports non-class arguments."""
"""Extension of issubclass that supports non-class arguments and generics."""
try:
return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
class_or_tuple = get_generic_origins(class_or_tuple)
if inspect.isclass(cls):
return issubclass(cls, class_or_tuple)
elif is_generic_class(cls):
return issubclass(cls.__origin__, class_or_tuple)
except TypeError:
return False
pass # TypeError means that cls is not a class
return False


def is_instance(obj, class_or_tuple) -> bool:
"""Extension of isinstance that supports generics."""
class_or_tuple = get_generic_origins(class_or_tuple)
return isinstance(obj, class_or_tuple)


def is_final_class(cls) -> bool:
Expand All @@ -236,6 +250,12 @@ def get_generic_origin(cls):
return cls.__origin__ if is_generic_class(cls) else cls


def get_generic_origins(class_or_tuple):
if isinstance(class_or_tuple, tuple):
return tuple(get_generic_origin(cls) for cls in class_or_tuple)
return get_generic_origin(class_or_tuple)


def get_unaliased_type(cls):
new_cls = cls
while True:
Expand All @@ -249,29 +269,25 @@ def get_unaliased_type(cls):
return cur_cls


def is_dataclass_like(cls) -> bool:
if is_generic_class(cls):
return is_dataclass_like(cls.__origin__)
if not inspect.isclass(cls) or cls is object:
return False
if is_final_class(cls):
return True
def is_pure_dataclass(cls) -> bool:
classes = [c for c in inspect.getmro(cls) if c not in {object, Generic}]
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)

if not all_dataclasses:
from ._optionals import attrs_support, is_pydantic_model
return all(dataclasses.is_dataclass(c) for c in classes)

if is_pydantic_model(cls):
return True

if attrs_support:
import attrs
not_subclass_type_selectors: Dict[str, Callable[[Type], Union[bool, int]]] = {
"final": is_final_class,
"dataclass": is_pure_dataclass,
"pydantic": is_pydantic_model,
"attrs": is_attrs_class,
}

if attrs.has(cls):
return True

return all_dataclasses
def is_not_subclass_type(cls) -> bool:
if is_generic_class(cls):
return is_not_subclass_type(cls.__origin__)
if not inspect.isclass(cls):
return False
return any(validator(cls) for validator in not_subclass_type_selectors.values())


def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType:
Expand Down
8 changes: 3 additions & 5 deletions jsonargparse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
class_instantiators,
debug_mode_active,
get_optionals_as_positionals_actions,
is_dataclass_like,
is_not_subclass_type,
lenient_check,
parser_context,
supports_optionals_as_positionals,
Expand Down Expand Up @@ -132,7 +132,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
return ActionParser._move_parser_actions(parser, args, kwargs)
ActionConfigFile._ensure_single_config_argument(self, kwargs["action"])
if "type" in kwargs:
if is_dataclass_like(kwargs["type"]):
if is_not_subclass_type(kwargs["type"]):
nested_key = args[0].lstrip("-")
self.add_class_arguments(kwargs.pop("type"), nested_key, **kwargs)
return _find_action(parser, nested_key)
Expand Down Expand Up @@ -1236,9 +1236,7 @@ def instantiate_classes(
"""
components: List[Union[ActionTypeHint, _ActionConfigLoad, ArgumentGroup]] = []
for action in filter_default_actions(self._actions):
if isinstance(action, ActionTypeHint) or (
isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)
):
if isinstance(action, ActionTypeHint):
components.append(action)

if instantiate_groups:
Expand Down
2 changes: 2 additions & 0 deletions jsonargparse/_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ def update(
if not only_unset or key not in self:
self[key] = value
else:
if key and not isinstance(self.get(key), Namespace):
self[key] = Namespace()
prefix = key + "." if key else ""
for subkey, subval in value.items():
if not only_unset or prefix + subkey not in self:
Expand Down
28 changes: 19 additions & 9 deletions jsonargparse/_optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,18 +346,28 @@ def get_pydantic_supports_field_init() -> bool:


def is_pydantic_model(class_type) -> int:
classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else []
for cls in classes:
if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel":
import pydantic

if issubclass(cls, pydantic.BaseModel):
return pydantic_support
elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel):
return 1
if pydantic_support:
classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else []
for cls in classes:
if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel":
import pydantic

if issubclass(cls, pydantic.BaseModel):
return pydantic_support
elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel):
return 1
return 0


def is_attrs_class(class_type) -> bool:
if attrs_support:
import attrs

if attrs.has(class_type):
return True
return False


def get_module(value):
return getattr(type(value), "__module__", "").split(".", 1)[0]

Expand Down
3 changes: 0 additions & 3 deletions jsonargparse/_parameter_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
LoggerProperty,
get_generic_origin,
get_unaliased_type,
is_dataclass_like,
is_generic_class,
is_subclass,
is_unpack_typehint,
Expand Down Expand Up @@ -372,8 +371,6 @@ def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component


def is_param_subclass_instance_default(param: ParamData) -> bool:
if is_dataclass_like(type(param.default)):
return False
from ._typehints import ActionTypeHint, get_optional_arg, get_subclass_types

annotation = get_optional_arg(param.annotation)
Expand Down
30 changes: 17 additions & 13 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
get_class_instantiator,
get_generic_origin,
get_unaliased_type,
is_dataclass_like,
is_final_class,
is_not_subclass_type,
is_subclass,
)
from ._namespace import Namespace
from ._optionals import attrs_support, get_doc_short_description, is_pydantic_model, pydantic_support
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support
from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters
from ._typehints import (
ActionTypeHint,
LazyInitBaseClass,
callable_instances,
get_subclass_names,
is_optional,
is_subclass_spec,
not_required_types,
)
from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
Expand Down Expand Up @@ -86,7 +87,7 @@ def add_class_arguments(
or (isinstance(default, LazyInitBaseClass) and isinstance(default, unaliased_class_type))
or (
not is_final_class(default.__class__)
and is_dataclass_like(default.__class__)
and is_not_subclass_type(default.__class__)
and isinstance(default, unaliased_class_type)
)
):
Expand Down Expand Up @@ -120,15 +121,15 @@ def add_class_arguments(
defaults = default
if isinstance(default, LazyInitBaseClass):
defaults = default.lazy_get_init_args().as_dict()
elif is_dataclass_like(default.__class__):
defaults = dataclass_to_dict(default)
elif is_convertible_to_dict(default.__class__):
defaults = convert_to_dict(default)
args = {k[len(prefix) :] for k in added_args}
skip_not_added = [k for k in defaults if k not in args]
if skip_not_added:
skip.update(skip_not_added) # skip init=False
elif isinstance(default, Namespace):
defaults = default.as_dict()
if defaults:
if is_subclass_spec(defaults):
defaults = defaults.get("init_args", {})
defaults = {prefix + k: v for k, v in defaults.items() if k not in skip}
self.set_defaults(**defaults) # type: ignore[attr-defined]

Expand Down Expand Up @@ -389,7 +390,7 @@ def _add_signature_parameter(
elif not as_positional or is_non_positional:
kwargs["required"] = True
is_subclass_typehint = False
is_dataclass_like_typehint = is_dataclass_like(annotation)
is_not_subclass_typehint = is_not_subclass_type(annotation)
dest = (nested_key + "." if nested_key else "") + name
args = [dest if is_required and as_positional and not is_non_positional else "--" + dest]
if param.origin:
Expand All @@ -407,7 +408,7 @@ def _add_signature_parameter(
if (
annotation in {str, int, float, bool}
or is_subclass(annotation, (str, int, float))
or is_dataclass_like_typehint
or is_not_subclass_typehint
):
kwargs["type"] = annotation
register_pydantic_type(annotation)
Expand Down Expand Up @@ -441,7 +442,7 @@ def _add_signature_parameter(
"sub_configs": sub_configs,
"instantiate": instantiate,
}
if is_dataclass_like_typehint:
if is_not_subclass_typehint:
kwargs.update(sub_add_kwargs)
with ActionTypeHint.allow_default_instance_context():
action = container.add_argument(*args, **kwargs)
Expand Down Expand Up @@ -492,8 +493,6 @@ def add_subclass_arguments(
Raises:
ValueError: When given an invalid base class.
"""
if is_dataclass_like(baseclass):
raise ValueError("Not allowed for dataclass-like classes.")
if type(baseclass) is not tuple:
baseclass = (baseclass,) # type: ignore[assignment]
if not baseclass or not all(ActionTypeHint.is_subclass_typehint(c, also_lists=True) for c in baseclass):
Expand Down Expand Up @@ -590,7 +589,11 @@ def is_factory_class(value):
return value.__class__ == dataclasses._HAS_DEFAULT_FACTORY_CLASS


def dataclass_to_dict(value) -> dict:
def is_convertible_to_dict(value):
return dataclasses.is_dataclass(value) or is_attrs_class(value) or is_pydantic_model(value)


def convert_to_dict(value) -> dict:
if pydantic_support:
pydantic_model = is_pydantic_model(type(value))
if pydantic_model:
Expand All @@ -602,6 +605,7 @@ def dataclass_to_dict(value) -> dict:
is_attrs_dataclass = attrs.has(type(value))
if is_attrs_dataclass:
return attrs.asdict(value)

return dataclasses.asdict(value)


Expand Down
Loading