diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b2e3c2df..0f464aa5 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -20,8 +20,10 @@ Added - ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770 `__). -- Experimental support for dataclass inheritance (`#775 +- Experimental support for dataclass inheritance (`#775 `__). +- Experimental support for pydantic ``BaseModel`` subclasses (`#781 + `__). Fixed ^^^^^ diff --git a/jsonargparse/_signatures.py b/jsonargparse/_signatures.py index 57c726c9..beee9ac0 100644 --- a/jsonargparse/_signatures.py +++ b/jsonargparse/_signatures.py @@ -17,7 +17,7 @@ is_subclass, ) from ._namespace import Namespace -from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support +from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters from ._typehints import ( ActionTypeHint, @@ -27,8 +27,9 @@ is_optional, is_subclass_spec, not_required_types, + sequence_origin_types, ) -from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str +from ._util import NoneType, get_import_path, get_private_kwargs, get_typehint_origin, iter_to_set_str from .typing import register_pydantic_type __all__ = [ @@ -594,19 +595,28 @@ def is_convertible_to_dict(value): def convert_to_dict(value) -> dict: - if pydantic_support: - pydantic_model = is_pydantic_model(type(value)) - if pydantic_model: - return value.dict() if pydantic_model == 1 else value.model_dump() - if attrs_support: import attrs - is_attrs_dataclass = attrs.has(type(value)) - if is_attrs_dataclass: + if attrs.has(type(value)): return attrs.asdict(value) - return dataclasses.asdict(value) + value_type = type(value) + init_args = {} + for name, attr in vars(value).items(): + attr_type = type(attr) + if is_convertible_to_dict(attr_type): + attr = convert_to_dict(attr) + elif attr_type in sequence_origin_types: + attr = attr.copy() + for num, item in enumerate(attr): + if is_convertible_to_dict(type(item)): + attr[num] = convert_to_dict(item) + init_args[name] = attr + + if is_not_subclass_type(value_type): + return init_args + return {"class_path": get_import_path(value_type), "init_args": init_args} def compose_dataclasses(*args): diff --git a/jsonargparse_tests/test_dataclasses.py b/jsonargparse_tests/test_dataclasses.py index 89781f71..1f465b2d 100644 --- a/jsonargparse_tests/test_dataclasses.py +++ b/jsonargparse_tests/test_dataclasses.py @@ -21,6 +21,7 @@ type_alias_type, typing_extensions_import, ) +from jsonargparse._signatures import convert_to_dict from jsonargparse.typing import PositiveFloat, PositiveInt from jsonargparse_tests.conftest import ( get_parse_args_stdout, @@ -852,3 +853,75 @@ def test_dataclass_nested_as_subclass(parser, subclass_behavior): assert isinstance(init.parent, ParentData) assert isinstance(init.parent.data, DataSub) assert dataclasses.asdict(init.parent.data) == {"p1": 3, "p2": "x"} + + +@dataclasses.dataclass +class Pet: + name: str + + +@dataclasses.dataclass +class Cat(Pet): + meows: int + + +@dataclasses.dataclass +class SpecialCat(Cat): + number_of_tails: int + + +@dataclasses.dataclass +class Dog(Pet): + barks: float + friend: Pet + + +@dataclasses.dataclass +class Person(Pet): + name: str + pets: list[Pet] + + +person = Person( + name="jt", + pets=[ + SpecialCat(name="sc", number_of_tails=2, meows=3), + Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)), + ], +) + + +def test_convert_to_dict_not_subclass(): + person_dict = convert_to_dict(person) + assert person_dict == { + "name": "jt", + "pets": [ + {"name": "sc", "meows": 3, "number_of_tails": 2}, + { + "name": "dog", + "barks": 2.0, + "friend": {"name": "cc", "meows": 2}, + }, + ], + } + + +def test_convert_to_dict_subclass(subclass_behavior): + person_dict = convert_to_dict(person) + assert person_dict == { + "class_path": f"{__name__}.Person", + "init_args": { + "name": "jt", + "pets": [ + {"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}}, + { + "class_path": f"{__name__}.Dog", + "init_args": { + "name": "dog", + "barks": 2.0, + "friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}}, + }, + }, + ], + }, + } diff --git a/jsonargparse_tests/test_pydantic.py b/jsonargparse_tests/test_pydantic.py index e5b660f4..8b9bcc77 100644 --- a/jsonargparse_tests/test_pydantic.py +++ b/jsonargparse_tests/test_pydantic.py @@ -2,7 +2,9 @@ import dataclasses import json +from copy import deepcopy from typing import Dict, List, Literal, Optional, Union +from unittest.mock import patch import pytest @@ -13,7 +15,9 @@ pydantic_supports_field_init, typing_extensions_import, ) +from jsonargparse._signatures import convert_to_dict from jsonargparse_tests.conftest import ( + get_parse_args_stdout, get_parser_help, json_or_yaml_load, ) @@ -35,6 +39,13 @@ def missing_pydantic(): pytest.skip("pydantic package is required") +@pytest.fixture +def subclass_behavior(): + with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors: + not_subclass_type_selectors.pop("pydantic") + yield + + @skip_if_pydantic_v1_on_v2 def test_pydantic_secret_str(parser): parser.add_argument("--password", type=pydantic.SecretStr) @@ -142,6 +153,10 @@ class PydanticDataNested: class PydanticDataFieldInitFalse: p1: str = PydanticV2Field("-", init=False) + @pydantic_v2_dataclass + class ParentPydanticDataFieldInitFalse: + y: PydanticDataFieldInitFalse = PydanticV2Field(default_factory=PydanticDataFieldInitFalse) + @pydantic.dataclasses.dataclass class PydanticDataStdlibField: p1: str = dataclasses.field(default="-") @@ -263,6 +278,17 @@ def test_dataclass_field_init_false(self, parser): init = parser.instantiate_classes(cfg) assert init.data.p1 == "-" + @pytest.mark.skipif(not pydantic_supports_field_init, reason="Field.init is required") + def test_nested_dataclass_field_init_false(self, parser): + parser.add_class_arguments(ParentPydanticDataFieldInitFalse, "data") + assert parser.get_defaults() == Namespace() + cfg = parser.parse_args([]) + assert cfg == Namespace() + init = parser.instantiate_classes(cfg) + assert isinstance(init.data, ParentPydanticDataFieldInitFalse) + assert isinstance(init.data.y, PydanticDataFieldInitFalse) + assert init.data.y.p1 == "-" + def test_dataclass_stdlib_field(self, parser): parser.add_argument("--data", type=PydanticDataStdlibField) cfg = parser.parse_args(["--data", "{}"]) @@ -310,3 +336,99 @@ def test_nested_dict(self, parser): init = parser.instantiate_classes(cfg) assert isinstance(init.model, PydanticNestedDict) assert isinstance(init.model.nested["key"], NestedModel) + + +if pydantic_support: + + class Pet(pydantic.BaseModel): + name: str + + class Cat(Pet): + meows: int + + class SpecialCat(Cat): + number_of_tails: int + + class Dog(Pet): + barks: float + friend: Pet + + class Person(Pet): + name: str + pets: list[Pet] + + person = Person( + name="jt", + pets=[ + SpecialCat(name="sc", number_of_tails=2, meows=3), + Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)), + ], + ) + + person_expected_dict = { + "name": "jt", + "pets": [ + {"name": "sc", "meows": 3, "number_of_tails": 2}, + { + "name": "dog", + "barks": 2.0, + "friend": {"name": "cc", "meows": 2}, + }, + ], + } + + person_expected_subclass_dict = { + "class_path": f"{__name__}.Person", + "init_args": { + "name": "jt", + "pets": [ + {"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}}, + { + "class_path": f"{__name__}.Dog", + "init_args": { + "name": "dog", + "barks": 2.0, + "friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}}, + }, + }, + ], + }, + } + + +def test_model_argument_as_subclass(parser, subtests, subclass_behavior): + parser.add_argument("--person", type=Person, default=person) + + with subtests.test("help"): + help_str = get_parser_help(parser) + assert "--person.help [CLASS_PATH_OR_NAME]" in help_str + assert f"{__name__}.Person" in help_str + help_str = get_parse_args_stdout(parser, ["--person.help"]) + assert f"Help for --person.help={__name__}.Person" in help_str + + with subtests.test("defaults"): + defaults = parser.get_defaults() + dump = json_or_yaml_load(parser.dump(defaults))["person"] + assert dump == person_expected_subclass_dict + + with subtests.test("sub-param"): + cfg = parser.parse_args(["--person.pets.name=lucky"]) + init = parser.instantiate_classes(cfg) + assert isinstance(init.person, Person) + assert isinstance(init.person.pets[0], SpecialCat) + assert isinstance(init.person.pets[1], Dog) + assert init.person.pets[1].name == "lucky" + dump = json_or_yaml_load(parser.dump(cfg))["person"] + expected = deepcopy(person_expected_subclass_dict) + expected["init_args"]["pets"][1]["init_args"]["name"] = "lucky" + assert dump == expected + + +def test_convert_to_dict_not_subclass(): + converted = convert_to_dict(person) + assert converted == person_expected_dict + + +def test_convert_to_dict_subclass(subclass_behavior): + converted = convert_to_dict(person) + assert converted == person_expected_subclass_dict