1- from typing import TypeVar
1+ from typing import Protocol , TypeVar
22
3- T = TypeVar ("T" )
3+
4+ class Comparable (Protocol ):
5+ def __lt__ (self , other : object ) -> bool : ...
6+ def __le__ (self , other : object ) -> bool : ...
7+
8+
9+ T = TypeVar ("T" , bound = Comparable )
410
511
612def binary_search (arr : list [T ], item : T , left : int , right : int ) -> int :
@@ -26,17 +32,13 @@ def binary_search(arr: list[T], item: T, left: int, right: int) -> int:
2632 return left
2733
2834
29- def insertion_sort (arr : list [T ]) -> list [T ]:
35+ def insertion_sort [ T_contra ] (arr : list [T_contra ]) -> list [T_contra ]: # type: ignore[valid-type]
3036 """
3137 Sort the list in-place using binary insertion sort.
3238
3339 >>> insertion_sort([3, 1, 2, 4])
3440 [1, 2, 3, 4]
3541 """
36- from typing import TypeVar
37-
38- T = TypeVar ("T" )
39-
4042 for i in range (1 , len (arr )):
4143 key = arr [i ]
4244 j = binary_search (arr , key , 0 , i - 1 )
@@ -65,7 +67,7 @@ def merge(left: list[T], right: list[T]) -> list[T]:
6567 return merged
6668
6769
68- def tim_sort (arr : list [T ]) -> list [T ]:
70+ def tim_sort [ T_contra ] (arr : list [T_contra ]) -> list [T_contra ]: # type: ignore[valid-type]
6971 """
7072 Simplified version of TimSort for educational purposes.
7173
@@ -83,10 +85,6 @@ def tim_sort(arr: list[T]) -> list[T]:
8385 >>> tim_sort([]) # empty input
8486 []
8587 """
86- from typing import TypeVar
87-
88- T = TypeVar ("T" )
89-
9088 if not isinstance (arr , list ):
9189 arr = list (arr )
9290 if not arr :
@@ -98,14 +96,14 @@ def tim_sort(arr: list[T]) -> list[T]:
9896 if n == 1 :
9997 return arr .copy ()
10098
101- runs : list [list [T ]] = []
99+ runs : list [list [T_contra ]] = []
102100 for start in range (0 , n , min_run ):
103101 end = min (start + min_run , n )
104102 run = insertion_sort (arr [start :end ])
105103 runs .append (run )
106104
107105 while len (runs ) > 1 :
108- new_runs : list [list [T ]] = []
106+ new_runs : list [list [T_contra ]] = []
109107 for i in range (0 , len (runs ), 2 ):
110108 if i + 1 < len (runs ):
111109 new_runs .append (merge (runs [i ], runs [i + 1 ]))
@@ -119,4 +117,4 @@ def tim_sort(arr: list[T]) -> list[T]:
119117if __name__ == "__main__" :
120118 import doctest
121119
122- doctest .testmod ()
120+ doctest .testmod ()
0 commit comments