diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8e1fd60b..b1a00c2d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -25,9 +25,11 @@ Fixed ^^^^^ - ``set_parsing_settings(validate_defaults=True)`` fails when the parser has a config action (`#718 `__). -- Regression causing dump/save to fail when ``skip_link_targets=True`` and target - being an entire required dataclass (`#717 +- Regression causing dump/save to fail when ``skip_link_targets=True`` and + target being an entire required dataclass (`#717 `__). +- ``TypedDict`` values not validated when types are forward references (`#722 + `__). Changed ^^^^^^^ diff --git a/jsonargparse/_typehints.py b/jsonargparse/_typehints.py index c4518ff8..c093eab5 100644 --- a/jsonargparse/_typehints.py +++ b/jsonargparse/_typehints.py @@ -11,11 +11,13 @@ from copy import deepcopy from enum import Enum from functools import partial +from importlib import import_module from types import FunctionType, MappingProxyType from typing import ( Any, Callable, Dict, + ForwardRef, Iterable, List, Literal, @@ -719,6 +721,15 @@ def raise_union_unexpected_value(subtypes, val: Any, exceptions: List[Exception] ) from exceptions[0] +def resolve_forward_ref(ref): + if not isinstance(ref, ForwardRef) or not ref.__forward_module__: + return ref + + aliases = __builtins__.copy() + aliases.update(vars(import_module(ref.__forward_module__))) + return aliases.get(ref.__forward_arg__, ref) + + def adapt_typehints( val, typehint, @@ -954,7 +965,8 @@ def adapt_typehints( if extra_keys: raise_unexpected_value(f"Unexpected keys: {extra_keys}", val) for k, v in val.items(): - val[k] = adapt_typehints(v, typehint.__annotations__[k], **adapt_kwargs) + subtypehint = resolve_forward_ref(typehint.__annotations__[k]) + val[k] = adapt_typehints(v, subtypehint, **adapt_kwargs) if typehint_origin is MappingProxyType and not serialize: val = MappingProxyType(val) elif typehint_origin is OrderedDict: @@ -1100,6 +1112,11 @@ def adapt_typehints( elif is_alias_type(typehint): return adapt_typehints(val, get_alias_target(typehint), **adapt_kwargs) + else: + if str(typehint) == "+VT_co": + return val # required for typing.Mapping in python 3.8 + raise RuntimeError(f"The code should never reach here: typehint={typehint}") # pragma: no cover + return val diff --git a/jsonargparse_tests/test_typehints.py b/jsonargparse_tests/test_typehints.py index 754f52ff..0e2ac72a 100644 --- a/jsonargparse_tests/test_typehints.py +++ b/jsonargparse_tests/test_typehints.py @@ -753,8 +753,7 @@ def test_invalid_inherited_unpack_typeddict(parser, init_args): parser.parse_args([f"--testclass={json.dumps(test_config)}"]) -@pytest.mark.skipif(sys.version_info < (3, 9), reason="Python 3.8 lacked runtime inspection of TypedDict required keys") -def test_typeddict_totality_inheritance(parser): +if sys.version_info >= (3, 9): class BottomDict(TypedDict, total=True): a: int @@ -765,6 +764,9 @@ class MiddleDict(BottomDict, total=False): class TopDict(MiddleDict, total=True): c: int + +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Python 3.8 lacked runtime inspection of TypedDict required keys") +def test_typeddict_totality_inheritance(parser): parser.add_argument("--middledict", type=MiddleDict, required=False) parser.add_argument("--topdict", type=TopDict, required=False) assert {"a": 1} == parser.parse_args(['--middledict={"a": 1}'])["middledict"]