|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import itertools |
5 | 6 | from collections import defaultdict |
6 | 7 | from typing import Final, NamedTuple |
7 | 8 |
|
@@ -247,37 +248,91 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: |
247 | 248 | if star_position is not None: |
248 | 249 | required_patterns -= 1 |
249 | 250 |
|
250 | | - # |
251 | | - # get inner types of original type |
252 | | - # |
| 251 | + # 1. Go through all possible types and filter to only those which are sequences that |
| 252 | + # could match that number of items |
| 253 | + # 2. If there is exactly one tuple left with an unpack, then use that type |
| 254 | + # and the unpack index |
| 255 | + # 3. Otherwise, take the product of the item types so that each index can have a |
| 256 | + # unique type. For tuples with unpack fallback to merging all of their types |
| 257 | + # for each index, since we can't handle multiple unpacked items at once yet. |
| 258 | + |
| 259 | + # Whether we have encountered a type that we don't know how to handle in the union |
| 260 | + unknown_type = False |
| 261 | + # A list of types that could match any of the items in the sequence. |
| 262 | + sequence_types: list[Type] = [] |
| 263 | + # A list of tuple types that could match the sequence, per index |
| 264 | + tuple_types: list[list[Type]] = [] |
| 265 | + # A list of all the unpack tuple types that we encountered, each containing the |
| 266 | + # tuple type, unpack index, and union index |
| 267 | + unpack_tuple_types: list[tuple[TupleType, int, int]] = [] |
| 268 | + for i, t in enumerate( |
| 269 | + current_type.items if isinstance(current_type, UnionType) else [current_type] |
| 270 | + ): |
| 271 | + t = get_proper_type(t) |
| 272 | + if isinstance(t, TupleType): |
| 273 | + tuple_items = list(t.items) |
| 274 | + unpack_index = find_unpack_in_list(tuple_items) |
| 275 | + if unpack_index is None: |
| 276 | + size_diff = len(tuple_items) - required_patterns |
| 277 | + if size_diff < 0: |
| 278 | + continue |
| 279 | + if size_diff > 0 and star_position is None: |
| 280 | + continue |
| 281 | + if not size_diff and star_position is not None: |
| 282 | + # Above we subtract from required_patterns if star_position is not None |
| 283 | + tuple_items.append(UninhabitedType()) |
| 284 | + tuple_types.append(tuple_items) |
| 285 | + else: |
| 286 | + normalized_inner_types = [] |
| 287 | + for it in tuple_items: |
| 288 | + # Unfortunately, it is not possible to "split" the TypeVarTuple |
| 289 | + # into individual items, so we just use its upper bound for the whole |
| 290 | + # analysis instead. |
| 291 | + if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): |
| 292 | + it = UnpackType(it.type.upper_bound) |
| 293 | + normalized_inner_types.append(it) |
| 294 | + if ( |
| 295 | + len(normalized_inner_types) - 1 > required_patterns |
| 296 | + and star_position is None |
| 297 | + ): |
| 298 | + continue |
| 299 | + t = t.copy_modified(items=normalized_inner_types) |
| 300 | + unpack_tuple_types.append((t, unpack_index, i)) |
| 301 | + # In case we have multiple unpacks we want to combine them all, so add |
| 302 | + # the combined tuple type to the sequence types. |
| 303 | + sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o)) |
| 304 | + elif isinstance(t, AnyType): |
| 305 | + sequence_types.append(AnyType(TypeOfAny.from_another_any, t)) |
| 306 | + elif self.chk.type_is_iterable(t) and isinstance(t, Instance): |
| 307 | + sequence_types.append(self.chk.iterable_item_type(t, o)) |
| 308 | + else: |
| 309 | + unknown_type = True |
| 310 | + |
| 311 | + inner_types: list[Type] |
| 312 | + |
| 313 | + # If we only got one unpack tuple type, we can use that |
253 | 314 | unpack_index = None |
254 | | - if isinstance(current_type, TupleType): |
255 | | - inner_types = current_type.items |
256 | | - unpack_index = find_unpack_in_list(inner_types) |
257 | | - if unpack_index is None: |
258 | | - size_diff = len(inner_types) - required_patterns |
259 | | - if size_diff < 0: |
260 | | - return self.early_non_match() |
261 | | - elif size_diff > 0 and star_position is None: |
262 | | - return self.early_non_match() |
| 315 | + if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types: |
| 316 | + update_tuple_type, unpack_index, union_index = unpack_tuple_types[0] |
| 317 | + inner_types = update_tuple_type.items |
| 318 | + if isinstance(current_type, UnionType): |
| 319 | + union_items = list(current_type.items) |
| 320 | + union_items[union_index] = update_tuple_type |
| 321 | + current_type = get_proper_type(UnionType.make_union(items=union_items)) |
263 | 322 | else: |
264 | | - normalized_inner_types = [] |
265 | | - for it in inner_types: |
266 | | - # Unfortunately, it is not possible to "split" the TypeVarTuple |
267 | | - # into individual items, so we just use its upper bound for the whole |
268 | | - # analysis instead. |
269 | | - if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): |
270 | | - it = UnpackType(it.type.upper_bound) |
271 | | - normalized_inner_types.append(it) |
272 | | - inner_types = normalized_inner_types |
273 | | - current_type = current_type.copy_modified(items=normalized_inner_types) |
274 | | - if len(inner_types) - 1 > required_patterns and star_position is None: |
275 | | - return self.early_non_match() |
| 323 | + current_type = update_tuple_type |
| 324 | + # If we only got tuples we can't match, then exit early |
| 325 | + elif not tuple_types and not sequence_types and not unknown_type: |
| 326 | + return self.early_non_match() |
| 327 | + elif tuple_types: |
| 328 | + inner_types = [ |
| 329 | + make_simplified_union([*sequence_types, *[t for t in group if t is not None]]) |
| 330 | + for group in itertools.zip_longest(*tuple_types) |
| 331 | + ] |
| 332 | + elif sequence_types: |
| 333 | + inner_types = [make_simplified_union(sequence_types)] * len(o.patterns) |
276 | 334 | else: |
277 | | - inner_type = self.get_sequence_type(current_type, o) |
278 | | - if inner_type is None: |
279 | | - inner_type = self.chk.named_type("builtins.object") |
280 | | - inner_types = [inner_type] * len(o.patterns) |
| 335 | + inner_types = [self.chk.named_type("builtins.object")] * len(o.patterns) |
281 | 336 |
|
282 | 337 | # |
283 | 338 | # match inner patterns |
@@ -356,25 +411,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: |
356 | 411 | new_type = self.narrow_sequence_child(current_type, new_inner_type, o) |
357 | 412 | return PatternType(new_type, rest_type, captures) |
358 | 413 |
|
359 | | - def get_sequence_type(self, t: Type, context: Context) -> Type | None: |
360 | | - t = get_proper_type(t) |
361 | | - if isinstance(t, AnyType): |
362 | | - return AnyType(TypeOfAny.from_another_any, t) |
363 | | - if isinstance(t, UnionType): |
364 | | - items = [self.get_sequence_type(item, context) for item in t.items] |
365 | | - not_none_items = [item for item in items if item is not None] |
366 | | - if not_none_items: |
367 | | - return make_simplified_union(not_none_items) |
368 | | - else: |
369 | | - return None |
370 | | - |
371 | | - if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)): |
372 | | - if isinstance(t, TupleType): |
373 | | - t = tuple_fallback(t) |
374 | | - return self.chk.iterable_item_type(t, context) |
375 | | - else: |
376 | | - return None |
377 | | - |
378 | 414 | def contract_starred_pattern_types( |
379 | 415 | self, types: list[Type], star_pos: int | None, num_patterns: int |
380 | 416 | ) -> list[Type]: |
|
0 commit comments