Skip to content

Commit 7fdbed0

Browse files
Fix matching against union of tuples (#19600)
This pull request fixes handling of union types containing tuples in match statements. Previously, when a tuple was part of a union, all its items would be unioned together and treated as a homogeneous tuple of that union type, which was incorrect. It still fallbacks on this behavior if we there are multiple tuples in the union with Unpack in them, but otherwise now it should be handled correctly. I attempted to keep as much of the existing semantics the same besides for this change. I also tried to keep the performance roughly similar, not unioning types more than needed. Fixes #19599 Fixes #19082 --------- Co-authored-by: Shantanu Jain <hauntsaninja@gmail.com>
1 parent 3d7f6c8 commit 7fdbed0

File tree

2 files changed

+150
-47
lines changed

2 files changed

+150
-47
lines changed

mypy/checkpattern.py

Lines changed: 83 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import itertools
56
from collections import defaultdict
67
from typing import Final, NamedTuple
78

@@ -247,37 +248,91 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
247248
if star_position is not None:
248249
required_patterns -= 1
249250

250-
#
251-
# get inner types of original type
252-
#
251+
# 1. Go through all possible types and filter to only those which are sequences that
252+
# could match that number of items
253+
# 2. If there is exactly one tuple left with an unpack, then use that type
254+
# and the unpack index
255+
# 3. Otherwise, take the product of the item types so that each index can have a
256+
# unique type. For tuples with unpack fallback to merging all of their types
257+
# for each index, since we can't handle multiple unpacked items at once yet.
258+
259+
# Whether we have encountered a type that we don't know how to handle in the union
260+
unknown_type = False
261+
# A list of types that could match any of the items in the sequence.
262+
sequence_types: list[Type] = []
263+
# A list of tuple types that could match the sequence, per index
264+
tuple_types: list[list[Type]] = []
265+
# A list of all the unpack tuple types that we encountered, each containing the
266+
# tuple type, unpack index, and union index
267+
unpack_tuple_types: list[tuple[TupleType, int, int]] = []
268+
for i, t in enumerate(
269+
current_type.items if isinstance(current_type, UnionType) else [current_type]
270+
):
271+
t = get_proper_type(t)
272+
if isinstance(t, TupleType):
273+
tuple_items = list(t.items)
274+
unpack_index = find_unpack_in_list(tuple_items)
275+
if unpack_index is None:
276+
size_diff = len(tuple_items) - required_patterns
277+
if size_diff < 0:
278+
continue
279+
if size_diff > 0 and star_position is None:
280+
continue
281+
if not size_diff and star_position is not None:
282+
# Above we subtract from required_patterns if star_position is not None
283+
tuple_items.append(UninhabitedType())
284+
tuple_types.append(tuple_items)
285+
else:
286+
normalized_inner_types = []
287+
for it in tuple_items:
288+
# Unfortunately, it is not possible to "split" the TypeVarTuple
289+
# into individual items, so we just use its upper bound for the whole
290+
# analysis instead.
291+
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
292+
it = UnpackType(it.type.upper_bound)
293+
normalized_inner_types.append(it)
294+
if (
295+
len(normalized_inner_types) - 1 > required_patterns
296+
and star_position is None
297+
):
298+
continue
299+
t = t.copy_modified(items=normalized_inner_types)
300+
unpack_tuple_types.append((t, unpack_index, i))
301+
# In case we have multiple unpacks we want to combine them all, so add
302+
# the combined tuple type to the sequence types.
303+
sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o))
304+
elif isinstance(t, AnyType):
305+
sequence_types.append(AnyType(TypeOfAny.from_another_any, t))
306+
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
307+
sequence_types.append(self.chk.iterable_item_type(t, o))
308+
else:
309+
unknown_type = True
310+
311+
inner_types: list[Type]
312+
313+
# If we only got one unpack tuple type, we can use that
253314
unpack_index = None
254-
if isinstance(current_type, TupleType):
255-
inner_types = current_type.items
256-
unpack_index = find_unpack_in_list(inner_types)
257-
if unpack_index is None:
258-
size_diff = len(inner_types) - required_patterns
259-
if size_diff < 0:
260-
return self.early_non_match()
261-
elif size_diff > 0 and star_position is None:
262-
return self.early_non_match()
315+
if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types:
316+
update_tuple_type, unpack_index, union_index = unpack_tuple_types[0]
317+
inner_types = update_tuple_type.items
318+
if isinstance(current_type, UnionType):
319+
union_items = list(current_type.items)
320+
union_items[union_index] = update_tuple_type
321+
current_type = get_proper_type(UnionType.make_union(items=union_items))
263322
else:
264-
normalized_inner_types = []
265-
for it in inner_types:
266-
# Unfortunately, it is not possible to "split" the TypeVarTuple
267-
# into individual items, so we just use its upper bound for the whole
268-
# analysis instead.
269-
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
270-
it = UnpackType(it.type.upper_bound)
271-
normalized_inner_types.append(it)
272-
inner_types = normalized_inner_types
273-
current_type = current_type.copy_modified(items=normalized_inner_types)
274-
if len(inner_types) - 1 > required_patterns and star_position is None:
275-
return self.early_non_match()
323+
current_type = update_tuple_type
324+
# If we only got tuples we can't match, then exit early
325+
elif not tuple_types and not sequence_types and not unknown_type:
326+
return self.early_non_match()
327+
elif tuple_types:
328+
inner_types = [
329+
make_simplified_union([*sequence_types, *[t for t in group if t is not None]])
330+
for group in itertools.zip_longest(*tuple_types)
331+
]
332+
elif sequence_types:
333+
inner_types = [make_simplified_union(sequence_types)] * len(o.patterns)
276334
else:
277-
inner_type = self.get_sequence_type(current_type, o)
278-
if inner_type is None:
279-
inner_type = self.chk.named_type("builtins.object")
280-
inner_types = [inner_type] * len(o.patterns)
335+
inner_types = [self.chk.named_type("builtins.object")] * len(o.patterns)
281336

282337
#
283338
# match inner patterns
@@ -356,25 +411,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
356411
new_type = self.narrow_sequence_child(current_type, new_inner_type, o)
357412
return PatternType(new_type, rest_type, captures)
358413

359-
def get_sequence_type(self, t: Type, context: Context) -> Type | None:
360-
t = get_proper_type(t)
361-
if isinstance(t, AnyType):
362-
return AnyType(TypeOfAny.from_another_any, t)
363-
if isinstance(t, UnionType):
364-
items = [self.get_sequence_type(item, context) for item in t.items]
365-
not_none_items = [item for item in items if item is not None]
366-
if not_none_items:
367-
return make_simplified_union(not_none_items)
368-
else:
369-
return None
370-
371-
if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
372-
if isinstance(t, TupleType):
373-
t = tuple_fallback(t)
374-
return self.chk.iterable_item_type(t, context)
375-
else:
376-
return None
377-
378414
def contract_starred_pattern_types(
379415
self, types: list[Type], star_pos: int | None, num_patterns: int
380416
) -> list[Type]:

test-data/unit/check-python310.test

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,73 @@ match m6:
17461746

17471747
[builtins fixtures/tuple.pyi]
17481748

1749+
[case testMatchTupleUnions]
1750+
from typing_extensions import Unpack
1751+
1752+
m1: tuple[int, str] | None
1753+
match m1:
1754+
case (a1, b1):
1755+
reveal_type(a1) # N: Revealed type is "builtins.int"
1756+
reveal_type(b1) # N: Revealed type is "builtins.str"
1757+
1758+
m2: tuple[int, str] | tuple[float, str]
1759+
match m2:
1760+
case (a2, b2):
1761+
reveal_type(a2) # N: Revealed type is "builtins.int | builtins.float"
1762+
reveal_type(b2) # N: Revealed type is "builtins.str"
1763+
1764+
m3: tuple[int] | tuple[float, str]
1765+
match m3:
1766+
case (a3, b3):
1767+
reveal_type(a3) # N: Revealed type is "builtins.float"
1768+
reveal_type(b3) # N: Revealed type is "builtins.str"
1769+
1770+
m4: tuple[int] | list[str]
1771+
match m4:
1772+
case (a4, b4):
1773+
reveal_type(a4) # N: Revealed type is "builtins.str"
1774+
reveal_type(b4) # N: Revealed type is "builtins.str"
1775+
1776+
# properly handles unpack when all other patterns are not sequences
1777+
m5: tuple[int, Unpack[tuple[float, ...]]] | None
1778+
match m5:
1779+
case (a5, b5):
1780+
reveal_type(a5) # N: Revealed type is "builtins.int"
1781+
reveal_type(b5) # N: Revealed type is "builtins.float"
1782+
1783+
# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case
1784+
# of combing all types
1785+
m6: tuple[int, Unpack[tuple[float, ...]]] | list[str]
1786+
match m6:
1787+
case (a6, b6):
1788+
reveal_type(a6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1789+
reveal_type(b6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1790+
1791+
# but do still separate types from non unpacked types
1792+
m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, str]
1793+
match m7:
1794+
case (a7, b7, *rest7):
1795+
reveal_type(a7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1796+
reveal_type(b7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
1797+
reveal_type(rest7) # N: Revealed type is "builtins.list[builtins.int | builtins.float]"
1798+
1799+
# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short
1800+
m8: tuple[int, str] | list[float]
1801+
match m8:
1802+
case (a8, b8, *rest8):
1803+
reveal_type(a8) # N: Revealed type is "builtins.float | builtins.int"
1804+
reveal_type(b8) # N: Revealed type is "builtins.float | builtins.str"
1805+
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"
1806+
1807+
m9: tuple[str, str, int] | tuple[str, str]
1808+
match m9:
1809+
case (a9, *rest9):
1810+
reveal_type(a9) # N: Revealed type is "builtins.str"
1811+
reveal_type(rest9) # N: Revealed type is "builtins.list[builtins.str | builtins.int]"
1812+
1813+
[builtins fixtures/tuple.pyi]
1814+
1815+
17491816
[case testMatchEnumSingleChoice]
17501817
from enum import Enum
17511818
from typing import NoReturn

0 commit comments

Comments
 (0)