Skip to content

Commit 9be7cd3

Browse files
committed
Fix mypy type handling in partial function and add regression tests
1 parent ff4b2d4 commit 9be7cd3

File tree

8 files changed

+233
-104
lines changed

8 files changed

+233
-104
lines changed

docs/conf.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
# Configuration file for the Sphinx documentation builder.
23
#
34
# This file does only contain a selection of the most common options. For a

returns/contrib/hypothesis/laws.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import inspect
3+
import sys
34
from collections.abc import Callable, Iterator
45
from contextlib import ExitStack, contextmanager
56
from typing import Any, TypeVar, final, overload
@@ -242,7 +243,21 @@ def _create_law_test_case(
242243
)
243244

244245
called_from = inspect.stack()[2]
245-
module = inspect.getmodule(called_from[0])
246+
# `inspect.getmodule(frame)` is surprisingly fragile under some import
247+
# modes (notably `pytest` collection with assertion rewriting) and can
248+
# return `None`. Use the module name from the caller's globals instead.
249+
module_name = called_from.frame.f_globals.get('__name__')
250+
if module_name is None:
251+
module = None
252+
else:
253+
module = sys.modules.get(module_name)
254+
if module is None:
255+
module = inspect.getmodule(called_from.frame)
256+
if module is None:
257+
raise RuntimeError(
258+
'Cannot determine a module to attach generated law tests to. '
259+
'Please call `check_all_laws` from an imported module scope.',
260+
)
246261

247262
template = 'test_{container}_{interface}_{name}'
248263
test_function.__name__ = template.format( # noqa: WPS125

returns/contrib/mypy/_features/partial.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from mypy.nodes import ARG_STAR, ARG_STAR2
66
from mypy.plugin import FunctionContext
77
from mypy.types import (
8+
AnyType,
89
CallableType,
910
FunctionLike,
1011
Instance,
1112
Overloaded,
1213
ProperType,
14+
TypeOfAny,
1315
TypeType,
1416
get_proper_type,
1517
)
@@ -51,30 +53,55 @@ def analyze(ctx: FunctionContext) -> ProperType:
5153
default_return = get_proper_type(ctx.default_return_type)
5254
if not isinstance(default_return, CallableType):
5355
return default_return
56+
return _analyze_partial(ctx, default_return)
57+
58+
59+
def _analyze_partial(
60+
ctx: FunctionContext,
61+
default_return: CallableType,
62+
) -> ProperType:
63+
if not ctx.arg_types or not ctx.arg_types[0]:
64+
# No function passed: treat as decorator factory and fallback to Any.
65+
return AnyType(TypeOfAny.implementation_artifact)
5466

5567
function_def = get_proper_type(ctx.arg_types[0][0])
5668
func_args = _AppliedArgs(ctx)
5769

58-
if len(list(filter(len, ctx.arg_types))) == 1:
59-
return function_def # this means, that `partial(func)` is called
60-
if not isinstance(function_def, _SUPPORTED_TYPES):
70+
is_valid, applied_args = func_args.build_from_context()
71+
if not is_valid:
6172
return default_return
62-
if isinstance(function_def, Instance | TypeType):
63-
# We force `Instance` and similar types to coercse to callable:
64-
function_def = func_args.get_callable_from_context()
73+
if not applied_args:
74+
return function_def # this means, that `partial(func)` is called
6575

66-
is_valid, applied_args = func_args.build_from_context()
67-
if not isinstance(function_def, CallableType | Overloaded) or not is_valid:
76+
callable_def = _coerce_to_callable(function_def, func_args)
77+
if callable_def is None:
6878
return default_return
6979

7080
return _PartialFunctionReducer(
7181
default_return,
72-
function_def,
82+
callable_def,
7383
applied_args,
7484
ctx,
7585
).new_partial()
7686

7787

88+
def _coerce_to_callable(
89+
function_def: ProperType,
90+
func_args: '_AppliedArgs',
91+
) -> CallableType | Overloaded | None:
92+
if not isinstance(function_def, _SUPPORTED_TYPES):
93+
return None
94+
if isinstance(function_def, Instance | TypeType):
95+
# We force `Instance` and similar types to coerce to callable:
96+
from_context = func_args.get_callable_from_context()
97+
return (
98+
from_context
99+
if isinstance(from_context, CallableType | Overloaded)
100+
else None
101+
)
102+
return function_def
103+
104+
78105
@final
79106
class _PartialFunctionReducer:
80107
"""
@@ -219,16 +246,10 @@ def __init__(self, function_ctx: FunctionContext) -> None:
219246
"""
220247
We need the function default context.
221248
222-
The first arguments of ``partial`` is skipped:
249+
The first argument of ``partial`` is skipped:
223250
it is the applied function itself.
224251
"""
225252
self._function_ctx = function_ctx
226-
self._parts = zip(
227-
self._function_ctx.arg_names[1:],
228-
self._function_ctx.arg_types[1:],
229-
self._function_ctx.arg_kinds[1:],
230-
strict=False,
231-
)
232253

233254
def get_callable_from_context(self) -> ProperType:
234255
"""Returns callable type from the context."""
@@ -254,17 +275,29 @@ def build_from_context(self) -> tuple[bool, list[FuncArg]]:
254275
Here ``*args`` and ``**kwargs`` can be literally anything!
255276
In these cases we fallback to the default return type.
256277
"""
257-
applied_args = []
258-
for names, types, kinds in self._parts:
278+
applied_args: list[FuncArg] = []
279+
for arg in self._iter_applied_args():
280+
if arg.kind in {ARG_STAR, ARG_STAR2}:
281+
# We cannot really work with `*args`, `**kwargs`.
282+
return False, []
283+
applied_args.append(arg)
284+
return True, applied_args
285+
286+
def _iter_applied_args(self) -> Iterator[FuncArg]:
287+
skipped_applied_function = False
288+
for names, types, kinds in zip(
289+
self._function_ctx.arg_names,
290+
self._function_ctx.arg_types,
291+
self._function_ctx.arg_kinds,
292+
strict=False,
293+
):
259294
for arg in self._generate_applied_args(
260-
zip(names, types, kinds, strict=False)
295+
zip(names, types, kinds, strict=False),
261296
):
262-
if arg.kind in {ARG_STAR, ARG_STAR2}:
263-
# We cannot really work with `*args`, `**kwargs`.
264-
return False, []
265-
266-
applied_args.append(arg)
267-
return True, applied_args
297+
if not skipped_applied_function:
298+
skipped_applied_function = True
299+
continue
300+
yield arg
268301

269302
def _generate_applied_args(self, arg_parts) -> Iterator[FuncArg]:
270303
yield from starmap(FuncArg, arg_parts)

returns/curry.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,41 @@
22
from functools import partial as _partial
33
from functools import wraps
44
from inspect import BoundArguments, Signature
5-
from typing import Any, TypeAlias, TypeVar
5+
from typing import Any, Generic, TypeAlias, TypeVar, overload
66

77
_ReturnType = TypeVar('_ReturnType')
8+
_Decorator: TypeAlias = Callable[
9+
[Callable[..., _ReturnType]],
10+
Callable[..., _ReturnType],
11+
]
812

913

14+
class _PartialDecorator(Generic[_ReturnType]):
15+
"""Wraps ``functools.partial`` into a decorator without nesting."""
16+
__slots__ = ('_args', '_kwargs')
17+
18+
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
19+
self._args = args
20+
self._kwargs = kwargs
21+
22+
def __call__(self, inner: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:
23+
return _partial(inner, *self._args, **self._kwargs)
24+
25+
26+
@overload
1027
def partial(
1128
func: Callable[..., _ReturnType],
29+
/,
1230
*args: Any,
1331
**kwargs: Any,
14-
) -> Callable[..., _ReturnType]:
32+
) -> Callable[..., _ReturnType]: ...
33+
34+
35+
@overload
36+
def partial(*args: Any, **kwargs: Any) -> _Decorator: ...
37+
38+
39+
def partial(*args: Any, **kwargs: Any) -> Any:
1540
"""
1641
Typed partial application.
1742
@@ -35,7 +60,11 @@ def partial(
3560
- https://docs.python.org/3/library/functools.html#functools.partial
3661
3762
"""
38-
return _partial(func, *args, **kwargs)
63+
if args and callable(args[0]):
64+
return _partial(args[0], *args[1:], **kwargs)
65+
if args and args[0] is None:
66+
args = args[1:]
67+
return _PartialDecorator(args, kwargs)
3968

4069

4170
def curry(function: Callable[..., _ReturnType]) -> Callable[..., _ReturnType]:

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ select = WPS, E999
1717

1818
extend-exclude =
1919
.venv
20+
.cache
2021
build
2122
# Bad code that I write to test things:
2223
ex.py

tests/test_curry/test_partial.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Callable, TypeAlias, TypeVar, cast
2+
3+
from returns.curry import partial
4+
5+
_ReturnType = TypeVar('_ReturnType')
6+
_Decorator: TypeAlias = Callable[
7+
[Callable[..., _ReturnType]],
8+
Callable[..., _ReturnType],
9+
]
10+
11+
12+
def add(first: int, second: int) -> int:
13+
return first + second
14+
15+
16+
def test_partial_direct_call() -> None:
17+
add_one = partial(add, 1)
18+
assert add_one(2) == 3
19+
20+
21+
def test_partial_as_decorator_factory() -> None:
22+
decorator = cast(_Decorator[int], partial())
23+
add_with_decorator = decorator(add)
24+
assert add_with_decorator(1, 2) == 3
25+
26+
27+
def test_partial_with_none_placeholder() -> None:
28+
decorator = cast(_Decorator[int], partial(None, 1))
29+
add_with_none_decorator = decorator(add)
30+
assert add_with_none_decorator(2) == 3

typesafety/test_curry/test_partial/test_partial.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)