Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,21 @@ The semantic versioning only considers the public API as described in
paths are considered internals and can change in minor and patch releases.


v4.42.0 (unreleased)
--------------------

Added
^^^^^
- ``set_parsing_settings`` now supports setting ``allow_py_files`` to enable
stubs resolver searching in ``.py`` files in addition to ``.pyi`` (`#770
<https://github.com/omni-us/jsonargparse/pull/770>`__).

Fixed
^^^^^
- Stubs resolver in some cases failing with maximum recursion error (`#770
<https://github.com/omni-us/jsonargparse/pull/770>`__).


v4.41.0 (2025-09-04)
--------------------

Expand Down
9 changes: 9 additions & 0 deletions jsonargparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def parser_context(**kwargs):
parsing_settings = {
"validate_defaults": False,
"parse_optionals_as_positionals": False,
"stubs_resolver_allow_py_files": False,
}


Expand All @@ -107,6 +108,7 @@ def set_parsing_settings(
docstring_parse_style: Optional["docstring_parser.DocstringStyle"] = None,
docstring_parse_attribute_docstrings: Optional[bool] = None,
parse_optionals_as_positionals: Optional[bool] = None,
stubs_resolver_allow_py_files: Optional[bool] = None,
) -> None:
"""
Modify settings that affect the parsing behavior.
Expand All @@ -129,6 +131,8 @@ def set_parsing_settings(
--key=value as usual, but also as positional. The extra positionals
are applied to optionals in the order that they were added to the
parser. By default, this is False.
stubs_resolver_allow_py_files: Whether the stubs resolver should search
in ``.py`` files in addition to ``.pyi`` files.
"""
# validate_defaults
if isinstance(validate_defaults, bool):
Expand All @@ -150,6 +154,11 @@ def set_parsing_settings(
parsing_settings["parse_optionals_as_positionals"] = parse_optionals_as_positionals
elif parse_optionals_as_positionals is not None:
raise ValueError(f"parse_optionals_as_positionals must be a boolean, but got {parse_optionals_as_positionals}.")
# stubs resolver
if isinstance(stubs_resolver_allow_py_files, bool):
parsing_settings["stubs_resolver_allow_py_files"] = stubs_resolver_allow_py_files
elif stubs_resolver_allow_py_files is not None:
raise ValueError(f"stubs_resolver_allow_py_files must be a boolean, but got {stubs_resolver_allow_py_files}.")


def get_parsing_setting(name: str):
Expand Down
11 changes: 7 additions & 4 deletions jsonargparse/_stubs_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from importlib import import_module
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple

from ._common import get_parsing_setting
from ._optionals import import_typeshed_client, typeshed_client_support
from ._postponed_annotations import NamesVisitor, get_arg_type

Expand Down Expand Up @@ -103,7 +104,9 @@ def find(self, node: ast.AST, method_name: str) -> Optional[ast.FunctionDef]:
def get_stubs_resolver():
global stubs_resolver
if not stubs_resolver:
stubs_resolver = StubsResolver()
allow_py_files = get_parsing_setting("stubs_resolver_allow_py_files")
search_context = tc.get_search_context(allow_py_files=allow_py_files)
stubs_resolver = StubsResolver(search_context=search_context)
return stubs_resolver


Expand Down Expand Up @@ -195,10 +198,10 @@ def add_import_aliases(self, aliases, stub_import: tc.ImportedInfo):
self.add_module_aliases(aliases, module_path, module, stub_ast)
return module_path, stub_import.info.ast

def add_module_aliases(self, aliases, module_path, module, node):
def add_module_aliases(self, aliases, module_path, module, node, skip=set()):
names = NamesVisitor().find(node) if node else []
for name in names:
if alias_already_added(aliases, name, module_path):
if alias_already_added(aliases, name, module_path) or name in skip:
continue
source = module_path
if name in __builtins__:
Expand All @@ -208,7 +211,7 @@ def add_module_aliases(self, aliases, module_path, module, node):
value = getattr(module, name)
elif name in self.get_module_stub_assigns(module_path):
value = self.get_module_stub_assigns(module_path)[name]
self.add_module_aliases(aliases, module_path, module, value.value)
self.add_module_aliases(aliases, module_path, module, value.value, skip={name})
elif name in self.get_module_stub_imports(module_path):
imported_module_path, imported_name = self.get_module_stub_imports(module_path)[name]
imported_module = import_module_or_none(imported_module_path)
Expand Down
48 changes: 45 additions & 3 deletions jsonargparse_tests/test_stubs_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from ipaddress import ip_network
from random import Random, SystemRandom, uniform
from tarfile import TarFile
from typing import Optional
from unittest.mock import patch
from uuid import UUID, uuid5

import pytest

from jsonargparse import set_parsing_settings
from jsonargparse._parameter_resolvers import get_signature_parameters as get_params
from jsonargparse._stubs_resolver import get_arg_type, get_mro_method_parent, get_stubs_resolver
from jsonargparse_tests.conftest import (
Expand All @@ -25,6 +27,7 @@
)

torch_available = bool(find_spec("torch"))
torchvision_available = bool(find_spec("torchvision"))


@pytest.fixture(autouse=True)
Expand All @@ -33,6 +36,21 @@
pytest.skip("typeshed-client package is required")


@pytest.fixture(autouse=True)
def clear_stubs_resolver():
import jsonargparse._stubs_resolver

jsonargparse._stubs_resolver.stubs_resolver = None
yield


@pytest.fixture
def allow_py_files():
with patch.dict("jsonargparse._common.parsing_settings"):
set_parsing_settings(stubs_resolver_allow_py_files=True)
yield


@contextmanager
def mock_stubs_missing_types():
with patch("jsonargparse._parameter_resolvers.add_stub_types"):
Expand Down Expand Up @@ -329,19 +347,20 @@
# pytorch tests


torch_optimizers_schedulers = torch_available
if torch_available:
import importlib.metadata

torch_version = tuple(int(v) for v in importlib.metadata.version("torch").split(".", 2)[:2])

if torch_version < (2, 1) or torch_version >= (2, 4):
torch_available = False
torch_optimizers_schedulers = False
else:
import torch.optim # pylint: disable=import-error
import torch.optim.lr_scheduler # pylint: disable=import-error


@pytest.mark.skipif(not torch_available, reason="only for torch>=2.1,<2.4")
@pytest.mark.skipif(not torch_optimizers_schedulers, reason="only for torch>=2.1,<2.4")
@pytest.mark.parametrize(
"class_name",
[
Expand All @@ -367,7 +386,7 @@
assert any(p.annotation is inspect._empty for p in params)


@pytest.mark.skipif(not torch_available, reason="only for torch>=2.1,<2.4")
@pytest.mark.skipif(not torch_optimizers_schedulers, reason="only for torch>=2.1,<2.4")
@pytest.mark.parametrize(
"class_name",
[
Expand Down Expand Up @@ -396,3 +415,26 @@
with mock_stubs_missing_types():
params = get_params(cls)
assert any(p.annotation is inspect._empty for p in params)


@pytest.mark.skipif(not torch_available, reason="torch package is required")
def test_get_params_torch_function_argmax(allow_py_files):
import torch

params = get_params(torch.argmax)
assert ["input", "dim", "keepdim", "out"] == get_param_names(params)
assert params[0].annotation is torch.Tensor
assert params[1].annotation == Optional[int]
assert params[2].annotation is bool
assert params[3].annotation == Optional[torch.Tensor]
with mock_stubs_missing_resolver():
assert [] == get_params(torch.argmax)


@pytest.mark.skipif(not torchvision_available, reason="torchvision package is required")
def test_get_params_torchvision_class_resize(allow_py_files):
from torchvision.transforms import Resize

params = get_params(Resize)
assert ["size", "interpolation", "max_size", "antialias"] == get_param_names(params)
assert all(p.annotation is inspect._empty for p in params)