55from mypy .nodes import ARG_STAR , ARG_STAR2
66from mypy .plugin import FunctionContext
77from mypy .types import (
8+ AnyType ,
89 CallableType ,
910 FunctionLike ,
1011 Instance ,
1112 Overloaded ,
1213 ProperType ,
14+ TypeOfAny ,
1315 TypeType ,
1416 get_proper_type ,
1517)
@@ -51,30 +53,55 @@ def analyze(ctx: FunctionContext) -> ProperType:
5153 default_return = get_proper_type (ctx .default_return_type )
5254 if not isinstance (default_return , CallableType ):
5355 return default_return
56+ return _analyze_partial (ctx , default_return )
57+
58+
59+ def _analyze_partial (
60+ ctx : FunctionContext ,
61+ default_return : CallableType ,
62+ ) -> ProperType :
63+ if not ctx .arg_types or not ctx .arg_types [0 ]:
64+ # No function passed: treat as decorator factory and fallback to Any.
65+ return AnyType (TypeOfAny .implementation_artifact )
5466
5567 function_def = get_proper_type (ctx .arg_types [0 ][0 ])
5668 func_args = _AppliedArgs (ctx )
5769
58- if len (list (filter (len , ctx .arg_types ))) == 1 :
59- return function_def # this means, that `partial(func)` is called
60- if not isinstance (function_def , _SUPPORTED_TYPES ):
70+ is_valid , applied_args = func_args .build_from_context ()
71+ if not is_valid :
6172 return default_return
62- if isinstance (function_def , Instance | TypeType ):
63- # We force `Instance` and similar types to coercse to callable:
64- function_def = func_args .get_callable_from_context ()
73+ if not applied_args :
74+ return function_def # this means, that `partial(func)` is called
6575
66- is_valid , applied_args = func_args . build_from_context ( )
67- if not isinstance ( function_def , CallableType | Overloaded ) or not is_valid :
76+ callable_def = _coerce_to_callable ( function_def , func_args )
77+ if callable_def is None :
6878 return default_return
6979
7080 return _PartialFunctionReducer (
7181 default_return ,
72- function_def ,
82+ callable_def ,
7383 applied_args ,
7484 ctx ,
7585 ).new_partial ()
7686
7787
88+ def _coerce_to_callable (
89+ function_def : ProperType ,
90+ func_args : '_AppliedArgs' ,
91+ ) -> CallableType | Overloaded | None :
92+ if not isinstance (function_def , _SUPPORTED_TYPES ):
93+ return None
94+ if isinstance (function_def , Instance | TypeType ):
95+ # We force `Instance` and similar types to coerce to callable:
96+ from_context = func_args .get_callable_from_context ()
97+ return (
98+ from_context
99+ if isinstance (from_context , CallableType | Overloaded )
100+ else None
101+ )
102+ return function_def
103+
104+
78105@final
79106class _PartialFunctionReducer :
80107 """
@@ -219,16 +246,10 @@ def __init__(self, function_ctx: FunctionContext) -> None:
219246 """
220247 We need the function default context.
221248
222- The first arguments of ``partial`` is skipped:
249+ The first argument of ``partial`` is skipped:
223250 it is the applied function itself.
224251 """
225252 self ._function_ctx = function_ctx
226- self ._parts = zip (
227- self ._function_ctx .arg_names [1 :],
228- self ._function_ctx .arg_types [1 :],
229- self ._function_ctx .arg_kinds [1 :],
230- strict = False ,
231- )
232253
233254 def get_callable_from_context (self ) -> ProperType :
234255 """Returns callable type from the context."""
@@ -254,17 +275,29 @@ def build_from_context(self) -> tuple[bool, list[FuncArg]]:
254275 Here ``*args`` and ``**kwargs`` can be literally anything!
255276 In these cases we fallback to the default return type.
256277 """
257- applied_args = []
258- for names , types , kinds in self ._parts :
278+ applied_args : list [FuncArg ] = []
279+ for arg in self ._iter_applied_args ():
280+ if arg .kind in {ARG_STAR , ARG_STAR2 }:
281+ # We cannot really work with `*args`, `**kwargs`.
282+ return False , []
283+ applied_args .append (arg )
284+ return True , applied_args
285+
286+ def _iter_applied_args (self ) -> Iterator [FuncArg ]:
287+ skipped_applied_function = False
288+ for names , types , kinds in zip (
289+ self ._function_ctx .arg_names ,
290+ self ._function_ctx .arg_types ,
291+ self ._function_ctx .arg_kinds ,
292+ strict = False ,
293+ ):
259294 for arg in self ._generate_applied_args (
260- zip (names , types , kinds , strict = False )
295+ zip (names , types , kinds , strict = False ),
261296 ):
262- if arg .kind in {ARG_STAR , ARG_STAR2 }:
263- # We cannot really work with `*args`, `**kwargs`.
264- return False , []
265-
266- applied_args .append (arg )
267- return True , applied_args
297+ if not skipped_applied_function :
298+ skipped_applied_function = True
299+ continue
300+ yield arg
268301
269302 def _generate_applied_args (self , arg_parts ) -> Iterator [FuncArg ]:
270303 yield from starmap (FuncArg , arg_parts )
0 commit comments