Skip to content

Commit c6786ab

Browse files
Fix method override check
1 parent d0d1fe2 commit c6786ab

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

mypy/checker.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@
155155
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
156156
from mypy.semanal_shared import SemanticAnalyzerCoreInterface
157157
from mypy.sharedparse import BINARY_MAGIC_METHODS
158-
from mypy.disallow_str_iteration_state import disallow_str_iteration_state
158+
from mypy.disallow_str_iteration_state import (
159+
STR_ITERATION_PROTOCOL_BASES,
160+
disallow_str_iteration_state,
161+
)
159162
from mypy.state import state
160163
from mypy.subtypes import (
161164
find_member,
@@ -2215,7 +2218,15 @@ def check_method_override(
22152218
)
22162219
)
22172220
found_method_base_classes: list[TypeInfo] = []
2221+
is_str_or_has_str_base = defn.info.fullname == "builtins.str"
22182222
for base in defn.info.mro[1:]:
2223+
if disallow_str_iteration_state.disallow_str_iteration:
2224+
if base.fullname == "builtins.str":
2225+
is_str_or_has_str_base = True
2226+
2227+
if is_str_or_has_str_base and base.fullname in STR_ITERATION_PROTOCOL_BASES:
2228+
continue
2229+
22192230
result = self.check_method_or_accessor_override_for_base(
22202231
defn, base, check_override_compatibility
22212232
)

test-data/unit/check-flags.test

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2454,7 +2454,7 @@ f(memoryview(b"asdf")) # E: Argument 1 to "f" has incompatible type "memoryview
24542454
[case testDisallowStrIteration]
24552455
# flags: --disallow-str-iteration
24562456
from abc import abstractmethod
2457-
from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, TypeVar, Union
2457+
from typing import Collection, Container, Iterable, Mapping, Protocol, Sequence, SupportsIndex, TypeVar, Union
24582458

24592459
def takes_str(x: str):
24602460
for ch in x: # E: Iterating over "str" is disallowed # N: This is because --disallow-str-iteration is enabled
@@ -2489,6 +2489,7 @@ def takes_str_upper_bound(x: T) -> None:
24892489

24902490
class StrSubclass(str):
24912491
def __contains__(self, x: object) -> bool: ...
2492+
def __getitem__(self, key: Union[SupportsIndex, slice], /) -> str: ...
24922493

24932494
def takes_str_subclass(x: StrSubclass):
24942495
for ch in x: # E: Iterating over "StrSubclass" is disallowed # N: This is because --disallow-str-iteration is enabled

test-data/unit/fixtures/str-iter.pyi

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Builtins stub used in disallow-str-iteration tests.
22

33

4-
from typing import Generic, Iterator, Mapping, Sequence, TypeVar, overload
4+
from typing import Generic, Iterator, Mapping, Sequence, SupportsIndex,TypeVar, overload
55

66
_T = TypeVar("_T")
77
_KT = TypeVar("_KT")
@@ -20,10 +20,7 @@ class str(Sequence[str]):
2020
def __iter__(self) -> Iterator[str]: pass
2121
def __len__(self) -> int: pass
2222
def __contains__(self, item: object) -> bool: pass
23-
@overload
24-
def __getitem__(self, i: int, /) -> str: ...
25-
@overload
26-
def __getitem__(self, s: slice, /) -> Sequence[str]: ...
23+
def __getitem__(self, key: SupportsIndex | slice, /) -> str: pass
2724

2825
class list(Sequence[_T], Generic[_T]):
2926
def __iter__(self) -> Iterator[_T]: pass

test-data/unit/fixtures/typing-str-iter.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@ _KT_co = TypeVar("_KT_co", covariant=True) # Key type covariant containers.
1717
_VT_co = TypeVar("_VT_co", covariant=True) # Value type covariant containers.
1818
_TC = TypeVar("_TC", bound=type[object])
1919

20-
@runtime_checkable
2120
class Iterable(Protocol[_T_co]):
2221
@abstractmethod
2322
def __iter__(self) -> Iterator[_T_co]: ...
2423

25-
@runtime_checkable
2624
class Iterator(Iterable[_T_co], Protocol[_T_co]):
2725
@abstractmethod
2826
def __next__(self) -> _T_co: ...
2927
def __iter__(self) -> Iterator[_T_co]: ...
3028

31-
@runtime_checkable
3229
class Container(Protocol[_T_co]):
3330
# This is generic more on vibes than anything else
3431
@abstractmethod
3532
def __contains__(self, x: object, /) -> bool: ...
3633

37-
@runtime_checkable
3834
class Collection(Iterable[_T_co], Container[_T_co], Protocol[_T_co]):
3935
# Implement Sized (but don't have it as a base class).
4036
@abstractmethod
4137
def __len__(self) -> int: ...
4238

39+
class SupportsIndex(Protocol, metaclass=ABCMeta):
40+
@abstractmethod
41+
def __index__(self) -> int: ...
42+
4343
class Sequence(Collection[_T_co]):
4444
@overload
4545
@abstractmethod

0 commit comments

Comments
 (0)