Skip to content

Commit 2bb7ab4

Browse files
authored
Experimental support for dataclass inheritance (#775)
1 parent 66ce10d commit 2bb7ab4

File tree

12 files changed

+308
-113
lines changed

12 files changed

+308
-113
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Added
2020
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
2121
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
2222
<https://github.com/omni-us/jsonargparse/pull/770>`__).
23+
- Experimental support for dataclass inheritance (`#775
24+
<https://github.com/omni-us/jsonargparse/pull/775>`__).
2325

2426
Fixed
2527
^^^^^

jsonargparse/_common.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from contextlib import contextmanager
77
from contextvars import ContextVar
88
from typing import ( # type: ignore[attr-defined]
9+
Callable,
910
Dict,
1011
Generic,
1112
List,
@@ -28,6 +29,8 @@
2829
import_reconplogger,
2930
is_alias_type,
3031
is_annotated,
32+
is_attrs_class,
33+
is_pydantic_model,
3134
reconplogger_support,
3235
typing_extensions_import,
3336
)
@@ -212,11 +215,22 @@ def supports_optionals_as_positionals(parser):
212215

213216

214217
def is_subclass(cls, class_or_tuple) -> bool:
215-
"""Extension of issubclass that supports non-class arguments."""
218+
"""Extension of issubclass that supports non-class arguments and generics."""
216219
try:
217-
return inspect.isclass(cls) and issubclass(cls, class_or_tuple)
220+
class_or_tuple = get_generic_origins(class_or_tuple)
221+
if inspect.isclass(cls):
222+
return issubclass(cls, class_or_tuple)
223+
elif is_generic_class(cls):
224+
return issubclass(cls.__origin__, class_or_tuple)
218225
except TypeError:
219-
return False
226+
pass # TypeError means that cls is not a class
227+
return False
228+
229+
230+
def is_instance(obj, class_or_tuple) -> bool:
231+
"""Extension of isinstance that supports generics."""
232+
class_or_tuple = get_generic_origins(class_or_tuple)
233+
return isinstance(obj, class_or_tuple)
220234

221235

222236
def is_final_class(cls) -> bool:
@@ -236,6 +250,12 @@ def get_generic_origin(cls):
236250
return cls.__origin__ if is_generic_class(cls) else cls
237251

238252

253+
def get_generic_origins(class_or_tuple):
254+
if isinstance(class_or_tuple, tuple):
255+
return tuple(get_generic_origin(cls) for cls in class_or_tuple)
256+
return get_generic_origin(class_or_tuple)
257+
258+
239259
def get_unaliased_type(cls):
240260
new_cls = cls
241261
while True:
@@ -249,29 +269,25 @@ def get_unaliased_type(cls):
249269
return cur_cls
250270

251271

252-
def is_dataclass_like(cls) -> bool:
253-
if is_generic_class(cls):
254-
return is_dataclass_like(cls.__origin__)
255-
if not inspect.isclass(cls) or cls is object:
256-
return False
257-
if is_final_class(cls):
258-
return True
272+
def is_pure_dataclass(cls) -> bool:
259273
classes = [c for c in inspect.getmro(cls) if c not in {object, Generic}]
260-
all_dataclasses = all(dataclasses.is_dataclass(c) for c in classes)
261-
262-
if not all_dataclasses:
263-
from ._optionals import attrs_support, is_pydantic_model
274+
return all(dataclasses.is_dataclass(c) for c in classes)
264275

265-
if is_pydantic_model(cls):
266-
return True
267276

268-
if attrs_support:
269-
import attrs
277+
not_subclass_type_selectors: Dict[str, Callable[[Type], Union[bool, int]]] = {
278+
"final": is_final_class,
279+
"dataclass": is_pure_dataclass,
280+
"pydantic": is_pydantic_model,
281+
"attrs": is_attrs_class,
282+
}
270283

271-
if attrs.has(cls):
272-
return True
273284

274-
return all_dataclasses
285+
def is_not_subclass_type(cls) -> bool:
286+
if is_generic_class(cls):
287+
return is_not_subclass_type(cls.__origin__)
288+
if not inspect.isclass(cls):
289+
return False
290+
return any(validator(cls) for validator in not_subclass_type_selectors.values())
275291

276292

277293
def default_class_instantiator(class_type: Type[ClassType], *args, **kwargs) -> ClassType:

jsonargparse/_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
class_instantiators,
4444
debug_mode_active,
4545
get_optionals_as_positionals_actions,
46-
is_dataclass_like,
46+
is_not_subclass_type,
4747
lenient_check,
4848
parser_context,
4949
supports_optionals_as_positionals,
@@ -132,7 +132,7 @@ def add_argument(self, *args, enable_path: bool = False, **kwargs):
132132
return ActionParser._move_parser_actions(parser, args, kwargs)
133133
ActionConfigFile._ensure_single_config_argument(self, kwargs["action"])
134134
if "type" in kwargs:
135-
if is_dataclass_like(kwargs["type"]):
135+
if is_not_subclass_type(kwargs["type"]):
136136
nested_key = args[0].lstrip("-")
137137
self.add_class_arguments(kwargs.pop("type"), nested_key, **kwargs)
138138
return _find_action(parser, nested_key)
@@ -1236,9 +1236,7 @@ def instantiate_classes(
12361236
"""
12371237
components: List[Union[ActionTypeHint, _ActionConfigLoad, ArgumentGroup]] = []
12381238
for action in filter_default_actions(self._actions):
1239-
if isinstance(action, ActionTypeHint) or (
1240-
isinstance(action, _ActionConfigLoad) and is_dataclass_like(action.basetype)
1241-
):
1239+
if isinstance(action, ActionTypeHint):
12421240
components.append(action)
12431241

12441242
if instantiate_groups:

jsonargparse/_namespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def update(
293293
if not only_unset or key not in self:
294294
self[key] = value
295295
else:
296+
if key and not isinstance(self.get(key), Namespace):
297+
self[key] = Namespace()
296298
prefix = key + "." if key else ""
297299
for subkey, subval in value.items():
298300
if not only_unset or prefix + subkey not in self:

jsonargparse/_optionals.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,18 +346,28 @@ def get_pydantic_supports_field_init() -> bool:
346346

347347

348348
def is_pydantic_model(class_type) -> int:
349-
classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else []
350-
for cls in classes:
351-
if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel":
352-
import pydantic
353-
354-
if issubclass(cls, pydantic.BaseModel):
355-
return pydantic_support
356-
elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel):
357-
return 1
349+
if pydantic_support:
350+
classes = inspect.getmro(class_type) if pydantic_support and inspect.isclass(class_type) else []
351+
for cls in classes:
352+
if getattr(cls, "__module__", "").startswith("pydantic") and getattr(cls, "__name__", "") == "BaseModel":
353+
import pydantic
354+
355+
if issubclass(cls, pydantic.BaseModel):
356+
return pydantic_support
357+
elif pydantic_support > 1 and issubclass(cls, pydantic.v1.BaseModel):
358+
return 1
358359
return 0
359360

360361

362+
def is_attrs_class(class_type) -> bool:
363+
if attrs_support:
364+
import attrs
365+
366+
if attrs.has(class_type):
367+
return True
368+
return False
369+
370+
361371
def get_module(value):
362372
return getattr(type(value), "__module__", "").split(".", 1)[0]
363373

jsonargparse/_parameter_resolvers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
LoggerProperty,
1818
get_generic_origin,
1919
get_unaliased_type,
20-
is_dataclass_like,
2120
is_generic_class,
2221
is_subclass,
2322
is_unpack_typehint,
@@ -372,8 +371,6 @@ def add_stub_types(stubs: Optional[Dict[str, Any]], params: ParamList, component
372371

373372

374373
def is_param_subclass_instance_default(param: ParamData) -> bool:
375-
if is_dataclass_like(type(param.default)):
376-
return False
377374
from ._typehints import ActionTypeHint, get_optional_arg, get_subclass_types
378375

379376
annotation = get_optional_arg(param.annotation)

jsonargparse/_signatures.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
get_class_instantiator,
1313
get_generic_origin,
1414
get_unaliased_type,
15-
is_dataclass_like,
1615
is_final_class,
16+
is_not_subclass_type,
1717
is_subclass,
1818
)
1919
from ._namespace import Namespace
20-
from ._optionals import attrs_support, get_doc_short_description, is_pydantic_model, pydantic_support
20+
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support
2121
from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters
2222
from ._typehints import (
2323
ActionTypeHint,
2424
LazyInitBaseClass,
2525
callable_instances,
2626
get_subclass_names,
2727
is_optional,
28+
is_subclass_spec,
2829
not_required_types,
2930
)
3031
from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
@@ -86,7 +87,7 @@ def add_class_arguments(
8687
or (isinstance(default, LazyInitBaseClass) and isinstance(default, unaliased_class_type))
8788
or (
8889
not is_final_class(default.__class__)
89-
and is_dataclass_like(default.__class__)
90+
and is_not_subclass_type(default.__class__)
9091
and isinstance(default, unaliased_class_type)
9192
)
9293
):
@@ -120,15 +121,15 @@ def add_class_arguments(
120121
defaults = default
121122
if isinstance(default, LazyInitBaseClass):
122123
defaults = default.lazy_get_init_args().as_dict()
123-
elif is_dataclass_like(default.__class__):
124-
defaults = dataclass_to_dict(default)
124+
elif is_convertible_to_dict(default.__class__):
125+
defaults = convert_to_dict(default)
125126
args = {k[len(prefix) :] for k in added_args}
126127
skip_not_added = [k for k in defaults if k not in args]
127128
if skip_not_added:
128129
skip.update(skip_not_added) # skip init=False
129-
elif isinstance(default, Namespace):
130-
defaults = default.as_dict()
131130
if defaults:
131+
if is_subclass_spec(defaults):
132+
defaults = defaults.get("init_args", {})
132133
defaults = {prefix + k: v for k, v in defaults.items() if k not in skip}
133134
self.set_defaults(**defaults) # type: ignore[attr-defined]
134135

@@ -389,7 +390,7 @@ def _add_signature_parameter(
389390
elif not as_positional or is_non_positional:
390391
kwargs["required"] = True
391392
is_subclass_typehint = False
392-
is_dataclass_like_typehint = is_dataclass_like(annotation)
393+
is_not_subclass_typehint = is_not_subclass_type(annotation)
393394
dest = (nested_key + "." if nested_key else "") + name
394395
args = [dest if is_required and as_positional and not is_non_positional else "--" + dest]
395396
if param.origin:
@@ -407,7 +408,7 @@ def _add_signature_parameter(
407408
if (
408409
annotation in {str, int, float, bool}
409410
or is_subclass(annotation, (str, int, float))
410-
or is_dataclass_like_typehint
411+
or is_not_subclass_typehint
411412
):
412413
kwargs["type"] = annotation
413414
register_pydantic_type(annotation)
@@ -441,7 +442,7 @@ def _add_signature_parameter(
441442
"sub_configs": sub_configs,
442443
"instantiate": instantiate,
443444
}
444-
if is_dataclass_like_typehint:
445+
if is_not_subclass_typehint:
445446
kwargs.update(sub_add_kwargs)
446447
with ActionTypeHint.allow_default_instance_context():
447448
action = container.add_argument(*args, **kwargs)
@@ -492,8 +493,6 @@ def add_subclass_arguments(
492493
Raises:
493494
ValueError: When given an invalid base class.
494495
"""
495-
if is_dataclass_like(baseclass):
496-
raise ValueError("Not allowed for dataclass-like classes.")
497496
if type(baseclass) is not tuple:
498497
baseclass = (baseclass,) # type: ignore[assignment]
499498
if not baseclass or not all(ActionTypeHint.is_subclass_typehint(c, also_lists=True) for c in baseclass):
@@ -590,7 +589,11 @@ def is_factory_class(value):
590589
return value.__class__ == dataclasses._HAS_DEFAULT_FACTORY_CLASS
591590

592591

593-
def dataclass_to_dict(value) -> dict:
592+
def is_convertible_to_dict(value):
593+
return dataclasses.is_dataclass(value) or is_attrs_class(value) or is_pydantic_model(value)
594+
595+
596+
def convert_to_dict(value) -> dict:
594597
if pydantic_support:
595598
pydantic_model = is_pydantic_model(type(value))
596599
if pydantic_model:
@@ -602,6 +605,7 @@ def dataclass_to_dict(value) -> dict:
602605
is_attrs_dataclass = attrs.has(type(value))
603606
if is_attrs_dataclass:
604607
return attrs.asdict(value)
608+
605609
return dataclasses.asdict(value)
606610

607611

0 commit comments

Comments
 (0)