Skip to content

Commit d11582e

Browse files
authored
Fix introspection of postponed annotations from jax (#749)
1 parent d1f5e57 commit d11582e

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Fixed
2121
<https://github.com/omni-us/jsonargparse/pull/743>`__).
2222
- Linking entire dataclasses on instantiation not working (`#746
2323
<https://github.com/omni-us/jsonargparse/pull/746>`__).
24+
- Introspection of postponed annotations from jax not working (`#749
25+
<https://github.com/omni-us/jsonargparse/pull/749>`__).
2426

2527

2628
v4.40.1 (2025-07-24)
@@ -441,9 +443,8 @@ Added
441443
- Allow adding config argument with ``action="config"`` avoiding need to import
442444
action class (`#512
443445
<https://github.com/omni-us/jsonargparse/pull/512>`__).
444-
- Allow providing a function with return type a class in ``class_path``
445-
(`lightning#13613
446-
<https://github.com/Lightning-AI/pytorch-lightning/discussions/13613>`__).
446+
- Allow providing a function with return type a class in ``class_path`` (`#513
447+
<https://github.com/omni-us/jsonargparse/pull/513>`__).
447448
- Automatic ``--print_shtab`` option when ``shtab`` is installed, providing
448449
completions for many type hints without the need to modify code (`#528
449450
<https://github.com/omni-us/jsonargparse/pull/528>`__).

jsonargparse/_postponed_annotations.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,10 @@ def type_requires_eval(typehint):
260260

261261

262262
def get_global_vars(obj: Any, logger: Optional[logging.Logger]) -> dict:
263-
global_vars = vars(import_module(obj.__module__))
263+
global_vars = obj.__globals__.copy() if hasattr(obj, "__globals__") else {}
264+
for key, value in vars(import_module(obj.__module__)).items(): # needed for pydantic-v1
265+
if key not in global_vars:
266+
global_vars[key] = value
264267
try:
265268
module_source = inspect.getsource(sys.modules[obj.__module__]) if obj.__module__ in sys.modules else ""
266269
if "TYPE_CHECKING" in module_source:
@@ -349,7 +352,7 @@ def evaluate_postponed_annotations(params, component, parent, logger):
349352
def get_return_type(component, logger=None):
350353
return_type = inspect.signature(component).return_annotation
351354
if type_requires_eval(return_type):
352-
global_vars = vars(import_module(component.__module__))
355+
global_vars = get_global_vars(component, logger)
353356
try:
354357
return_type = get_type_hints(component, global_vars)["return"]
355358
if isinstance(return_type, ForwardRef):

jsonargparse_tests/test_postponed_annotations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ class DataclassForwardRef:
313313

314314
@pytest.mark.skipif(sys.version_info < (3, 9), reason="not working in python 3.8")
315315
def test_get_types_type_checking_dataclass_init_forward_ref():
316+
import xml.dom
317+
316318
types = get_types(DataclassForwardRef.__init__)
317319
assert types == {"p1": int, "p2": Optional[xml.dom.Node], "return": type(None)}
318320

0 commit comments

Comments
 (0)