Skip to content

Commit ac2e8d2

Browse files
authored
Fix evaluation of postponed annotations for dataclass inheritance across modules (#814)
1 parent e1b7cf8 commit ac2e8d2

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ The semantic versioning only considers the public API as described in
1212
paths are considered internals and can change in minor and patch releases.
1313

1414

15+
v4.44.1 (unreleased)
16+
--------------------
17+
18+
Fixed
19+
^^^^^
20+
- Evaluation of postponed annotations for dataclass inheritance across modules
21+
not working correctly (`#814
22+
<https://github.com/omni-us/jsonargparse/pull/814>`__).
23+
24+
1525
v4.44.0 (2025-11-25)
1626
--------------------
1727

jsonargparse/_postponed_annotations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,11 @@ def type_requires_eval(typehint):
221221

222222

223223
def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict:
224-
global_vars = obj.__globals__.copy() if hasattr(obj, "__globals__") else {}
224+
global_vars = getattr(obj, "__globals__", {}).copy()
225+
if is_dataclass(obj):
226+
next_mro = inspect.getmro(obj)[1] # type: ignore[arg-type]
227+
if is_dataclass(next_mro):
228+
global_vars.update(get_global_vars(next_mro, logger))
225229
for key, value in vars(import_module(obj.__module__)).items(): # needed for pydantic-v1
226230
if key not in global_vars:
227231
global_vars[key] = value
@@ -284,11 +288,7 @@ def evaluate_postponed_annotations(params, component, parent, logger):
284288
if not (params and any(type_requires_eval(p.annotation) for p in params)):
285289
return
286290
try:
287-
if (
288-
is_dataclass(parent)
289-
and component.__name__ == "__init__"
290-
and not component.__qualname__.startswith(parent.__name__ + ".")
291-
):
291+
if is_dataclass(parent) and component.__name__ == "__init__":
292292
types = get_types(parent, logger)
293293
else:
294294
types = get_types(component, logger)

jsonargparse_tests/test_dataclasses.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
typing_extensions_import,
2222
)
2323
from jsonargparse._signatures import convert_to_dict
24-
from jsonargparse.typing import PositiveFloat, PositiveInt
24+
from jsonargparse.typing import PositiveFloat, PositiveInt, restricted_number_type
2525
from jsonargparse_tests.conftest import (
2626
get_parse_args_stdout,
2727
get_parser_help,
@@ -31,6 +31,15 @@
3131

3232
annotated = typing_extensions_import("Annotated")
3333

34+
BetweenThreeAndNine = restricted_number_type("BetweenThreeAndNine", float, [(">=", 3), ("<=", 9)])
35+
ListPositiveInt = List[PositiveInt]
36+
37+
38+
@dataclasses.dataclass
39+
class DifferentModuleBaseData:
40+
count: Optional[BetweenThreeAndNine] = None # type: ignore[valid-type]
41+
numbers: ListPositiveInt = dataclasses.field(default_factory=list)
42+
3443

3544
@dataclasses.dataclass(frozen=True)
3645
class DataClassA:

jsonargparse_tests/test_postponed_annotations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
)
1717
from jsonargparse.typing import Path_drw
1818
from jsonargparse_tests.conftest import capture_logs, source_unavailable
19+
from jsonargparse_tests.test_dataclasses import DifferentModuleBaseData
1920

2021

2122
def function_pep604(p1: str | None, p2: int | float | bool = 1):
@@ -324,3 +325,22 @@ def test_add_dataclass_with_init_pep585(parser, tmp_cwd):
324325
parser.add_class_arguments(DataWithInit585, "data")
325326
cfg = parser.parse_args(["--data.a=[1, 2]", "--data.b=."])
326327
assert cfg.data == Namespace(a=[1, 2], b=Path_drw("."))
328+
329+
330+
@dataclasses.dataclass
331+
class InheritDifferentModule(DifferentModuleBaseData):
332+
extra: str = "default"
333+
334+
335+
def test_get_params_dataclass_inherit_different_module():
336+
assert "BetweenThreeAndNine" not in globals()
337+
assert "PositiveInt" not in globals()
338+
339+
params = get_params(InheritDifferentModule)
340+
341+
assert [p.name for p in params] == ["count", "numbers", "extra"]
342+
assert all(not isinstance(p.annotation, str) for p in params)
343+
assert not isinstance(params[0].annotation.__args__[0], str)
344+
assert "BetweenThreeAndNine" in str(params[0].annotation)
345+
assert not isinstance(params[1].annotation.__args__[0], str)
346+
assert "PositiveInt" in str(params[1].annotation)

0 commit comments

Comments
 (0)