1010from ipaddress import ip_network
1111from random import Random , SystemRandom , uniform
1212from tarfile import TarFile
13+ from typing import Optional
1314from unittest .mock import patch
1415from uuid import UUID , uuid5
1516
1617import pytest
1718
19+ from jsonargparse import set_parsing_settings
1820from jsonargparse ._parameter_resolvers import get_signature_parameters as get_params
1921from jsonargparse ._stubs_resolver import get_arg_type , get_mro_method_parent , get_stubs_resolver
2022from jsonargparse_tests .conftest import (
2527)
2628
2729torch_available = bool (find_spec ("torch" ))
30+ torchvision_available = bool (find_spec ("torchvision" ))
2831
2932
3033@pytest .fixture (autouse = True )
@@ -33,6 +36,29 @@ def skip_if_typeshed_client_unavailable():
3336 pytest .skip ("typeshed-client package is required" )
3437
3538
39+ @pytest .fixture (autouse = True )
40+ def clear_stubs_resolver ():
41+ import jsonargparse ._stubs_resolver
42+
43+ jsonargparse ._stubs_resolver .stubs_resolver = None
44+ yield
45+
46+
47+ @pytest .fixture
48+ def allow_py_files ():
49+ with patch .dict ("jsonargparse._common.parsing_settings" ):
50+ set_parsing_settings (stubs_resolver_allow_py_files = True )
51+ yield
52+
53+
54+ @pytest .fixture (params = ["allow-py-files-true" , "allow-py-files-false" ])
55+ def parametrize_allow_py_files (request ):
56+ allow_py_files = request .param == "allow-py-files-true"
57+ with patch .dict ("jsonargparse._common.parsing_settings" ):
58+ set_parsing_settings (stubs_resolver_allow_py_files = allow_py_files )
59+ yield
60+
61+
3662@contextmanager
3763def mock_stubs_missing_types ():
3864 with patch ("jsonargparse._parameter_resolvers.add_stub_types" ):
@@ -112,7 +138,7 @@ def test_get_params_class_with_inheritance():
112138 assert [("firstweekday" , inspect ._empty )] == get_param_types (params )
113139
114140
115- def test_get_params_method ():
141+ def test_get_params_method (parametrize_allow_py_files ):
116142 params = get_params (Random , "randint" )
117143 assert [("a" , int ), ("b" , int )] == get_param_types (params )
118144 with mock_stubs_missing_types ():
@@ -148,7 +174,7 @@ def test_get_params_exec_failure(mock_get_stub_types):
148174 assert [("a" , inspect ._empty ), ("version" , inspect ._empty )] == get_param_types (params )
149175
150176
151- def test_get_params_classmethod ():
177+ def test_get_params_classmethod (parametrize_allow_py_files ):
152178 params = get_params (TarFile , "open" )
153179 expected = [
154180 "name" ,
@@ -190,7 +216,7 @@ def test_get_params_staticmethod():
190216 assert [("value" , inspect ._empty )] == get_param_types (params )
191217
192218
193- def test_get_params_function ():
219+ def test_get_params_function (parametrize_allow_py_files ):
194220 params = get_params (ip_network )
195221 assert ["address" , "strict" ] == get_param_names (params )
196222 if sys .version_info >= (3 , 10 ):
@@ -329,19 +355,20 @@ def test_get_params_inspect_signature_failure_missing_type(logger):
329355# pytorch tests
330356
331357
358+ torch_optimizers_schedulers = torch_available
332359if torch_available :
333360 import importlib .metadata
334361
335362 torch_version = tuple (int (v ) for v in importlib .metadata .version ("torch" ).split ("." , 2 )[:2 ])
336363
337364 if torch_version < (2 , 1 ) or torch_version >= (2 , 4 ):
338- torch_available = False
365+ torch_optimizers_schedulers = False
339366 else :
340367 import torch .optim # pylint: disable=import-error
341368 import torch .optim .lr_scheduler # pylint: disable=import-error
342369
343370
344- @pytest .mark .skipif (not torch_available , reason = "only for torch>=2.1,<2.4" )
371+ @pytest .mark .skipif (not torch_optimizers_schedulers , reason = "only for torch>=2.1,<2.4" )
345372@pytest .mark .parametrize (
346373 "class_name" ,
347374 [
@@ -367,7 +394,7 @@ def test_get_params_torch_optimizer(class_name):
367394 assert any (p .annotation is inspect ._empty for p in params )
368395
369396
370- @pytest .mark .skipif (not torch_available , reason = "only for torch>=2.1,<2.4" )
397+ @pytest .mark .skipif (not torch_optimizers_schedulers , reason = "only for torch>=2.1,<2.4" )
371398@pytest .mark .parametrize (
372399 "class_name" ,
373400 [
@@ -396,3 +423,26 @@ def test_get_params_torch_lr_scheduler(class_name):
396423 with mock_stubs_missing_types ():
397424 params = get_params (cls )
398425 assert any (p .annotation is inspect ._empty for p in params )
426+
427+
428+ @pytest .mark .skipif (not torch_available , reason = "torch package is required" )
429+ def test_get_params_torch_function_argmax (allow_py_files ):
430+ import torch
431+
432+ params = get_params (torch .argmax )
433+ assert ["input" , "dim" , "keepdim" , "out" ] == get_param_names (params )
434+ assert params [0 ].annotation is torch .Tensor
435+ assert params [1 ].annotation == Optional [int ]
436+ assert params [2 ].annotation is bool
437+ assert params [3 ].annotation == Optional [torch .Tensor ]
438+ with mock_stubs_missing_resolver ():
439+ assert [] == get_params (torch .argmax )
440+
441+
442+ @pytest .mark .skipif (not torchvision_available , reason = "torchvision package is required" )
443+ def test_get_params_torchvision_class_resize (allow_py_files ):
444+ from torchvision .transforms import Resize
445+
446+ params = get_params (Resize )
447+ assert ["size" , "interpolation" , "max_size" , "antialias" ] == get_param_names (params )
448+ assert all (p .annotation is inspect ._empty for p in params )
0 commit comments