Skip to content

Commit a6c96a1

Browse files
fix union simplification
1 parent 446d5da commit a6c96a1

File tree

4 files changed

+34
-9
lines changed

4 files changed

+34
-9
lines changed

mypy/typeops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,15 @@ def make_simplified_union(
636636

637637

638638
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
639-
from mypy.subtypes import is_proper_subtype
639+
from mypy.subtypes import SubtypeContext, is_proper_subtype
640+
641+
subtype_context = SubtypeContext(
642+
ignore_promotions=True,
643+
keep_erased_types=keep_erased,
644+
options=(
645+
checker_state.type_checker.options if checker_state.type_checker is not None else None
646+
),
647+
)
640648

641649
# The first pass through this loop, we check if later items are subtypes of earlier items.
642650
# The second pass through this loop, we check if earlier items are subtypes of later items
@@ -685,9 +693,7 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[
685693
):
686694
continue
687695

688-
if is_proper_subtype(
689-
ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
690-
):
696+
if is_proper_subtype(ti, tj, subtype_context=subtype_context):
691697
duplicate_index = j
692698
break
693699
if duplicate_index != -1:

test-data/unit/check-flags.test

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2454,7 +2454,7 @@ f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview
24542454
[case testDisallowStrIteration]
24552455
# flags: --disallow-str-iteration
24562456
from abc import abstractmethod
2457-
from typing import Collection, Container, Iterable, Protocol, Sequence, TypeVar
2457+
from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, TypeVar, Union
24582458

24592459
def takes_str(x: str):
24602460
for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled
@@ -2508,6 +2508,9 @@ takes_collection_subclass(StrSubclass()) # E: Argument 1 to "takes_collection_s
25082508
# N: "StrSubclass" is missing following "CollectionSubclass" protocol member: \
25092509
# N: __missing_impl__
25102510

2511+
def repro(x: Mapping[str, Union[str, Sequence[str]]]) -> None:
2512+
x = {**x}
2513+
25112514
[builtins fixtures/str-iter.pyi]
25122515
[typing fixtures/typing-str-iter.pyi]
25132516

test-data/unit/fixtures/str-iter.pyi

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Builtins stub used in disallow-str-iteration tests.
22

33

4-
from typing import Generic, Iterator, Sequence, TypeVar, overload
4+
from typing import Generic, Iterator, Mapping, Sequence, TypeVar, overload
55

66
_T = TypeVar("_T")
7+
_KT = TypeVar("_KT")
8+
_VT = TypeVar("_VT")
79

810
class object:
911
def __init__(self) -> None: pass
@@ -14,11 +16,14 @@ class bool(int): pass
1416
class ellipsis: pass
1517
class slice: pass
1618

17-
class str:
19+
class str(Sequence[str]):
1820
def __iter__(self) -> Iterator[str]: pass
1921
def __len__(self) -> int: pass
2022
def __contains__(self, item: object) -> bool: pass
21-
def __getitem__(self, i: int) -> str: pass
23+
@overload
24+
def __getitem__(self, i: int, /) -> str: ...
25+
@overload
26+
def __getitem__(self, s: slice, /) -> Sequence[str]: ...
2227

2328
class list(Sequence[_T], Generic[_T]):
2429
def __iter__(self) -> Iterator[_T]: pass
@@ -38,4 +43,8 @@ class tuple(Sequence[_T], Generic[_T]):
3843
@overload
3944
def __getitem__(self, s: slice, /) -> list[_T]: ...
4045

41-
class dict: pass
46+
class dict(Mapping[_KT, _VT], Generic[_KT, _VT]):
47+
def __iter__(self) -> Iterator[_KT]: pass
48+
def __len__(self) -> int: pass
49+
def __contains__(self, item: object) -> bool: pass
50+
def __getitem__(self, key: _KT) -> _VT: pass

test-data/unit/fixtures/typing-str-iter.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# Minimal typing fixture for disallow-str-iteration tests.
22

3+
import _typeshed
34
from abc import ABCMeta, abstractmethod
45

56
Any = object()
67
TypeVar = 0
78
Generic = 0
89
Protocol = 0
10+
Union = 0
911
overload = 0
1012

1113
_T = TypeVar("_T")
1214
_KT = TypeVar("_KT")
1315
_T_co = TypeVar("_T_co", covariant=True)
16+
_KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers.
1417
_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers.
1518
_TC = TypeVar("_TC", bound=type[object])
1619

@@ -47,10 +50,14 @@ class Sequence(Collection[_T_co]):
4750
def __contains__(self, value: object) -> bool: ...
4851
def __iter__(self) -> Iterator[_T_co]: ...
4952

53+
class KeysView(Protocol[_KT_co]):
54+
def __iter__(self) -> Iterator[_KT_co]: ...
55+
5056
class Mapping(Collection[_KT], Generic[_KT, _VT_co]):
5157
@abstractmethod
5258
def __getitem__(self, key: _KT, /) -> _VT_co: ...
5359
def __contains__(self, key: object, /) -> bool: ...
60+
def keys(self) -> KeysView[_KT]: ...
5461

5562
def runtime_checkable(cls: _TC) -> _TC:
5663
return cls

0 commit comments

Comments
 (0)