77from copy import deepcopy
88from dataclasses import is_dataclass
99from importlib import import_module
10- from typing import Any , Dict , ForwardRef , FrozenSet , List , Optional , Set , Tuple , Type , Union , get_type_hints
10+ from typing import Any , ForwardRef , List , Optional , Union , get_type_hints
1111
1212from ._optionals import typing_extensions_import
1313from ._typehints import mapping_origin_types , sequence_origin_types , tuple_set_origin_types
1616var_map = namedtuple ("var_map" , "name value" )
1717none_map = var_map (name = "NoneType" , value = type (None ))
1818union_map = var_map (name = "Union" , value = Union )
19- pep585_map = {
20- "dict" : var_map (name = "Dict" , value = Dict ),
21- "frozenset" : var_map (name = "FrozenSet" , value = FrozenSet ),
22- "list" : var_map (name = "List" , value = List ),
23- "set" : var_map (name = "Set" , value = Set ),
24- "tuple" : var_map (name = "Tuple" , value = Tuple ),
25- "type" : var_map (name = "Type" , value = Type ),
26- }
2719
2820
2921class BackportTypeHints (ast .NodeTransformer ):
30- def visit_Subscript (self , node : ast .Subscript ) -> ast .Subscript :
31- if isinstance (node .value , ast .Name ) and node .value .id in pep585_map :
32- value = self .new_name_load (pep585_map [node .value .id ])
33- else :
34- value = node .value # type: ignore[assignment]
35- return ast .Subscript (
36- value = value ,
37- slice = self .visit (node .slice ),
38- ctx = ast .Load (),
39- )
40-
4122 def visit_Constant (self , node : ast .Constant ) -> Union [ast .Constant , ast .Name ]:
4223 if node .value is None :
4324 return self .new_name_load (none_map )
@@ -193,44 +174,24 @@ def get_arg_type(arg_ast, aliases):
193174 return exec_vars ["___arg_type___" ]
194175
195176
196- def getattr_recursive (obj , attr ):
197- if "." in attr :
198- attr , * attrs = attr .split ("." , 1 )
199- return getattr_recursive (getattr (obj , attr ), attrs [0 ])
200- return getattr (obj , attr )
201-
202-
203177def resolve_forward_refs (arg_type , aliases , logger ):
204- if isinstance (arg_type , str ) and arg_type in aliases :
205- arg_type = aliases [arg_type ]
206178
207179 def resolve_subtypes_forward_refs (typehint ):
208180 if has_subtypes (typehint ):
209181 try :
210182 subtypes = []
211183 for arg in typehint .__args__ :
212184 if isinstance (arg , ForwardRef ):
213- forward_arg , * forward_args = arg .__forward_arg__ .split ("." , 1 )
185+ forward_arg , * _ = arg .__forward_arg__ .split ("." , 1 )
214186 if forward_arg in aliases :
215187 arg = aliases [forward_arg ]
216- if forward_args :
217- arg = getattr_recursive (arg , forward_args [0 ])
218188 else :
219189 raise NameError (f"Name '{ forward_arg } ' is not defined" )
220190 else :
221191 arg = resolve_subtypes_forward_refs (arg )
222192 subtypes .append (arg )
223193 if subtypes != list (typehint .__args__ ):
224194 typehint_origin = get_typehint_origin (typehint )
225- if sys .version_info < (3 , 10 ):
226- if typehint_origin in sequence_origin_types :
227- typehint_origin = List
228- elif typehint_origin in tuple_set_origin_types :
229- typehint_origin = Tuple
230- elif typehint_origin in mapping_origin_types :
231- typehint_origin = Dict
232- elif typehint_origin == type :
233- typehint_origin = Type
234195 typehint = typehint_origin [tuple (subtypes )]
235196 except Exception as ex :
236197 if logger :
@@ -292,7 +253,7 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
292253 except Exception as ex2 :
293254 if isinstance (types , Exception ):
294255 if logger :
295- logger .debug (f"Failed to parse to source code for { obj } " , exc_info = ex2 )
256+ logger .debug (f"Failed to parse the source code for { obj } " , exc_info = ex2 )
296257 raise type (types )(f"{ repr (types )} + { repr (ex2 )} " ) from ex2 # type: ignore[arg-type]
297258 return types
298259
@@ -303,19 +264,13 @@ def get_types(obj: Any, logger: Optional[logging.Logger] = None) -> dict:
303264 ex = types
304265 types = {}
305266
306- if isinstance (node , ast .FunctionDef ):
307- arg_asts = [(a .arg , a .annotation ) for a in node .args .args + node .args .kwonlyargs ]
308- else :
309- arg_asts = [(a .target .id , a .annotation ) for a in node .body if isinstance (a , ast .AnnAssign )] # type: ignore[union-attr]
267+ arg_asts = [(a .arg , a .annotation ) for a in node .args .args + node .args .kwonlyargs ] # type: ignore[union-attr]
310268
311269 for name , annotation in arg_asts :
312270 if annotation and (name not in types or type_requires_eval (types [name ])):
313271 try :
314- if isinstance (annotation , ast .Constant ) and annotation .value in aliases :
315- types [name ] = aliases [annotation .value ]
316- else :
317- arg_type = get_arg_type (annotation , aliases )
318- types [name ] = resolve_forward_refs (arg_type , aliases , logger )
272+ arg_type = get_arg_type (annotation , aliases )
273+ types [name ] = resolve_forward_refs (arg_type , aliases , logger )
319274 except Exception as ex3 :
320275 types [name ] = ex3
321276
@@ -355,8 +310,6 @@ def get_return_type(component, logger=None):
355310 global_vars = get_global_vars (component , logger )
356311 try :
357312 return_type = get_type_hints (component , global_vars )["return" ]
358- if isinstance (return_type , ForwardRef ):
359- return_type = resolve_forward_refs (return_type .__forward_arg__ , global_vars , logger )
360313 except Exception as ex :
361314 if logger :
362315 logger .debug (f"Unable to evaluate types for { component } " , exc_info = ex )
0 commit comments