Skip to content

Commit cbf39ae

Browse files
committed
fully type-safe TimSort using PEP 695 generics
mypy + ruff compliant
1 parent ef59573 commit cbf39ae

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

sorts/tim_sort.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
from typing import Protocol, TypeVar
1+
from typing import Protocol
22

33

44
class Comparable(Protocol):
55
def __lt__(self, other: object) -> bool: ...
66
def __le__(self, other: object) -> bool: ...
77

88

9-
T = TypeVar("T", bound=Comparable)
10-
11-
12-
def binary_search(arr: list[T], item: T, left: int, right: int) -> int:
9+
def binary_search[T: Comparable](arr: list[T], item: T, left: int, right: int) -> int:
1310
"""
1411
Return the index where `item` should be inserted in `arr[left:right+1]`
1512
to keep it sorted.
@@ -32,7 +29,7 @@ def binary_search(arr: list[T], item: T, left: int, right: int) -> int:
3229
return left
3330

3431

35-
def insertion_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ignore[valid-type]
32+
def insertion_sort[T: Comparable](arr: list[T]) -> list[T]:
3633
"""
3734
Sort the list in-place using binary insertion sort.
3835
@@ -46,7 +43,7 @@ def insertion_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ig
4643
return arr
4744

4845

49-
def merge(left: list[T], right: list[T]) -> list[T]:
46+
def merge[T: Comparable](left: list[T], right: list[T]) -> list[T]:
5047
"""
5148
Merge two sorted lists into one sorted list.
5249
@@ -67,7 +64,7 @@ def merge(left: list[T], right: list[T]) -> list[T]:
6764
return merged
6865

6966

70-
def tim_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ignore[valid-type]
67+
def tim_sort[T: Comparable](arr: list[T]) -> list[T]:
7168
"""
7269
Simplified version of TimSort for educational purposes.
7370
@@ -96,14 +93,14 @@ def tim_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ignore[v
9693
if n == 1:
9794
return arr.copy()
9895

99-
runs: list[list[T_contra]] = []
96+
runs: list[list[T]] = []
10097
for start in range(0, n, min_run):
10198
end = min(start + min_run, n)
10299
run = insertion_sort(arr[start:end])
103100
runs.append(run)
104101

105102
while len(runs) > 1:
106-
new_runs: list[list[T_contra]] = []
103+
new_runs: list[list[T]] = []
107104
for i in range(0, len(runs), 2):
108105
if i + 1 < len(runs):
109106
new_runs.append(merge(runs[i], runs[i + 1]))

0 commit comments

Comments
 (0)