@@ -2743,3 +2743,193 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper
27432743reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis"
27442744reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis"
27452745[builtins fixtures/enum.pyi]
2746+
2747+ [case testStrEnumEqualityReachability]
2748+ # flags: --strict-equality --warn-unreachable
2749+
2750+ from __future__ import annotations
2751+ import enum
2752+
2753+ # https://github.com/python/mypy/issues/17162
2754+
2755+ class MyEnum(str, enum.Enum):
2756+ A = 'a'
2757+
2758+ class MyEnum2(enum.StrEnum):
2759+ A = 'a'
2760+
2761+ def f1():
2762+ if MyEnum.A == 'a':
2763+ 1 + ''
2764+
2765+ if MyEnum2.A == 'a':
2766+ 1 + ''
2767+ [builtins fixtures/primitives.pyi]
2768+
2769+ [case testStrEnumEqualityNarrowing]
2770+ # flags: --strict-equality --warn-unreachable
2771+
2772+ from __future__ import annotations
2773+ import enum
2774+ from typing import Literal
2775+
2776+ # https://github.com/python/mypy/issues/18029
2777+
2778+ class Foo(enum.StrEnum):
2779+ FOO = 'a'
2780+
2781+ def f1(a: Foo | Literal['foo']) -> Foo:
2782+ if a == 'foo':
2783+ # Ideally this is narrowed to just Literal['foo'] (if we learn to narrow based on enum value)
2784+ reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
2785+ return Foo.FOO
2786+
2787+ # Ideally this passes
2788+ reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']"
2789+ return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo")
2790+ [builtins fixtures/primitives.pyi]
2791+
2792+ [case testStrEnumEqualityAlias]
2793+ # flags: --strict-equality --warn-unreachable
2794+ # https://github.com/python/mypy/issues/16830
2795+ from __future__ import annotations
2796+ from enum import Enum, auto
2797+
2798+ # https://github.com/python/mypy/issues/16830
2799+
2800+ class TrafficLight(Enum):
2801+ RED = auto()
2802+ AMBER = auto()
2803+ GREEN = auto()
2804+
2805+ YELLOW = AMBER # alias
2806+
2807+ # Behaviour here is not yet ideal, because we don't model enum aliases
2808+
2809+ def demo1(inst: TrafficLight) -> None:
2810+ if inst is TrafficLight.AMBER:
2811+ reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.AMBER]"
2812+ else:
2813+ reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.RED] | Literal[__main__.TrafficLight.GREEN] | Literal[__main__.TrafficLight.YELLOW]"
2814+
2815+ if inst == TrafficLight.AMBER:
2816+ reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.AMBER]"
2817+ else:
2818+ reveal_type(inst) # N: Revealed type is "Literal[__main__.TrafficLight.RED] | Literal[__main__.TrafficLight.GREEN] | Literal[__main__.TrafficLight.YELLOW]"
2819+
2820+ def demo2() -> None:
2821+ if TrafficLight.AMBER is TrafficLight.YELLOW: # E: Non-overlapping identity check (left operand type: "Literal[TrafficLight.AMBER]", right operand type: "Literal[TrafficLight.YELLOW]")
2822+ 1 + '' # E: Unsupported operand types for + ("int" and "str")
2823+ else:
2824+ 1 + '' # E: Unsupported operand types for + ("int" and "str")
2825+ [builtins fixtures/primitives.pyi]
2826+
2827+
2828+ [case testEnumEqualityNarrowing]
2829+ # flags: --strict-equality --warn-unreachable
2830+
2831+ from __future__ import annotations
2832+ from typing import cast
2833+ from enum import Enum, StrEnum
2834+
2835+ # https://github.com/python/mypy/issues/16830
2836+
2837+ class E(Enum):
2838+ A = "a"
2839+ B = "b"
2840+ C = "c"
2841+
2842+ class Custom:
2843+ def __eq__(self, other: object) -> bool: return True
2844+
2845+ class SE(StrEnum):
2846+ A = "a"
2847+ B = "b"
2848+ C = "c"
2849+
2850+ def f1(x: int | str | E):
2851+ if x == E.A:
2852+ reveal_type(x) # N: Revealed type is "Literal[__main__.E.A]"
2853+ else:
2854+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | Literal[__main__.E.B] | Literal[__main__.E.C]"
2855+
2856+ if x in cast(list[E], []):
2857+ reveal_type(x) # N: Revealed type is "__main__.E"
2858+ else:
2859+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.E"
2860+
2861+ if x == str():
2862+ reveal_type(x) # N: Revealed type is "builtins.str"
2863+ else:
2864+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.E"
2865+
2866+ def f2(x: int | Custom | E):
2867+ if x == E.A:
2868+ reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[__main__.E.A]"
2869+ else:
2870+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | Literal[__main__.E.B] | Literal[__main__.E.C]"
2871+
2872+ if x in cast(list[E], []):
2873+ reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.E"
2874+ else:
2875+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.E"
2876+
2877+ if x == str():
2878+ reveal_type(x) # N: Revealed type is "__main__.Custom"
2879+ else:
2880+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.E"
2881+
2882+ def f3_simple(x: str | SE):
2883+ if x == SE.A:
2884+ reveal_type(x) # N: Revealed type is "builtins.str | __main__.SE"
2885+
2886+ def f3(x: int | str | SE):
2887+ # Ideally we filter out some of these ints
2888+ if x == SE.A:
2889+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2890+ else:
2891+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2892+
2893+ if x in cast(list[SE], []):
2894+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2895+ else:
2896+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2897+
2898+ if x == str():
2899+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2900+ else:
2901+ reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE"
2902+
2903+ def f4(x: int | Custom | SE):
2904+ if x == SE.A:
2905+ reveal_type(x) # N: Revealed type is "__main__.Custom | Literal[__main__.SE.A]"
2906+ else:
2907+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | Literal[__main__.SE.B] | Literal[__main__.SE.C]"
2908+
2909+ if x in cast(list[SE], []):
2910+ reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.SE"
2911+ else:
2912+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.SE"
2913+
2914+ if x == str():
2915+ reveal_type(x) # N: Revealed type is "__main__.Custom | __main__.SE"
2916+ else:
2917+ reveal_type(x) # N: Revealed type is "builtins.int | __main__.Custom | __main__.SE"
2918+
2919+ def f5(x: str | Custom | SE):
2920+ if x == SE.A:
2921+ reveal_type(x) # N: Revealed type is "Literal[__main__.SE.A] | __main__.Custom"
2922+ else:
2923+ reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom"
2924+
2925+ if x in cast(list[SE], []):
2926+ reveal_type(x) # N: Revealed type is "__main__.SE | __main__.Custom"
2927+ else:
2928+ reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE"
2929+
2930+ if x == str():
2931+ reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom"
2932+ else:
2933+ reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE"
2934+
2935+ [builtins fixtures/primitives.pyi]
0 commit comments