Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ Added
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
<https://github.com/omni-us/jsonargparse/pull/770>`__).
- Experimental support for dataclass inheritance (`#775
- Experimental support for dataclass inheritance (`#775
<https://github.com/omni-us/jsonargparse/pull/775>`__).
- Experimental support for pydantic ``BaseModel`` subclasses (`#781
<https://github.com/omni-us/jsonargparse/pull/781>`__).

Fixed
^^^^^
Expand Down
30 changes: 20 additions & 10 deletions jsonargparse/_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
is_subclass,
)
from ._namespace import Namespace
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model, pydantic_support
from ._optionals import attrs_support, get_doc_short_description, is_attrs_class, is_pydantic_model
from ._parameter_resolvers import ParamData, get_parameter_origins, get_signature_parameters
from ._typehints import (
ActionTypeHint,
Expand All @@ -27,8 +27,9 @@
is_optional,
is_subclass_spec,
not_required_types,
sequence_origin_types,
)
from ._util import NoneType, get_private_kwargs, get_typehint_origin, iter_to_set_str
from ._util import NoneType, get_import_path, get_private_kwargs, get_typehint_origin, iter_to_set_str
from .typing import register_pydantic_type

__all__ = [
Expand Down Expand Up @@ -594,19 +595,28 @@ def is_convertible_to_dict(value):


def convert_to_dict(value) -> dict:
if pydantic_support:
pydantic_model = is_pydantic_model(type(value))
if pydantic_model:
return value.dict() if pydantic_model == 1 else value.model_dump()

if attrs_support:
import attrs

is_attrs_dataclass = attrs.has(type(value))
if is_attrs_dataclass:
if attrs.has(type(value)):
return attrs.asdict(value)

return dataclasses.asdict(value)
value_type = type(value)
init_args = {}
for name, attr in vars(value).items():
attr_type = type(attr)
if is_convertible_to_dict(attr_type):
attr = convert_to_dict(attr)
elif attr_type in sequence_origin_types:
attr = attr.copy()
for num, item in enumerate(attr):
if is_convertible_to_dict(type(item)):
attr[num] = convert_to_dict(item)
init_args[name] = attr

if is_not_subclass_type(value_type):
return init_args
return {"class_path": get_import_path(value_type), "init_args": init_args}


def compose_dataclasses(*args):
Expand Down
73 changes: 73 additions & 0 deletions jsonargparse_tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
type_alias_type,
typing_extensions_import,
)
from jsonargparse._signatures import convert_to_dict
from jsonargparse.typing import PositiveFloat, PositiveInt
from jsonargparse_tests.conftest import (
get_parse_args_stdout,
Expand Down Expand Up @@ -852,3 +853,75 @@ def test_dataclass_nested_as_subclass(parser, subclass_behavior):
assert isinstance(init.parent, ParentData)
assert isinstance(init.parent.data, DataSub)
assert dataclasses.asdict(init.parent.data) == {"p1": 3, "p2": "x"}


@dataclasses.dataclass
class Pet:
name: str


@dataclasses.dataclass
class Cat(Pet):
meows: int


@dataclasses.dataclass
class SpecialCat(Cat):
number_of_tails: int


@dataclasses.dataclass
class Dog(Pet):
barks: float
friend: Pet


@dataclasses.dataclass
class Person(Pet):
name: str
pets: list[Pet]


person = Person(
name="jt",
pets=[
SpecialCat(name="sc", number_of_tails=2, meows=3),
Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)),
],
)


def test_convert_to_dict_not_subclass():
person_dict = convert_to_dict(person)
assert person_dict == {
"name": "jt",
"pets": [
{"name": "sc", "meows": 3, "number_of_tails": 2},
{
"name": "dog",
"barks": 2.0,
"friend": {"name": "cc", "meows": 2},
},
],
}


def test_convert_to_dict_subclass(subclass_behavior):
person_dict = convert_to_dict(person)
assert person_dict == {
"class_path": f"{__name__}.Person",
"init_args": {
"name": "jt",
"pets": [
{"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}},
{
"class_path": f"{__name__}.Dog",
"init_args": {
"name": "dog",
"barks": 2.0,
"friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}},
},
},
],
},
}
122 changes: 122 additions & 0 deletions jsonargparse_tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import dataclasses
import json
from copy import deepcopy
from typing import Dict, List, Literal, Optional, Union
from unittest.mock import patch

import pytest

Expand All @@ -13,7 +15,9 @@
pydantic_supports_field_init,
typing_extensions_import,
)
from jsonargparse._signatures import convert_to_dict
from jsonargparse_tests.conftest import (
get_parse_args_stdout,
get_parser_help,
json_or_yaml_load,
)
Expand All @@ -35,6 +39,13 @@ def missing_pydantic():
pytest.skip("pydantic package is required")


@pytest.fixture
def subclass_behavior():
with patch.dict("jsonargparse._common.not_subclass_type_selectors") as not_subclass_type_selectors:
not_subclass_type_selectors.pop("pydantic")
yield


@skip_if_pydantic_v1_on_v2
def test_pydantic_secret_str(parser):
parser.add_argument("--password", type=pydantic.SecretStr)
Expand Down Expand Up @@ -142,6 +153,10 @@ class PydanticDataNested:
class PydanticDataFieldInitFalse:
p1: str = PydanticV2Field("-", init=False)

@pydantic_v2_dataclass
class ParentPydanticDataFieldInitFalse:
y: PydanticDataFieldInitFalse = PydanticV2Field(default_factory=PydanticDataFieldInitFalse)

@pydantic.dataclasses.dataclass
class PydanticDataStdlibField:
p1: str = dataclasses.field(default="-")
Expand Down Expand Up @@ -263,6 +278,17 @@ def test_dataclass_field_init_false(self, parser):
init = parser.instantiate_classes(cfg)
assert init.data.p1 == "-"

@pytest.mark.skipif(not pydantic_supports_field_init, reason="Field.init is required")
def test_nested_dataclass_field_init_false(self, parser):
parser.add_class_arguments(ParentPydanticDataFieldInitFalse, "data")
assert parser.get_defaults() == Namespace()
cfg = parser.parse_args([])
assert cfg == Namespace()
init = parser.instantiate_classes(cfg)
assert isinstance(init.data, ParentPydanticDataFieldInitFalse)
assert isinstance(init.data.y, PydanticDataFieldInitFalse)
assert init.data.y.p1 == "-"

def test_dataclass_stdlib_field(self, parser):
parser.add_argument("--data", type=PydanticDataStdlibField)
cfg = parser.parse_args(["--data", "{}"])
Expand Down Expand Up @@ -310,3 +336,99 @@ def test_nested_dict(self, parser):
init = parser.instantiate_classes(cfg)
assert isinstance(init.model, PydanticNestedDict)
assert isinstance(init.model.nested["key"], NestedModel)


if pydantic_support:

class Pet(pydantic.BaseModel):
name: str

class Cat(Pet):
meows: int

class SpecialCat(Cat):
number_of_tails: int

class Dog(Pet):
barks: float
friend: Pet

class Person(Pet):
name: str
pets: list[Pet]

person = Person(
name="jt",
pets=[
SpecialCat(name="sc", number_of_tails=2, meows=3),
Dog(name="dog", barks=2, friend=Cat(name="cc", meows=2)),
],
)

person_expected_dict = {
"name": "jt",
"pets": [
{"name": "sc", "meows": 3, "number_of_tails": 2},
{
"name": "dog",
"barks": 2.0,
"friend": {"name": "cc", "meows": 2},
},
],
}

person_expected_subclass_dict = {
"class_path": f"{__name__}.Person",
"init_args": {
"name": "jt",
"pets": [
{"class_path": f"{__name__}.SpecialCat", "init_args": {"name": "sc", "meows": 3, "number_of_tails": 2}},
{
"class_path": f"{__name__}.Dog",
"init_args": {
"name": "dog",
"barks": 2.0,
"friend": {"class_path": f"{__name__}.Cat", "init_args": {"name": "cc", "meows": 2}},
},
},
],
},
}


def test_model_argument_as_subclass(parser, subtests, subclass_behavior):
parser.add_argument("--person", type=Person, default=person)

with subtests.test("help"):
help_str = get_parser_help(parser)
assert "--person.help [CLASS_PATH_OR_NAME]" in help_str
assert f"{__name__}.Person" in help_str
help_str = get_parse_args_stdout(parser, ["--person.help"])
assert f"Help for --person.help={__name__}.Person" in help_str

with subtests.test("defaults"):
defaults = parser.get_defaults()
dump = json_or_yaml_load(parser.dump(defaults))["person"]
assert dump == person_expected_subclass_dict

with subtests.test("sub-param"):
cfg = parser.parse_args(["--person.pets.name=lucky"])
init = parser.instantiate_classes(cfg)
assert isinstance(init.person, Person)
assert isinstance(init.person.pets[0], SpecialCat)
assert isinstance(init.person.pets[1], Dog)
assert init.person.pets[1].name == "lucky"
dump = json_or_yaml_load(parser.dump(cfg))["person"]
expected = deepcopy(person_expected_subclass_dict)
expected["init_args"]["pets"][1]["init_args"]["name"] = "lucky"
assert dump == expected


def test_convert_to_dict_not_subclass():
converted = convert_to_dict(person)
assert converted == person_expected_dict


def test_convert_to_dict_subclass(subclass_behavior):
converted = convert_to_dict(person)
assert converted == person_expected_subclass_dict