1+ """Tests for runpod.__init__ module exports."""
2+
3+ import inspect
4+ import runpod
5+
6+
7+ class TestRunpodInit :
8+ """Test runpod module __all__ exports."""
9+
10+ def test_all_defined (self ):
11+ """Test that __all__ is defined in the module."""
12+ assert hasattr (runpod , '__all__' )
13+ assert isinstance (runpod .__all__ , list )
14+ assert len (runpod .__all__ ) > 0
15+
16+ def test_all_symbols_importable (self ):
17+ """Test that all symbols in __all__ are actually importable."""
18+ for symbol in runpod .__all__ :
19+ assert hasattr (runpod , symbol ), f"Symbol '{ symbol } ' in __all__ but not found in module"
20+
21+ def test_api_functions_accessible (self ):
22+ """Test that API functions are accessible and callable."""
23+ api_functions = [
24+ 'create_container_registry_auth' , 'create_endpoint' , 'create_pod' , 'create_template' ,
25+ 'delete_container_registry_auth' , 'get_endpoints' , 'get_gpu' , 'get_gpus' ,
26+ 'get_pod' , 'get_pods' , 'get_user' , 'resume_pod' , 'stop_pod' , 'terminate_pod' ,
27+ 'update_container_registry_auth' , 'update_endpoint_template' , 'update_user_settings'
28+ ]
29+
30+ for func_name in api_functions :
31+ assert func_name in runpod .__all__
32+ assert hasattr (runpod , func_name )
33+ assert callable (getattr (runpod , func_name ))
34+
35+ def test_config_functions_accessible (self ):
36+ """Test that config functions are accessible and callable."""
37+ config_functions = ['check_credentials' , 'get_credentials' , 'set_credentials' ]
38+
39+ for func_name in config_functions :
40+ assert func_name in runpod .__all__
41+ assert hasattr (runpod , func_name )
42+ assert callable (getattr (runpod , func_name ))
43+
44+ def test_endpoint_classes_accessible (self ):
45+ """Test that endpoint classes are accessible."""
46+ endpoint_classes = ['AsyncioEndpoint' , 'AsyncioJob' , 'Endpoint' ]
47+
48+ for class_name in endpoint_classes :
49+ assert class_name in runpod .__all__
50+ assert hasattr (runpod , class_name )
51+ assert inspect .isclass (getattr (runpod , class_name ))
52+
53+ def test_serverless_module_accessible (self ):
54+ """Test that serverless module is accessible."""
55+ assert 'serverless' in runpod .__all__
56+ assert hasattr (runpod , 'serverless' )
57+ assert inspect .ismodule (runpod .serverless )
58+
59+ def test_logger_class_accessible (self ):
60+ """Test that RunPodLogger class is accessible."""
61+ assert 'RunPodLogger' in runpod .__all__
62+ assert hasattr (runpod , 'RunPodLogger' )
63+ assert inspect .isclass (runpod .RunPodLogger )
64+
65+ def test_version_accessible (self ):
66+ """Test that __version__ is accessible."""
67+ assert '__version__' in runpod .__all__
68+ assert hasattr (runpod , '__version__' )
69+ assert isinstance (runpod .__version__ , str )
70+
71+ def test_module_variables_accessible (self ):
72+ """Test that module variables are accessible."""
73+ module_vars = ['SSH_KEY_PATH' , 'profile' , 'api_key' , 'endpoint_url_base' ]
74+
75+ for var_name in module_vars :
76+ assert var_name in runpod .__all__
77+ assert hasattr (runpod , var_name )
78+
79+ def test_private_imports_not_exported (self ):
80+ """Test that private imports are not in __all__."""
81+ private_symbols = {
82+ 'logging' , 'os' , '_credentials'
83+ }
84+ all_symbols = set (runpod .__all__ )
85+
86+ for private_symbol in private_symbols :
87+ assert private_symbol not in all_symbols , f"Private symbol '{ private_symbol } ' should not be in __all__"
88+
89+ def test_all_covers_expected_public_api (self ):
90+ """Test that __all__ contains the expected public API symbols."""
91+ expected_symbols = {
92+ # API functions
93+ 'create_container_registry_auth' , 'create_endpoint' , 'create_pod' , 'create_template' ,
94+ 'delete_container_registry_auth' , 'get_endpoints' , 'get_gpu' , 'get_gpus' ,
95+ 'get_pod' , 'get_pods' , 'get_user' , 'resume_pod' , 'stop_pod' , 'terminate_pod' ,
96+ 'update_container_registry_auth' , 'update_endpoint_template' , 'update_user_settings' ,
97+ # Config functions
98+ 'check_credentials' , 'get_credentials' , 'set_credentials' ,
99+ # Endpoint classes
100+ 'AsyncioEndpoint' , 'AsyncioJob' , 'Endpoint' ,
101+ # Serverless module
102+ 'serverless' ,
103+ # Logger class
104+ 'RunPodLogger' ,
105+ # Version
106+ '__version__' ,
107+ # Module variables
108+ 'SSH_KEY_PATH' , 'profile' , 'api_key' , 'endpoint_url_base'
109+ }
110+
111+ actual_symbols = set (runpod .__all__ )
112+ assert expected_symbols == actual_symbols , f"Expected { expected_symbols } , got { actual_symbols } "
113+
114+ def test_no_duplicate_symbols_in_all (self ):
115+ """Test that __all__ contains no duplicate symbols."""
116+ all_symbols = runpod .__all__
117+ unique_symbols = set (all_symbols )
118+ assert len (all_symbols ) == len (unique_symbols ), f"Duplicates found in __all__: { [x for x in all_symbols if all_symbols .count (x ) > 1 ]} "
0 commit comments