22
33import dataclasses
44import json
5+ from copy import deepcopy
56from typing import Dict , List , Literal , Optional , Union
7+ from unittest .mock import patch
68
79import pytest
810
1315 pydantic_supports_field_init ,
1416 typing_extensions_import ,
1517)
18+ from jsonargparse ._signatures import convert_to_dict
1619from jsonargparse_tests .conftest import (
20+ get_parse_args_stdout ,
1721 get_parser_help ,
1822 json_or_yaml_load ,
1923)
@@ -35,6 +39,13 @@ def missing_pydantic():
3539 pytest .skip ("pydantic package is required" )
3640
3741
42+ @pytest .fixture
43+ def subclass_behavior ():
44+ with patch .dict ("jsonargparse._common.not_subclass_type_selectors" ) as not_subclass_type_selectors :
45+ not_subclass_type_selectors .pop ("pydantic" )
46+ yield
47+
48+
3849@skip_if_pydantic_v1_on_v2
3950def test_pydantic_secret_str (parser ):
4051 parser .add_argument ("--password" , type = pydantic .SecretStr )
@@ -142,6 +153,10 @@ class PydanticDataNested:
142153 class PydanticDataFieldInitFalse :
143154 p1 : str = PydanticV2Field ("-" , init = False )
144155
156+ @pydantic_v2_dataclass
157+ class ParentPydanticDataFieldInitFalse :
158+ y : PydanticDataFieldInitFalse = PydanticV2Field (default_factory = PydanticDataFieldInitFalse )
159+
145160 @pydantic .dataclasses .dataclass
146161 class PydanticDataStdlibField :
147162 p1 : str = dataclasses .field (default = "-" )
@@ -263,6 +278,17 @@ def test_dataclass_field_init_false(self, parser):
263278 init = parser .instantiate_classes (cfg )
264279 assert init .data .p1 == "-"
265280
281+ @pytest .mark .skipif (not pydantic_supports_field_init , reason = "Field.init is required" )
282+ def test_nested_dataclass_field_init_false (self , parser ):
283+ parser .add_class_arguments (ParentPydanticDataFieldInitFalse , "data" )
284+ assert parser .get_defaults () == Namespace ()
285+ cfg = parser .parse_args ([])
286+ assert cfg == Namespace ()
287+ init = parser .instantiate_classes (cfg )
288+ assert isinstance (init .data , ParentPydanticDataFieldInitFalse )
289+ assert isinstance (init .data .y , PydanticDataFieldInitFalse )
290+ assert init .data .y .p1 == "-"
291+
266292 def test_dataclass_stdlib_field (self , parser ):
267293 parser .add_argument ("--data" , type = PydanticDataStdlibField )
268294 cfg = parser .parse_args (["--data" , "{}" ])
@@ -310,3 +336,99 @@ def test_nested_dict(self, parser):
310336 init = parser .instantiate_classes (cfg )
311337 assert isinstance (init .model , PydanticNestedDict )
312338 assert isinstance (init .model .nested ["key" ], NestedModel )
339+
340+
341+ if pydantic_support :
342+
343+ class Pet (pydantic .BaseModel ):
344+ name : str
345+
346+ class Cat (Pet ):
347+ meows : int
348+
349+ class SpecialCat (Cat ):
350+ number_of_tails : int
351+
352+ class Dog (Pet ):
353+ barks : float
354+ friend : Pet
355+
356+ class Person (Pet ):
357+ name : str
358+ pets : list [Pet ]
359+
360+ person = Person (
361+ name = "jt" ,
362+ pets = [
363+ SpecialCat (name = "sc" , number_of_tails = 2 , meows = 3 ),
364+ Dog (name = "dog" , barks = 2 , friend = Cat (name = "cc" , meows = 2 )),
365+ ],
366+ )
367+
368+ person_expected_dict = {
369+ "name" : "jt" ,
370+ "pets" : [
371+ {"name" : "sc" , "meows" : 3 , "number_of_tails" : 2 },
372+ {
373+ "name" : "dog" ,
374+ "barks" : 2.0 ,
375+ "friend" : {"name" : "cc" , "meows" : 2 },
376+ },
377+ ],
378+ }
379+
380+ person_expected_subclass_dict = {
381+ "class_path" : f"{ __name__ } .Person" ,
382+ "init_args" : {
383+ "name" : "jt" ,
384+ "pets" : [
385+ {"class_path" : f"{ __name__ } .SpecialCat" , "init_args" : {"name" : "sc" , "meows" : 3 , "number_of_tails" : 2 }},
386+ {
387+ "class_path" : f"{ __name__ } .Dog" ,
388+ "init_args" : {
389+ "name" : "dog" ,
390+ "barks" : 2.0 ,
391+ "friend" : {"class_path" : f"{ __name__ } .Cat" , "init_args" : {"name" : "cc" , "meows" : 2 }},
392+ },
393+ },
394+ ],
395+ },
396+ }
397+
398+
399+ def test_model_argument_as_subclass (parser , subtests , subclass_behavior ):
400+ parser .add_argument ("--person" , type = Person , default = person )
401+
402+ with subtests .test ("help" ):
403+ help_str = get_parser_help (parser )
404+ assert "--person.help [CLASS_PATH_OR_NAME]" in help_str
405+ assert f"{ __name__ } .Person" in help_str
406+ help_str = get_parse_args_stdout (parser , ["--person.help" ])
407+ assert f"Help for --person.help={ __name__ } .Person" in help_str
408+
409+ with subtests .test ("defaults" ):
410+ defaults = parser .get_defaults ()
411+ dump = json_or_yaml_load (parser .dump (defaults ))["person" ]
412+ assert dump == person_expected_subclass_dict
413+
414+ with subtests .test ("sub-param" ):
415+ cfg = parser .parse_args (["--person.pets.name=lucky" ])
416+ init = parser .instantiate_classes (cfg )
417+ assert isinstance (init .person , Person )
418+ assert isinstance (init .person .pets [0 ], SpecialCat )
419+ assert isinstance (init .person .pets [1 ], Dog )
420+ assert init .person .pets [1 ].name == "lucky"
421+ dump = json_or_yaml_load (parser .dump (cfg ))["person" ]
422+ expected = deepcopy (person_expected_subclass_dict )
423+ expected ["init_args" ]["pets" ][1 ]["init_args" ]["name" ] = "lucky"
424+ assert dump == expected
425+
426+
427+ def test_convert_to_dict_not_subclass ():
428+ converted = convert_to_dict (person )
429+ assert converted == person_expected_dict
430+
431+
432+ def test_convert_to_dict_subclass (subclass_behavior ):
433+ converted = convert_to_dict (person )
434+ assert converted == person_expected_subclass_dict
0 commit comments