@@ -6679,7 +6679,8 @@ def narrow_type_by_identity_equality(
66796679 else :
66806680 raise AssertionError
66816681
6682- partial_type_maps = []
6682+ all_if_maps : list [TypeMap ] = []
6683+ all_else_maps : list [TypeMap ] = []
66836684
66846685 # For each narrowable index, we see what we can narrow based on each relevant target
66856686 for i in expr_indices :
@@ -6690,10 +6691,8 @@ def narrow_type_by_identity_equality(
66906691 continue
66916692
66926693 expr_type = operand_types [i ]
6693- expanded_expr_type = try_expanding_sum_type_to_union (
6694- coerce_to_literal (expr_type ), None
6695- )
66966694 expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6695+ expr_type = try_expanding_sum_type_to_union (coerce_to_literal (expr_type ), None )
66976696 for j in expr_indices :
66986697 if i == j :
66996698 continue
@@ -6703,11 +6702,6 @@ def narrow_type_by_identity_equality(
67036702 continue
67046703 target_type = operand_types [j ]
67056704 if should_coerce_literals :
6706- # TODO: doing this prevents narrowing a single-member Enum to literal
6707- # of its member, because we expand it here and then refuse to add equal
6708- # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6709- # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6710- # See testMatchEnumSingleChoice
67116705 target_type = coerce_to_literal (target_type )
67126706
67136707 if (
@@ -6718,24 +6712,21 @@ def narrow_type_by_identity_equality(
67186712 continue
67196713
67206714 target = TypeRange (target_type , is_upper_bound = False )
6721- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
67226715
6723- if is_value_target :
6724- if_map , else_map = conditional_types_to_typemaps (
6725- operands [i ], * conditional_types (expanded_expr_type , [target ])
6726- )
6727- partial_type_maps .append ((if_map , else_map ))
6716+ if_map , else_map = conditional_types_to_typemaps (
6717+ operands [i ], * conditional_types (expr_type , [target ])
6718+ )
6719+ if is_target_for_value_narrowing (get_proper_type (target_type )):
6720+ all_if_maps .append (if_map )
6721+ all_else_maps .append (else_map )
67286722 else :
6729- if_map , else_map = conditional_types_to_typemaps (
6730- operands [i ], * conditional_types (expr_type , [target ])
6731- )
67326723 # For value targets, it is safe to narrow in the negative case.
67336724 # e.g. if (x: Literal[5] | None) != (y: Literal[5]), we can narrow x to None
67346725 # However, for non-value targets, we cannot do this narrowing,
67356726 # and so we ignore else_map
67366727 # e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6737- if if_map :
6738- partial_type_maps .append (( if_map , {}) )
6728+ if if_map is not None : # TODO: this gate is incorrect and should be removed
6729+ all_if_maps .append (if_map )
67396730
67406731 # Handle narrowing for operands with custom __eq__ methods specially
67416732 # In most cases, we won't be able to do any narrowing
@@ -6757,14 +6748,12 @@ def narrow_type_by_identity_equality(
67576748 if should_coerce_literals :
67586749 target_type = coerce_to_literal (target_type )
67596750 target = TypeRange (target_type , is_upper_bound = False )
6760- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
6761-
6762- if is_value_target :
6751+ if is_target_for_value_narrowing (get_proper_type (target_type )):
67636752 if_map , else_map = conditional_types_to_typemaps (
67646753 operands [i ], * conditional_types (expr_type , [target ])
67656754 )
67666755 if else_map :
6767- partial_type_maps .append (({}, else_map ) )
6756+ all_else_maps .append (else_map )
67686757 continue
67696758
67706759 # If our operand with custom __eq__ is a union, where only some members of the union
@@ -6778,37 +6767,24 @@ def narrow_type_by_identity_equality(
67786767 # we narrow to in the if_map
67796768 or_if_maps .append ({operands [i ]: expr_type })
67806769
6770+ expr_type = coerce_to_literal (try_expanding_sum_type_to_union (expr_type , None ))
67816771 for j in expr_indices :
67826772 if j in custom_eq_indices :
67836773 continue
67846774 target_type = operand_types [j ]
67856775 if should_coerce_literals :
67866776 target_type = coerce_to_literal (target_type )
67876777 target = TypeRange (target_type , is_upper_bound = False )
6788- is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
67896778
6790- if is_value_target :
6791- expr_type = coerce_to_literal (expr_type )
6792- expr_type = try_expanding_sum_type_to_union (expr_type , None )
67936779 if_map , else_map = conditional_types_to_typemaps (
67946780 operands [i ], * conditional_types (expr_type , [target ], default = expr_type )
67956781 )
67966782 or_if_maps .append (if_map )
6797- if is_value_target :
6783+ if is_target_for_value_narrowing ( get_proper_type ( target_type )) :
67986784 or_else_maps .append (else_map )
67996785
6800- final_if_map : TypeMap = {}
6801- final_else_map : TypeMap = {}
6802- if or_if_maps :
6803- final_if_map = or_if_maps [0 ]
6804- for if_map in or_if_maps [1 :]:
6805- final_if_map = or_conditional_maps (final_if_map , if_map )
6806- if or_else_maps :
6807- final_else_map = or_else_maps [0 ]
6808- for else_map in or_else_maps [1 :]:
6809- final_else_map = or_conditional_maps (final_else_map , else_map )
6810-
6811- partial_type_maps .append ((final_if_map , final_else_map ))
6786+ all_if_maps .append (reduce_or_conditional_type_maps (or_if_maps ))
6787+ all_else_maps .append (reduce_or_conditional_type_maps (or_else_maps ))
68126788
68136789 # Handle narrowing for comparisons that produce additional narrowing, like
68146790 # `type(x) == T` or `x.__class__ is T`
@@ -6849,13 +6825,16 @@ def narrow_type_by_identity_equality(
68496825 if isinstance (expr , RefExpr ) and isinstance (expr .node , TypeInfo )
68506826 else False
68516827 )
6852- if not is_final :
6853- else_map = {}
6854- partial_type_maps .append ((if_map , else_map ))
6828+ all_if_maps .append (if_map )
6829+ if is_final :
6830+ # We can only narrow `type(x) == T` in the negative case if T is final
6831+ all_else_maps .append (else_map )
68556832
68566833 # We will not have duplicate entries in our type maps if we only have two operands,
68576834 # so we can skip running meets on the intersections
6858- return reduce_conditional_maps (partial_type_maps , use_meet = len (operands ) > 2 )
6835+ if_map = reduce_and_conditional_type_maps (all_if_maps , use_meet = len (operands ) > 2 )
6836+ else_map = reduce_or_conditional_type_maps (all_else_maps )
6837+ return if_map , else_map
68596838
68606839 def propagate_up_typemap_info (self , new_types : TypeMap ) -> TypeMap :
68616840 """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
@@ -8491,7 +8470,7 @@ def builtin_item_type(tp: Type) -> Type | None:
84918470 return None
84928471
84938472
8494- def and_conditional_maps (m1 : TypeMap , m2 : TypeMap , use_meet : bool = False ) -> TypeMap :
8473+ def and_conditional_maps (m1 : TypeMap , m2 : TypeMap , * , use_meet : bool = False ) -> TypeMap :
84958474 """Calculate what information we can learn from the truth of (e1 and e2)
84968475 in terms of the information that we can learn from the truth of e1 and
84978476 the truth of e2.
@@ -8524,7 +8503,7 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty
85248503 return result
85258504
85268505
8527- def or_conditional_maps (m1 : TypeMap , m2 : TypeMap , coalesce_any : bool = False ) -> TypeMap :
8506+ def or_conditional_maps (m1 : TypeMap , m2 : TypeMap , * , coalesce_any : bool = False ) -> TypeMap :
85288507 """Calculate what information we can learn from the truth of (e1 or e2)
85298508 in terms of the information that we can learn from the truth of e1 and
85308509 the truth of e2. If coalesce_any is True, consider Any a supertype when
@@ -8589,6 +8568,30 @@ def reduce_conditional_maps(
85898568 return final_if_map , final_else_map
85908569
85918570
8571+ def reduce_or_conditional_type_maps (ms : list [TypeMap ]) -> TypeMap :
8572+ """Reduces a list of TypeMaps into a single TypeMap by "or"-ing them together."""
8573+ if len (ms ) == 0 :
8574+ return {}
8575+ if len (ms ) == 1 :
8576+ return ms [0 ]
8577+ result = ms [0 ]
8578+ for m in ms [1 :]:
8579+ result = or_conditional_maps (result , m )
8580+ return result
8581+
8582+
8583+ def reduce_and_conditional_type_maps (ms : list [TypeMap ], * , use_meet : bool ) -> TypeMap :
8584+ """Reduces a list of TypeMaps into a single TypeMap by "and"-ing them together."""
8585+ if len (ms ) == 0 :
8586+ return {}
8587+ if len (ms ) == 1 :
8588+ return ms [0 ]
8589+ result = ms [0 ]
8590+ for m in ms [1 :]:
8591+ result = and_conditional_maps (result , m , use_meet = use_meet )
8592+ return result
8593+
8594+
85928595BUILTINS_CUSTOM_EQ_CHECKS : Final = {
85938596 "builtins.bytes" ,
85948597 "builtins.bytearray" ,
0 commit comments