1212 get_class_instantiator ,
1313 get_generic_origin ,
1414 get_unaliased_type ,
15- is_dataclass_like ,
1615 is_final_class ,
16+ is_not_subclass_type ,
1717 is_subclass ,
1818)
1919from ._namespace import Namespace
20- from ._optionals import attrs_support , get_doc_short_description , is_pydantic_model , pydantic_support
20+ from ._optionals import attrs_support , get_doc_short_description , is_attrs_class , is_pydantic_model , pydantic_support
2121from ._parameter_resolvers import ParamData , get_parameter_origins , get_signature_parameters
2222from ._typehints import (
2323 ActionTypeHint ,
2424 LazyInitBaseClass ,
2525 callable_instances ,
2626 get_subclass_names ,
2727 is_optional ,
28+ is_subclass_spec ,
2829 not_required_types ,
2930)
3031from ._util import NoneType , get_private_kwargs , get_typehint_origin , iter_to_set_str
@@ -86,7 +87,7 @@ def add_class_arguments(
8687 or (isinstance (default , LazyInitBaseClass ) and isinstance (default , unaliased_class_type ))
8788 or (
8889 not is_final_class (default .__class__ )
89- and is_dataclass_like (default .__class__ )
90+ and is_not_subclass_type (default .__class__ )
9091 and isinstance (default , unaliased_class_type )
9192 )
9293 ):
@@ -120,15 +121,15 @@ def add_class_arguments(
120121 defaults = default
121122 if isinstance (default , LazyInitBaseClass ):
122123 defaults = default .lazy_get_init_args ().as_dict ()
123- elif is_dataclass_like (default .__class__ ):
124- defaults = dataclass_to_dict (default )
124+ elif is_convertible_to_dict (default .__class__ ):
125+ defaults = convert_to_dict (default )
125126 args = {k [len (prefix ) :] for k in added_args }
126127 skip_not_added = [k for k in defaults if k not in args ]
127128 if skip_not_added :
128129 skip .update (skip_not_added ) # skip init=False
129- elif isinstance (default , Namespace ):
130- defaults = default .as_dict ()
131130 if defaults :
131+ if is_subclass_spec (defaults ):
132+ defaults = defaults .get ("init_args" , {})
132133 defaults = {prefix + k : v for k , v in defaults .items () if k not in skip }
133134 self .set_defaults (** defaults ) # type: ignore[attr-defined]
134135
@@ -389,7 +390,7 @@ def _add_signature_parameter(
389390 elif not as_positional or is_non_positional :
390391 kwargs ["required" ] = True
391392 is_subclass_typehint = False
392- is_dataclass_like_typehint = is_dataclass_like (annotation )
393+ is_not_subclass_typehint = is_not_subclass_type (annotation )
393394 dest = (nested_key + "." if nested_key else "" ) + name
394395 args = [dest if is_required and as_positional and not is_non_positional else "--" + dest ]
395396 if param .origin :
@@ -407,7 +408,7 @@ def _add_signature_parameter(
407408 if (
408409 annotation in {str , int , float , bool }
409410 or is_subclass (annotation , (str , int , float ))
410- or is_dataclass_like_typehint
411+ or is_not_subclass_typehint
411412 ):
412413 kwargs ["type" ] = annotation
413414 register_pydantic_type (annotation )
@@ -441,7 +442,7 @@ def _add_signature_parameter(
441442 "sub_configs" : sub_configs ,
442443 "instantiate" : instantiate ,
443444 }
444- if is_dataclass_like_typehint :
445+ if is_not_subclass_typehint :
445446 kwargs .update (sub_add_kwargs )
446447 with ActionTypeHint .allow_default_instance_context ():
447448 action = container .add_argument (* args , ** kwargs )
@@ -492,8 +493,6 @@ def add_subclass_arguments(
492493 Raises:
493494 ValueError: When given an invalid base class.
494495 """
495- if is_dataclass_like (baseclass ):
496- raise ValueError ("Not allowed for dataclass-like classes." )
497496 if type (baseclass ) is not tuple :
498497 baseclass = (baseclass ,) # type: ignore[assignment]
499498 if not baseclass or not all (ActionTypeHint .is_subclass_typehint (c , also_lists = True ) for c in baseclass ):
@@ -590,7 +589,11 @@ def is_factory_class(value):
590589 return value .__class__ == dataclasses ._HAS_DEFAULT_FACTORY_CLASS
591590
592591
593- def dataclass_to_dict (value ) -> dict :
592+ def is_convertible_to_dict (value ):
593+ return dataclasses .is_dataclass (value ) or is_attrs_class (value ) or is_pydantic_model (value )
594+
595+
596+ def convert_to_dict (value ) -> dict :
594597 if pydantic_support :
595598 pydantic_model = is_pydantic_model (type (value ))
596599 if pydantic_model :
@@ -602,6 +605,7 @@ def dataclass_to_dict(value) -> dict:
602605 is_attrs_dataclass = attrs .has (type (value ))
603606 if is_attrs_dataclass :
604607 return attrs .asdict (value )
608+
605609 return dataclasses .asdict (value )
606610
607611
0 commit comments