Skip to content

Commit 94a3cf6

Browse files
authored
Add more tests for narrowing logic (#20672)
1 parent 2c93a2c commit 94a3cf6

File tree

3 files changed

+278
-39
lines changed

3 files changed

+278
-39
lines changed

test-data/unit/check-enum.test

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2743,3 +2743,193 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper
27432743
reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis"
27442744
reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis"
27452745
[builtins fixtures/enum.pyi]
2746+
2747+
[case testStrEnumEqualityReachability]
2748+
# flags: --strict-equality --warn-unreachable
2749+
2750+
from __future__ import annotations
2751+
import enum
2752+
2753+
# https://github.com/python/mypy/issues/17162
2754+
2755+
class MyEnum(str, enum.Enum):
2756+
A = 'a'
2757+
2758+
class MyEnum2(enum.StrEnum):
2759+
A = 'a'
2760+
2761+
def f1():
2762+
if MyEnum.A == 'a':
2763+
1 + ''
2764+
2765+
if MyEnum2.A == 'a':
2766+
1 + ''
2767+
[builtins fixtures/primitives.pyi]
2768+
2769+
[case testStrEnumEqualityNarrowing]
2770+
# flags: --strict-equality --warn-unreachable
2771+
2772+
from __future__ import annotations
2773+
import enum
2774+
from typing import Literal
2775+
2776+
# https://github.com/python/mypy/issues/18029
2777+
2778+
class Foo(enum.StrEnum):
2779+
FOO = 'a'
2780+
2781+
def f1(a: Foo | Literal['foo']) -> Foo:
2782+
if a == 'foo':
2783+
# Ideally this is narrowed to just Literal['foo'] (if we learn to narrow based on enum value)
2784+
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
2785+
return Foo.FOO
2786+
2787+
# Ideally this passes
2788+
reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
2789+
return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo")
2790+
[builtins fixtures/primitives.pyi]
2791+
2792+
[case testStrEnumEqualityAlias]
2793+
# flags: --strict-equality --warn-unreachable
2794+
# https://github.com/python/mypy/issues/16830
2795+
from __future__ import annotations
2796+
from enum import Enum, auto
2797+
2798+
# https://github.com/python/mypy/issues/16830
2799+
2800+
class TrafficLight(Enum):
2801+
RED = auto()
2802+
AMBER = auto()
2803+
GREEN = auto()
2804+
2805+
YELLOW = AMBER # alias
2806+
2807+
# Behaviour here is not yet ideal, because we don't model enum aliases
2808+
2809+
def demo1(inst: TrafficLight) -> None:
2810+
if inst is TrafficLight.AMBER:
2811+
reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.AMBER]"
2812+
else:
2813+
reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.RED] | Literal[__main__.TrafficLight.GREEN] | Literal[__main__.TrafficLight.YELLOW]"
2814+
2815+
if inst == TrafficLight.AMBER:
2816+
reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.AMBER]"
2817+
else:
2818+
reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.RED] | Literal[__main__.TrafficLight.GREEN] | Literal[__main__.TrafficLight.YELLOW]"
2819+
2820+
def demo2() -> None:
2821+
if TrafficLight.AMBER is TrafficLight.YELLOW: # E: Non-overlapping identity check (left operand type: "Literal[TrafficLight.AMBER]", right operand type: "Literal[TrafficLight.YELLOW]")
2822+
1 + '' # E: Unsupported operand types for + ("int" and "str")
2823+
else:
2824+
1 + '' # E: Unsupported operand types for + ("int" and "str")
2825+
[builtins fixtures/primitives.pyi]
2826+
2827+
2828+
[case testEnumEqualityNarrowing]
2829+
# flags: --strict-equality --warn-unreachable
2830+
2831+
from __future__ import annotations
2832+
from typing import cast
2833+
from enum import Enum, StrEnum
2834+
2835+
# https://github.com/python/mypy/issues/16830
2836+
2837+
class E(Enum):
2838+
A = "a"
2839+
B = "b"
2840+
C = "c"
2841+
2842+
class Custom:
2843+
def __eq__(self, other: object) -> bool: return True
2844+
2845+
class SE(StrEnum):
2846+
A = "a"
2847+
B = "b"
2848+
C = "c"
2849+
2850+
def f1(x: int | str | E):
2851+
if x == E.A:
2852+
reveal_type(x) # N: Revealed type is "Literal[__main__.E.A]"
2853+
else:
2854+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | Literal[__main__.E.B] | Literal[__main__.E.C]"
2855+
2856+
if x in cast(list[E], []):
2857+
reveal_type(x) # N: Revealed type is "__main__.E"
2858+
else:
2859+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.E"
2860+
2861+
if x == str():
2862+
reveal_type(x) # N: Revealed type is "builtins.str"
2863+
else:
2864+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.E"
2865+
2866+
def f2(x: int | Custom | E):
2867+
if x == E.A:
2868+
reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[__main__.E.A]"
2869+
else:
2870+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | Literal[__main__.E.B] | Literal[__main__.E.C]"
2871+
2872+
if x in cast(list[E], []):
2873+
reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.E"
2874+
else:
2875+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.E"
2876+
2877+
if x == str():
2878+
reveal_type(x) # N: Revealed type is "__main__.Custom"
2879+
else:
2880+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.E"
2881+
2882+
def f3_simple(x: str | SE):
2883+
if x == SE.A:
2884+
reveal_type(x) # N: Revealed type is "builtins.str | __main__.SE"
2885+
2886+
def f3(x: int | str | SE):
2887+
# Ideally we filter out some of these ints
2888+
if x == SE.A:
2889+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2890+
else:
2891+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2892+
2893+
if x in cast(list[SE], []):
2894+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2895+
else:
2896+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2897+
2898+
if x == str():
2899+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2900+
else:
2901+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2902+
2903+
def f4(x: int | Custom | SE):
2904+
if x == SE.A:
2905+
reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[__main__.SE.A]"
2906+
else:
2907+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | Literal[__main__.SE.B] | Literal[__main__.SE.C]"
2908+
2909+
if x in cast(list[SE], []):
2910+
reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.SE"
2911+
else:
2912+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.SE"
2913+
2914+
if x == str():
2915+
reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.SE"
2916+
else:
2917+
reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.SE"
2918+
2919+
def f5(x: str | Custom | SE):
2920+
if x == SE.A:
2921+
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.A] | __main__.Custom"
2922+
else:
2923+
reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom"
2924+
2925+
if x in cast(list[SE], []):
2926+
reveal_type(x) # N: Revealed type is "__main__.SE | __main__.Custom"
2927+
else:
2928+
reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE"
2929+
2930+
if x == str():
2931+
reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom"
2932+
else:
2933+
reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE"
2934+
2935+
[builtins fixtures/primitives.pyi]

test-data/unit/check-isinstance.test

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2008,6 +2008,7 @@ else:
20082008
[out]
20092009

20102010
[case testNarrowTypeAfterInListNested]
2011+
# flags: --warn-unreachable
20112012
from typing import List, Optional, Any
20122013

20132014
x: Optional[int]
@@ -2019,7 +2020,6 @@ if lst in nested_any:
20192020
if x in nested_any:
20202021
reveal_type(x) # N: Revealed type is "builtins.int | None"
20212022
[builtins fixtures/list.pyi]
2022-
[out]
20232023

20242024
[case testNarrowTypeAfterInTuple]
20252025
from typing import Optional
@@ -2885,6 +2885,24 @@ if hasattr(mod, "y"):
28852885
def __getattr__(attr: str) -> str: ...
28862886
[builtins fixtures/module.pyi]
28872887

2888+
[case testMultipleHasAttr-xfail]
2889+
# flags: --warn-unreachable
2890+
# https://github.com/python/mypy/issues/20596
2891+
from __future__ import annotations
2892+
from typing import Any
2893+
2894+
def len(obj: object) -> Any: ...
2895+
2896+
def f(x: type | str):
2897+
if (
2898+
hasattr(x, "__origin__")
2899+
and x.__origin__ is list
2900+
and hasattr(x, "__args__")
2901+
and len(x.__args__) == 1
2902+
):
2903+
reveal_type(x.__args__[0]) # N: Revealed type is "Any"
2904+
[builtins fixtures/module.pyi]
2905+
28882906
[case testTypeIsntLostAfterNarrowing]
28892907
from typing import Any
28902908

@@ -2912,3 +2930,11 @@ if isinstance(a, B):
29122930
c = a
29132931

29142932
[builtins fixtures/isinstance.pyi]
2933+
2934+
[case testIsInstanceTypeAny]
2935+
from typing import Any
2936+
2937+
def foo(x: object, t: type[Any]):
2938+
if isinstance(x, t):
2939+
reveal_type(x) # N: Revealed type is "Any"
2940+
[builtins fixtures/isinstance.pyi]

test-data/unit/check-narrowing.test

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -829,19 +829,19 @@ from typing import Literal, Union
829829
class Custom:
830830
def __eq__(self, other: object) -> bool: return True
831831

832-
class Default: pass
832+
def f1(x: Union[Custom, Literal[1], Literal[2]]):
833+
if x == 1:
834+
reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[1]"
835+
else:
836+
reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[2]"
833837

834-
x1: Union[Custom, Literal[1], Literal[2]]
835-
if x1 == 1:
836-
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[1]"
837-
else:
838-
reveal_type(x1) # N: Revealed type is "__main__.Custom | Literal[2]"
838+
class Default: pass
839839

840-
x2: Union[Default, Literal[1], Literal[2]]
841-
if x2 == 1:
842-
reveal_type(x2) # N: Revealed type is "Literal[1]"
843-
else:
844-
reveal_type(x2) # N: Revealed type is "__main__.Default | Literal[2]"
840+
def f2(x: Union[Default, Literal[1], Literal[2]]):
841+
if x == 1:
842+
reveal_type(x) # N: Revealed type is "Literal[1]"
843+
else:
844+
reveal_type(x) # N: Revealed type is "__main__.Default | Literal[2]"
845845
[builtins fixtures/primitives.pyi]
846846

847847
[case testNarrowingEqualityCustomEqualityEnum]
@@ -875,25 +875,23 @@ from typing import Literal, Union
875875
class Custom:
876876
def __eq__(self, other: object) -> bool: return True
877877

878-
class Default: pass
879-
880-
x: Literal[1, 2, None]
881-
y: Custom
882-
z: Default
878+
def f1(x: Literal[1, 2, None], y: Custom):
879+
if 1 == x == y:
880+
reveal_type(x) # N: Revealed type is "Literal[1]"
881+
reveal_type(y) # N: Revealed type is "__main__.Custom"
882+
else:
883+
reveal_type(x) # N: Revealed type is "Literal[2] | None"
884+
reveal_type(y) # N: Revealed type is "__main__.Custom"
883885

884-
if 1 == x == y:
885-
reveal_type(x) # N: Revealed type is "Literal[1]"
886-
reveal_type(y) # N: Revealed type is "__main__.Custom"
887-
else:
888-
reveal_type(x) # N: Revealed type is "Literal[2] | None"
889-
reveal_type(y) # N: Revealed type is "__main__.Custom"
886+
class Default: pass
890887

891-
if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
892-
reveal_type(x) # E: Statement is unreachable
893-
reveal_type(z)
894-
else:
895-
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
896-
reveal_type(z) # N: Revealed type is "__main__.Default"
888+
def f2(x: Literal[1, 2, None], z: Default):
889+
if 1 == x == z: # E: Non-overlapping equality check (left operand type: "Literal[1, 2] | None", right operand type: "Default")
890+
reveal_type(x) # E: Statement is unreachable
891+
reveal_type(z)
892+
else:
893+
reveal_type(x) # N: Revealed type is "Literal[1] | Literal[2] | None"
894+
reveal_type(z) # N: Revealed type is "__main__.Default"
897895
[builtins fixtures/primitives.pyi]
898896

899897
[case testNarrowingCustomEqualityLiteralElseBranch]
@@ -1471,19 +1469,23 @@ if val not in (None,):
14711469
reveal_type(val) # N: Revealed type is "__main__.A"
14721470
else:
14731471
reveal_type(val) # N: Revealed type is "None"
1472+
[builtins fixtures/primitives.pyi]
14741473

1475-
class Hmm:
1474+
[case testNarrowingCustomEqualityOptionalEqualsNone]
1475+
# flags: --strict-equality --warn-unreachable
1476+
from __future__ import annotations
1477+
class Custom:
14761478
def __eq__(self, other) -> bool: ...
14771479

1478-
hmm: Optional[Hmm]
1479-
if hmm == None:
1480-
reveal_type(hmm) # N: Revealed type is "__main__.Hmm | None"
1481-
else:
1482-
reveal_type(hmm) # N: Revealed type is "__main__.Hmm"
1483-
if hmm != None:
1484-
reveal_type(hmm) # N: Revealed type is "__main__.Hmm"
1485-
else:
1486-
reveal_type(hmm) # N: Revealed type is "__main__.Hmm | None"
1480+
def f(x: Custom | None):
1481+
if x == None:
1482+
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
1483+
else:
1484+
reveal_type(x) # N: Revealed type is "__main__.Custom"
1485+
if x != None:
1486+
reveal_type(x) # N: Revealed type is "__main__.Custom"
1487+
else:
1488+
reveal_type(x) # N: Revealed type is "__main__.Custom | None"
14871489
[builtins fixtures/primitives.pyi]
14881490

14891491
[case testNarrowingWithTupleOfTypes]
@@ -3325,6 +3327,7 @@ def bar(y: Any):
33253327
[builtins fixtures/dict-full.pyi]
33263328

33273329
[case testNarrowTypeVarType]
3330+
# flags: --strict-equality --warn-unreachable
33283331
from typing import TypeVar
33293332

33303333
T = TypeVar("T")
@@ -3343,3 +3346,23 @@ def bar(X: type[T]) -> T:
33433346
return A() # E: Incompatible return value type (got "A", expected "T")
33443347
raise
33453348
[builtins fixtures/type.pyi]
3349+
3350+
[case testNarrowingAnyNegativeIntersection-xfail]
3351+
# flags: --strict-equality --warn-unreachable
3352+
# https://github.com/python/mypy/issues/20597
3353+
from __future__ import annotations
3354+
from typing import Any
3355+
3356+
class array: ...
3357+
3358+
def get_result() -> Any: ...
3359+
3360+
def foo(x: str | array) -> str:
3361+
result = get_result()
3362+
if isinstance(result, array):
3363+
return "asdf"
3364+
if result is x:
3365+
reveal_type(result) # N: Revealed type is "Any"
3366+
return result
3367+
raise
3368+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)