Skip to content

Commit b4ff170

Browse files
authored
Experimental support for pydantic BaseModel subclasses (#781)
1 parent 3eb1bdd commit b4ff170

File tree

4 files changed

+218
-11
lines changed

4 files changed

+218
-11
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ Added
2020
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
2121
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
2222
<https://github.com/omni-us/jsonargparse/pull/770>`__).
23-
- Experimental support for dataclass inheritance (`#775
23+
- Experimental support for dataclass inheritance (`#775
2424
<https://github.com/omni-us/jsonargparse/pull/775>`__).
25+
- Experimental support for pydantic ``BaseModel`` subclasses (`#781
26+
<https://github.com/omni-us/jsonargparse/pull/781>`__).
2527

2628
Fixed
2729
^^^^^

jsonargparse/_signatures.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
is_subclass,
1818
)
1919
from ._namespace import Namespace
20-
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support
20+
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model
2121
from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters
2222
from ._typehints import (
2323
ActionTypeHint,
@@ -27,8 +27,9 @@
2727
is_optional,
2828
is_subclass_spec,
2929
not_required_types,
30+
sequence_origin_types,
3031
)
31-
from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
32+
from ._util import NoneType, get_import_path, get_private_kwargs, get_typehint_origin, iter_to_set_str
3233
from .typing import register_pydantic_type
3334

3435
__all__ = [
@@ -594,19 +595,28 @@ def is_convertible_to_dict(value):
594595

595596

596597
def convert_to_dict(value) -> dict:
597-
if pydantic_support:
598-
pydantic_model = is_pydantic_model(type(value))
599-
if pydantic_model:
600-
return value.dict() if pydantic_model == 1 else value.model_dump()
601-
602598
if attrs_support:
603599
import attrs
604600

605-
is_attrs_dataclass = attrs.has(type(value))
606-
if is_attrs_dataclass:
601+
if attrs.has(type(value)):
607602
return attrs.asdict(value)
608603

609-
return dataclasses.asdict(value)
604+
value_type = type(value)
605+
init_args = {}
606+
for name, attr in vars(value).items():
607+
attr_type = type(attr)
608+
if is_convertible_to_dict(attr_type):
609+
attr = convert_to_dict(attr)
610+
elif attr_type in sequence_origin_types:
611+
attr = attr.copy()
612+
for num, item in enumerate(attr):
613+
if is_convertible_to_dict(type(item)):
614+
attr[num] = convert_to_dict(item)
615+
init_args[name] = attr
616+
617+
if is_not_subclass_type(value_type):
618+
return init_args
619+
return {"class_path": get_import_path(value_type), "init_args": init_args}
610620

611621

612622
def compose_dataclasses(*args):

jsonargparse_tests/test_dataclasses.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
type_alias_type,
2222
typing_extensions_import,
2323
)
24+
from jsonargparse._signatures import convert_to_dict
2425
from jsonargparse.typing import PositiveFloat, PositiveInt
2526
from jsonargparse_tests.conftest import (
2627
get_parse_args_stdout,
@@ -852,3 +853,75 @@ def test_dataclass_nested_as_subclass(parser, subclass_behavior):
852853
assert isinstance(init.parent, ParentData)
853854
assert isinstance(init.parent.data, DataSub)
854855
assert dataclasses.asdict(init.parent.data) == {"p1": 3, "p2": "x"}
856+
857+
858+
@dataclasses.dataclass
859+
class Pet:
860+
name: str
861+
862+
863+
@dataclasses.dataclass
864+
class Cat(Pet):
865+
meows: int
866+
867+
868+
@dataclasses.dataclass
869+
class SpecialCat(Cat):
870+
number_of_tails: int
871+
872+
873+
@dataclasses.dataclass
874+
class Dog(Pet):
875+
barks: float
876+
friend: Pet
877+
878+
879+
@dataclasses.dataclass
880+
class Person(Pet):
881+
name: str
882+
pets: list[Pet]
883+
884+
885+
person = Person(
886+
name="jt",
887+
pets=[
888+
SpecialCat(name="sc", number_of_tails=2, meows=3),
889+
Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)),
890+
],
891+
)
892+
893+
894+
def test_convert_to_dict_not_subclass():
895+
person_dict = convert_to_dict(person)
896+
assert person_dict == {
897+
"name": "jt",
898+
"pets": [
899+
{"name": "sc", "meows": 3, "number_of_tails": 2},
900+
{
901+
"name": "dog",
902+
"barks": 2.0,
903+
"friend": {"name": "cc", "meows": 2},
904+
},
905+
],
906+
}
907+
908+
909+
def test_convert_to_dict_subclass(subclass_behavior):
910+
person_dict = convert_to_dict(person)
911+
assert person_dict == {
912+
"class_path": f"{__name__}.Person",
913+
"init_args": {
914+
"name": "jt",
915+
"pets": [
916+
{"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}},
917+
{
918+
"class_path": f"{__name__}.Dog",
919+
"init_args": {
920+
"name": "dog",
921+
"barks": 2.0,
922+
"friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}},
923+
},
924+
},
925+
],
926+
},
927+
}

jsonargparse_tests/test_pydantic.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import dataclasses
44
import json
5+
from copy import deepcopy
56
from typing import Dict, List, Literal, Optional, Union
7+
from unittest.mock import patch
68

79
import pytest
810

@@ -13,7 +15,9 @@
1315
pydantic_supports_field_init,
1416
typing_extensions_import,
1517
)
18+
from jsonargparse._signatures import convert_to_dict
1619
from jsonargparse_tests.conftest import (
20+
get_parse_args_stdout,
1721
get_parser_help,
1822
json_or_yaml_load,
1923
)
@@ -35,6 +39,13 @@ def missing_pydantic():
3539
pytest.skip("pydantic package is required")
3640

3741

42+
@pytest.fixture
43+
def subclass_behavior():
44+
with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors:
45+
not_subclass_type_selectors.pop("pydantic")
46+
yield
47+
48+
3849
@skip_if_pydantic_v1_on_v2
3950
def test_pydantic_secret_str(parser):
4051
parser.add_argument("--password", type=pydantic.SecretStr)
@@ -142,6 +153,10 @@ class PydanticDataNested:
142153
class PydanticDataFieldInitFalse:
143154
p1: str = PydanticV2Field("-", init=False)
144155

156+
@pydantic_v2_dataclass
157+
class ParentPydanticDataFieldInitFalse:
158+
y: PydanticDataFieldInitFalse = PydanticV2Field(default_factory=PydanticDataFieldInitFalse)
159+
145160
@pydantic.dataclasses.dataclass
146161
class PydanticDataStdlibField:
147162
p1: str = dataclasses.field(default="-")
@@ -263,6 +278,17 @@ def test_dataclass_field_init_false(self, parser):
263278
init = parser.instantiate_classes(cfg)
264279
assert init.data.p1 == "-"
265280

281+
@pytest.mark.skipif(not pydantic_supports_field_init, reason="Field.init is required")
282+
def test_nested_dataclass_field_init_false(self, parser):
283+
parser.add_class_arguments(ParentPydanticDataFieldInitFalse, "data")
284+
assert parser.get_defaults() == Namespace()
285+
cfg = parser.parse_args([])
286+
assert cfg == Namespace()
287+
init = parser.instantiate_classes(cfg)
288+
assert isinstance(init.data, ParentPydanticDataFieldInitFalse)
289+
assert isinstance(init.data.y, PydanticDataFieldInitFalse)
290+
assert init.data.y.p1 == "-"
291+
266292
def test_dataclass_stdlib_field(self, parser):
267293
parser.add_argument("--data", type=PydanticDataStdlibField)
268294
cfg = parser.parse_args(["--data", "{}"])
@@ -310,3 +336,99 @@ def test_nested_dict(self, parser):
310336
init = parser.instantiate_classes(cfg)
311337
assert isinstance(init.model, PydanticNestedDict)
312338
assert isinstance(init.model.nested["key"], NestedModel)
339+
340+
341+
if pydantic_support:
342+
343+
class Pet(pydantic.BaseModel):
344+
name: str
345+
346+
class Cat(Pet):
347+
meows: int
348+
349+
class SpecialCat(Cat):
350+
number_of_tails: int
351+
352+
class Dog(Pet):
353+
barks: float
354+
friend: Pet
355+
356+
class Person(Pet):
357+
name: str
358+
pets: list[Pet]
359+
360+
person = Person(
361+
name="jt",
362+
pets=[
363+
SpecialCat(name="sc", number_of_tails=2, meows=3),
364+
Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)),
365+
],
366+
)
367+
368+
person_expected_dict = {
369+
"name": "jt",
370+
"pets": [
371+
{"name": "sc", "meows": 3, "number_of_tails": 2},
372+
{
373+
"name": "dog",
374+
"barks": 2.0,
375+
"friend": {"name": "cc", "meows": 2},
376+
},
377+
],
378+
}
379+
380+
person_expected_subclass_dict = {
381+
"class_path": f"{__name__}.Person",
382+
"init_args": {
383+
"name": "jt",
384+
"pets": [
385+
{"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}},
386+
{
387+
"class_path": f"{__name__}.Dog",
388+
"init_args": {
389+
"name": "dog",
390+
"barks": 2.0,
391+
"friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}},
392+
},
393+
},
394+
],
395+
},
396+
}
397+
398+
399+
def test_model_argument_as_subclass(parser, subtests, subclass_behavior):
400+
parser.add_argument("--person", type=Person, default=person)
401+
402+
with subtests.test("help"):
403+
help_str = get_parser_help(parser)
404+
assert "--person.help [CLASS_PATH_OR_NAME]" in help_str
405+
assert f"{__name__}.Person" in help_str
406+
help_str = get_parse_args_stdout(parser, ["--person.help"])
407+
assert f"Help for --person.help={__name__}.Person" in help_str
408+
409+
with subtests.test("defaults"):
410+
defaults = parser.get_defaults()
411+
dump = json_or_yaml_load(parser.dump(defaults))["person"]
412+
assert dump == person_expected_subclass_dict
413+
414+
with subtests.test("sub-param"):
415+
cfg = parser.parse_args(["--person.pets.name=lucky"])
416+
init = parser.instantiate_classes(cfg)
417+
assert isinstance(init.person, Person)
418+
assert isinstance(init.person.pets[0], SpecialCat)
419+
assert isinstance(init.person.pets[1], Dog)
420+
assert init.person.pets[1].name == "lucky"
421+
dump = json_or_yaml_load(parser.dump(cfg))["person"]
422+
expected = deepcopy(person_expected_subclass_dict)
423+
expected["init_args"]["pets"][1]["init_args"]["name"] = "lucky"
424+
assert dump == expected
425+
426+
427+
def test_convert_to_dict_not_subclass():
428+
converted = convert_to_dict(person)
429+
assert converted == person_expected_dict
430+
431+
432+
def test_convert_to_dict_subclass(subclass_behavior):
433+
converted = convert_to_dict(person)
434+
assert converted == person_expected_subclass_dict

0 commit comments

Comments
 (0)