Skip to content

Commit 4914ac2

Browse files
committed
Modernize and fix TimSort implementation type-safe, PEP 695 compliant
1 parent e0e84c6 commit 4914ac2

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

sorts/tim_sort.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from typing import TypeVar
1+
from typing import Protocol, TypeVar
22

3-
T = TypeVar("T")
3+
4+
class Comparable(Protocol):
5+
def __lt__(self, other: object) -> bool: ...
6+
def __le__(self, other: object) -> bool: ...
7+
8+
9+
T = TypeVar("T", bound=Comparable)
410

511

612
def binary_search(arr: list[T], item: T, left: int, right: int) -> int:
@@ -26,17 +32,13 @@ def binary_search(arr: list[T], item: T, left: int, right: int) -> int:
2632
return left
2733

2834

29-
def insertion_sort(arr: list[T]) -> list[T]:
35+
def insertion_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ignore[valid-type]
3036
"""
3137
Sort the list in-place using binary insertion sort.
3238
3339
>>> insertion_sort([3, 1, 2, 4])
3440
[1, 2, 3, 4]
3541
"""
36-
from typing import TypeVar
37-
38-
T = TypeVar("T")
39-
4042
for i in range(1, len(arr)):
4143
key = arr[i]
4244
j = binary_search(arr, key, 0, i - 1)
@@ -65,7 +67,7 @@ def merge(left: list[T], right: list[T]) -> list[T]:
6567
return merged
6668

6769

68-
def tim_sort(arr: list[T]) -> list[T]:
70+
def tim_sort[T_contra](arr: list[T_contra]) -> list[T_contra]: # type: ignore[valid-type]
6971
"""
7072
Simplified version of TimSort for educational purposes.
7173
@@ -83,10 +85,6 @@ def tim_sort(arr: list[T]) -> list[T]:
8385
>>> tim_sort([]) # empty input
8486
[]
8587
"""
86-
from typing import TypeVar
87-
88-
T = TypeVar("T")
89-
9088
if not isinstance(arr, list):
9189
arr = list(arr)
9290
if not arr:
@@ -98,14 +96,14 @@ def tim_sort(arr: list[T]) -> list[T]:
9896
if n == 1:
9997
return arr.copy()
10098

101-
runs: list[list[T]] = []
99+
runs: list[list[T_contra]] = []
102100
for start in range(0, n, min_run):
103101
end = min(start + min_run, n)
104102
run = insertion_sort(arr[start:end])
105103
runs.append(run)
106104

107105
while len(runs) > 1:
108-
new_runs: list[list[T]] = []
106+
new_runs: list[list[T_contra]] = []
109107
for i in range(0, len(runs), 2):
110108
if i + 1 < len(runs):
111109
new_runs.append(merge(runs[i], runs[i + 1]))
@@ -119,4 +117,4 @@ def tim_sort(arr: list[T]) -> list[T]:
119117
if __name__ == "__main__":
120118
import doctest
121119

122-
doctest.testmod()
120+
doctest.testmod()

0 commit comments

Comments
 (0)