Skip to content

Commit e3a5b39

Browse files
committed
Improve the existing TimSort implementation
1 parent e2a78d4 commit e3a5b39

File tree

1 file changed

+91
-58
lines changed

1 file changed

+91
-58
lines changed

sorts/tim_sort.py

Lines changed: 91 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,115 @@
1-
def binary_search(lst, item, start, end):
2-
if start == end:
3-
return start if lst[start] > item else start + 1
4-
if start > end:
5-
return start
1+
from typing import List, TypeVar
62

7-
mid = (start + end) // 2
8-
if lst[mid] < item:
9-
return binary_search(lst, item, mid + 1, end)
10-
elif lst[mid] > item:
11-
return binary_search(lst, item, start, mid - 1)
12-
else:
13-
return mid
3+
T = TypeVar("T")
144

155

16-
def insertion_sort(lst):
17-
length = len(lst)
18-
19-
for index in range(1, length):
20-
value = lst[index]
21-
pos = binary_search(lst, value, 0, index - 1)
22-
lst = [*lst[:pos], value, *lst[pos:index], *lst[index + 1 :]]
6+
def binary_search(arr: List[T], item: T, left: int, right: int) -> int:
7+
"""
8+
Return the index where `item` should be inserted in `arr[left:right+1]`
9+
to keep it sorted.
10+
11+
>>> binary_search([1, 3, 5, 7], 6, 0, 3)
12+
3
13+
>>> binary_search([1, 3, 5, 7], 0, 0, 3)
14+
0
15+
>>> binary_search([1, 3, 5, 7], 8, 0, 3)
16+
4
17+
"""
18+
while left <= right:
19+
mid = (left + right) // 2
20+
if arr[mid] == item:
21+
return mid
22+
elif arr[mid] < item:
23+
left = mid + 1
24+
else:
25+
right = mid - 1
26+
return left
2327

24-
return lst
2528

29+
def insertion_sort(arr: List[T]) -> List[T]:
30+
"""
31+
Sort the list in-place using binary insertion sort.
2632
27-
def merge(left, right):
28-
if not left:
29-
return right
33+
>>> insertion_sort([3, 1, 2, 4])
34+
[1, 2, 3, 4]
35+
"""
36+
for i in range(1, len(arr)):
37+
key = arr[i]
38+
j = binary_search(arr, key, 0, i - 1)
39+
arr[:] = arr[:j] + [key] + arr[j:i] + arr[i + 1 :]
40+
return arr
3041

31-
if not right:
32-
return left
3342

34-
if left[0] < right[0]:
35-
return [left[0], *merge(left[1:], right)]
43+
def merge(left: List[T], right: List[T]) -> List[T]:
44+
"""
45+
Merge two sorted lists into one sorted list.
3646
37-
return [right[0], *merge(left, right[1:])]
47+
>>> merge([1, 3, 5], [2, 4, 6])
48+
[1, 2, 3, 4, 5, 6]
49+
"""
50+
merged = []
51+
i = j = 0
52+
while i < len(left) and j < len(right):
53+
if left[i] <= right[j]:
54+
merged.append(left[i])
55+
i += 1
56+
else:
57+
merged.append(right[j])
58+
j += 1
59+
merged.extend(left[i:])
60+
merged.extend(right[j:])
61+
return merged
3862

3963

40-
def tim_sort(lst):
64+
def tim_sort(arr: List[T]) -> List[T]:
4165
"""
66+
Simplified version of TimSort for educational purposes.
67+
68+
TimSort is a hybrid stable sorting algorithm that combines merge sort
69+
and insertion sort. It was originally designed by Tim Peters for Python (2002).
70+
71+
Source: https://en.wikipedia.org/wiki/Timsort
72+
4273
>>> tim_sort("Python")
4374
['P', 'h', 'n', 'o', 't', 'y']
44-
>>> tim_sort((1.1, 1, 0, -1, -1.1))
45-
[-1.1, -1, 0, 1, 1.1]
46-
>>> tim_sort(list(reversed(list(range(7)))))
47-
[0, 1, 2, 3, 4, 5, 6]
48-
>>> tim_sort([3, 2, 1]) == insertion_sort([3, 2, 1])
49-
True
75+
>>> tim_sort([5, 4, 3, 2, 1])
76+
[1, 2, 3, 4, 5]
5077
>>> tim_sort([3, 2, 1]) == sorted([3, 2, 1])
5178
True
79+
>>> tim_sort([]) # empty input
80+
[]
5281
"""
53-
length = len(lst)
54-
runs, sorted_runs = [], []
55-
new_run = [lst[0]]
56-
sorted_array = []
57-
i = 1
58-
while i < length:
59-
if lst[i] < lst[i - 1]:
60-
runs.append(new_run)
61-
new_run = [lst[i]]
62-
else:
63-
new_run.append(lst[i])
64-
i += 1
65-
runs.append(new_run)
82+
if not isinstance(arr, list):
83+
arr = list(arr)
84+
85+
if not arr:
86+
return []
6687

67-
for run in runs:
68-
sorted_runs.append(insertion_sort(run))
69-
for run in sorted_runs:
70-
sorted_array = merge(sorted_array, run)
88+
min_run = 32
89+
n = len(arr)
7190

72-
return sorted_array
91+
if n == 1:
92+
return arr.copy()
7393

94+
runs = []
95+
for start in range(0, n, min_run):
96+
end = min(start + min_run, n)
97+
run = insertion_sort(arr[start:end])
98+
runs.append(run)
7499

75-
def main():
76-
lst = [5, 9, 10, 3, -4, 5, 178, 92, 46, -18, 0, 7]
77-
sorted_lst = tim_sort(lst)
78-
print(sorted_lst)
100+
while len(runs) > 1:
101+
new_runs = []
102+
for i in range(0, len(runs), 2):
103+
if i + 1 < len(runs):
104+
new_runs.append(merge(runs[i], runs[i + 1]))
105+
else:
106+
new_runs.append(runs[i])
107+
runs = new_runs
108+
109+
return runs[0] if runs else []
79110

80111

81112
if __name__ == "__main__":
82-
main()
113+
import doctest
114+
115+
doctest.testmod()

0 commit comments

Comments
 (0)