Skip to content

Commit c2c7aa4

Browse files
authored
Stubs resolver allow_py_files setting and fix maximum recursion error (#770)
1 parent cb4e633 commit c2c7aa4

File tree

5 files changed

+95
-10
lines changed

5 files changed

+95
-10
lines changed

CHANGELOG.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@ The semantic versioning only considers the public API as described in
1212
paths are considered internals and can change in minor and patch releases.
1313

1414

15+
v4.42.0 (unreleased)
16+
--------------------
17+
18+
Added
19+
^^^^^
20+
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
21+
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
22+
<https://github.com/omni-us/jsonargparse/pull/770>`__).
23+
24+
Fixed
25+
^^^^^
26+
- Stubs resolver in some cases failing with maximum recursion error (`#770
27+
<https://github.com/omni-us/jsonargparse/pull/770>`__).
28+
29+
1530
v4.41.0 (2025-09-04)
1631
--------------------
1732

jsonargparse/_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def parser_context(**kwargs):
9696
parsing_settings = {
9797
"validate_defaults": False,
9898
"parse_optionals_as_positionals": False,
99+
"stubs_resolver_allow_py_files": False,
99100
}
100101

101102

@@ -107,6 +108,7 @@ def set_parsing_settings(
107108
docstring_parse_style: Optional["docstring_parser.DocstringStyle"] = None,
108109
docstring_parse_attribute_docstrings: Optional[bool] = None,
109110
parse_optionals_as_positionals: Optional[bool] = None,
111+
stubs_resolver_allow_py_files: Optional[bool] = None,
110112
) -> None:
111113
"""
112114
Modify settings that affect the parsing behavior.
@@ -129,6 +131,8 @@ def set_parsing_settings(
129131
--key=value as usual, but also as positional. The extra positionals
130132
are applied to optionals in the order that they were added to the
131133
parser. By default, this is False.
134+
stubs_resolver_allow_py_files: Whether the stubs resolver should search
135+
in ``.py`` files in addition to ``.pyi`` files.
132136
"""
133137
# validate_defaults
134138
if isinstance(validate_defaults, bool):
@@ -150,6 +154,11 @@ def set_parsing_settings(
150154
parsing_settings["parse_optionals_as_positionals"] = parse_optionals_as_positionals
151155
elif parse_optionals_as_positionals is not None:
152156
raise ValueError(f"parse_optionals_as_positionals must be a boolean, but got {parse_optionals_as_positionals}.")
157+
# stubs resolver
158+
if isinstance(stubs_resolver_allow_py_files, bool):
159+
parsing_settings["stubs_resolver_allow_py_files"] = stubs_resolver_allow_py_files
160+
elif stubs_resolver_allow_py_files is not None:
161+
raise ValueError(f"stubs_resolver_allow_py_files must be a boolean, but got {stubs_resolver_allow_py_files}.")
153162

154163

155164
def get_parsing_setting(name: str):

jsonargparse/_stubs_resolver.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from importlib import import_module
77
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
88

9+
from ._common import get_parsing_setting
910
from ._optionals import import_typeshed_client, typeshed_client_support
1011
from ._postponed_annotations import NamesVisitor, get_arg_type
1112

@@ -103,7 +104,9 @@ def find(self, node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]:
103104
def get_stubs_resolver():
104105
global stubs_resolver
105106
if not stubs_resolver:
106-
stubs_resolver = StubsResolver()
107+
allow_py_files = get_parsing_setting("stubs_resolver_allow_py_files")
108+
search_context = tc.get_search_context(allow_py_files=allow_py_files)
109+
stubs_resolver = StubsResolver(search_context=search_context)
107110
return stubs_resolver
108111

109112

@@ -195,10 +198,10 @@ def add_import_aliases(self, aliases, stub_import: tc.ImportedInfo):
195198
self.add_module_aliases(aliases, module_path, module, stub_ast)
196199
return module_path, stub_import.info.ast
197200

198-
def add_module_aliases(self, aliases, module_path, module, node):
201+
def add_module_aliases(self, aliases, module_path, module, node, skip=set()):
199202
names = NamesVisitor().find(node) if node else []
200203
for name in names:
201-
if alias_already_added(aliases, name, module_path):
204+
if alias_already_added(aliases, name, module_path) or name in skip:
202205
continue
203206
source = module_path
204207
if name in __builtins__:
@@ -208,7 +211,7 @@ def add_module_aliases(self, aliases, module_path, module, node):
208211
value = getattr(module, name)
209212
elif name in self.get_module_stub_assigns(module_path):
210213
value = self.get_module_stub_assigns(module_path)[name]
211-
self.add_module_aliases(aliases, module_path, module, value.value)
214+
self.add_module_aliases(aliases, module_path, module, value.value, skip={name})
212215
elif name in self.get_module_stub_imports(module_path):
213216
imported_module_path, imported_name = self.get_module_stub_imports(module_path)[name]
214217
imported_module = import_module_or_none(imported_module_path)

jsonargparse_tests/test_parsing_settings.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,11 @@ def test_optionals_as_positionals_unsupported_arguments(parser):
175175

176176
help_str = get_parse_args_stdout(parser, ["--o1.help=Adam"])
177177
assert "extra positionals are parsed as optionals in the order shown above" not in help_str
178+
179+
180+
# stubs_resolver_allow_py_files
181+
182+
183+
def test_set_stubs_resolver_allow_py_files_failure():
184+
with pytest.raises(ValueError, match="stubs_resolver_allow_py_files must be a boolean"):
185+
set_parsing_settings(stubs_resolver_allow_py_files="invalid")

jsonargparse_tests/test_stubs_resolver.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from ipaddress import ip_network
1111
from random import Random, SystemRandom, uniform
1212
from tarfile import TarFile
13+
from typing import Optional
1314
from unittest.mock import patch
1415
from uuid import UUID, uuid5
1516

1617
import pytest
1718

19+
from jsonargparse import set_parsing_settings
1820
from jsonargparse._parameter_resolvers import get_signature_parameters as get_params
1921
from jsonargparse._stubs_resolver import get_arg_type, get_mro_method_parent, get_stubs_resolver
2022
from jsonargparse_tests.conftest import (
@@ -25,6 +27,7 @@
2527
)
2628

2729
torch_available = bool(find_spec("torch"))
30+
torchvision_available = bool(find_spec("torchvision"))
2831

2932

3033
@pytest.fixture(autouse=True)
@@ -33,6 +36,29 @@ def skip_if_typeshed_client_unavailable():
3336
pytest.skip("typeshed-client package is required")
3437

3538

39+
@pytest.fixture(autouse=True)
40+
def clear_stubs_resolver():
41+
import jsonargparse._stubs_resolver
42+
43+
jsonargparse._stubs_resolver.stubs_resolver = None
44+
yield
45+
46+
47+
@pytest.fixture
48+
def allow_py_files():
49+
with patch.dict("jsonargparse._common.parsing_settings"):
50+
set_parsing_settings(stubs_resolver_allow_py_files=True)
51+
yield
52+
53+
54+
@pytest.fixture(params=["allow-py-files-true", "allow-py-files-false"])
55+
def parametrize_allow_py_files(request):
56+
allow_py_files = request.param == "allow-py-files-true"
57+
with patch.dict("jsonargparse._common.parsing_settings"):
58+
set_parsing_settings(stubs_resolver_allow_py_files=allow_py_files)
59+
yield
60+
61+
3662
@contextmanager
3763
def mock_stubs_missing_types():
3864
with patch("jsonargparse._parameter_resolvers.add_stub_types"):
@@ -112,7 +138,7 @@ def test_get_params_class_with_inheritance():
112138
assert [("firstweekday", inspect._empty)] == get_param_types(params)
113139

114140

115-
def test_get_params_method():
141+
def test_get_params_method(parametrize_allow_py_files):
116142
params = get_params(Random, "randint")
117143
assert [("a", int), ("b", int)] == get_param_types(params)
118144
with mock_stubs_missing_types():
@@ -148,7 +174,7 @@ def test_get_params_exec_failure(mock_get_stub_types):
148174
assert [("a", inspect._empty), ("version", inspect._empty)] == get_param_types(params)
149175

150176

151-
def test_get_params_classmethod():
177+
def test_get_params_classmethod(parametrize_allow_py_files):
152178
params = get_params(TarFile, "open")
153179
expected = [
154180
"name",
@@ -190,7 +216,7 @@ def test_get_params_staticmethod():
190216
assert [("value", inspect._empty)] == get_param_types(params)
191217

192218

193-
def test_get_params_function():
219+
def test_get_params_function(parametrize_allow_py_files):
194220
params = get_params(ip_network)
195221
assert ["address", "strict"] == get_param_names(params)
196222
if sys.version_info >= (3, 10):
@@ -329,19 +355,20 @@ def test_get_params_inspect_signature_failure_missing_type(logger):
329355
# pytorch tests
330356

331357

358+
torch_optimizers_schedulers = torch_available
332359
if torch_available:
333360
import importlib.metadata
334361

335362
torch_version = tuple(int(v) for v in importlib.metadata.version("torch").split(".", 2)[:2])
336363

337364
if torch_version < (2, 1) or torch_version >= (2, 4):
338-
torch_available = False
365+
torch_optimizers_schedulers = False
339366
else:
340367
import torch.optim # pylint: disable=import-error
341368
import torch.optim.lr_scheduler # pylint: disable=import-error
342369

343370

344-
@pytest.mark.skipif(not torch_available, reason="only for torch>=2.1,<2.4")
371+
@pytest.mark.skipif(not torch_optimizers_schedulers, reason="only for torch>=2.1,<2.4")
345372
@pytest.mark.parametrize(
346373
"class_name",
347374
[
@@ -367,7 +394,7 @@ def test_get_params_torch_optimizer(class_name):
367394
assert any(p.annotation is inspect._empty for p in params)
368395

369396

370-
@pytest.mark.skipif(not torch_available, reason="only for torch>=2.1,<2.4")
397+
@pytest.mark.skipif(not torch_optimizers_schedulers, reason="only for torch>=2.1,<2.4")
371398
@pytest.mark.parametrize(
372399
"class_name",
373400
[
@@ -396,3 +423,26 @@ def test_get_params_torch_lr_scheduler(class_name):
396423
with mock_stubs_missing_types():
397424
params = get_params(cls)
398425
assert any(p.annotation is inspect._empty for p in params)
426+
427+
428+
@pytest.mark.skipif(not torch_available, reason="torch package is required")
429+
def test_get_params_torch_function_argmax(allow_py_files):
430+
import torch
431+
432+
params = get_params(torch.argmax)
433+
assert ["input", "dim", "keepdim", "out"] == get_param_names(params)
434+
assert params[0].annotation is torch.Tensor
435+
assert params[1].annotation == Optional[int]
436+
assert params[2].annotation is bool
437+
assert params[3].annotation == Optional[torch.Tensor]
438+
with mock_stubs_missing_resolver():
439+
assert [] == get_params(torch.argmax)
440+
441+
442+
@pytest.mark.skipif(not torchvision_available, reason="torchvision package is required")
443+
def test_get_params_torchvision_class_resize(allow_py_files):
444+
from torchvision.transforms import Resize
445+
446+
params = get_params(Resize)
447+
assert ["size", "interpolation", "max_size", "antialias"] == get_param_names(params)
448+
assert all(p.annotation is inspect._empty for p in params)

0 commit comments

Comments
 (0)