From 838acb11e5a011555c9b7b61d77a15a17655f048 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas <5780272+mauvilsa@users.noreply.github.com> Date: Wed, 26 Nov 2025 17:19:55 -0500 Subject: [PATCH] Fix evaluation of postponed annotations for dataclass inheritance across modules (https://github.com/omni-us/jsonargparse/issues/287#issuecomment-3570874105) --- CHANGELOG.rst | 10 ++++++++++ jsonargparse/_postponed_annotations.py | 12 +++++------ jsonargparse_tests/test_dataclasses.py | 11 +++++++++- .../test_postponed_annotations.py | 20 +++++++++++++++++++ 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b27ae401..6836551b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -12,6 +12,16 @@ The semantic versioning only considers the public API as described in paths are considered internals and can change in minor and patch releases. +v4.44.1 (unreleased) +-------------------- + +Fixed +^^^^^ +- Evaluation of postponed annotations for dataclass inheritance across modules + not working correctly (`#814 + `__). + + v4.44.0 (2025-11-25) -------------------- diff --git a/jsonargparse/_postponed_annotations.py b/jsonargparse/_postponed_annotations.py index 70a60a5e..0bd0b51e 100644 --- a/jsonargparse/_postponed_annotations.py +++ b/jsonargparse/_postponed_annotations.py @@ -221,7 +221,11 @@ def type_requires_eval(typehint): def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict: - global_vars = obj.__globals__.copy() if hasattr(obj, "__globals__") else {} + global_vars = getattr(obj, "__globals__", {}).copy() + if is_dataclass(obj): + next_mro = inspect.getmro(obj)[1] # type: ignore[arg-type] + if is_dataclass(next_mro): + global_vars.update(get_global_vars(next_mro, logger)) for key, value in vars(import_module(obj.__module__)).items(): # needed for pydantic-v1 if key not in global_vars: global_vars[key] = value @@ -284,11 +288,7 @@ def evaluate_postponed_annotations(params, component, parent, logger): if not (params and any(type_requires_eval(p.annotation) for p in params)): return try: - if ( - is_dataclass(parent) - and component.__name__ == "__init__" - and not component.__qualname__.startswith(parent.__name__ + ".") - ): + if is_dataclass(parent) and component.__name__ == "__init__": types = get_types(parent, logger) else: types = get_types(component, logger) diff --git a/jsonargparse_tests/test_dataclasses.py b/jsonargparse_tests/test_dataclasses.py index 63b9414f..b3601020 100644 --- a/jsonargparse_tests/test_dataclasses.py +++ b/jsonargparse_tests/test_dataclasses.py @@ -21,7 +21,7 @@ typing_extensions_import, ) from jsonargparse._signatures import convert_to_dict -from jsonargparse.typing import PositiveFloat, PositiveInt +from jsonargparse.typing import PositiveFloat, PositiveInt, restricted_number_type from jsonargparse_tests.conftest import ( get_parse_args_stdout, get_parser_help, @@ -31,6 +31,15 @@ annotated = typing_extensions_import("Annotated") +BetweenThreeAndNine = restricted_number_type("BetweenThreeAndNine", float, [(">=", 3), ("<=", 9)]) +ListPositiveInt = List[PositiveInt] + + +@dataclasses.dataclass +class DifferentModuleBaseData: + count: Optional[BetweenThreeAndNine] = None # type: ignore[valid-type] + numbers: ListPositiveInt = dataclasses.field(default_factory=list) + @dataclasses.dataclass(frozen=True) class DataClassA: diff --git a/jsonargparse_tests/test_postponed_annotations.py b/jsonargparse_tests/test_postponed_annotations.py index 726dbd46..779dcbaf 100644 --- a/jsonargparse_tests/test_postponed_annotations.py +++ b/jsonargparse_tests/test_postponed_annotations.py @@ -16,6 +16,7 @@ ) from jsonargparse.typing import Path_drw from jsonargparse_tests.conftest import capture_logs, source_unavailable +from jsonargparse_tests.test_dataclasses import DifferentModuleBaseData def function_pep604(p1: str | None, p2: int | float | bool = 1): @@ -324,3 +325,22 @@ def test_add_dataclass_with_init_pep585(parser, tmp_cwd): parser.add_class_arguments(DataWithInit585, "data") cfg = parser.parse_args(["--data.a=[1, 2]", "--data.b=."]) assert cfg.data == Namespace(a=[1, 2], b=Path_drw(".")) + + +@dataclasses.dataclass +class InheritDifferentModule(DifferentModuleBaseData): + extra: str = "default" + + +def test_get_params_dataclass_inherit_different_module(): + assert "BetweenThreeAndNine" not in globals() + assert "PositiveInt" not in globals() + + params = get_params(InheritDifferentModule) + + assert [p.name for p in params] == ["count", "numbers", "extra"] + assert all(not isinstance(p.annotation, str) for p in params) + assert not isinstance(params[0].annotation.__args__[0], str) + assert "BetweenThreeAndNine" in str(params[0].annotation) + assert not isinstance(params[1].annotation.__args__[0], str) + assert "PositiveInt" in str(params[1].annotation)