@@ -25,11 +25,8 @@ def test_racing_heapify(self):
2525 heap = list (range (OBJECT_COUNT ))
2626 shuffle (heap )
2727
28- def heapify_func (heap : list [int ]):
29- heapq .heapify (heap )
30-
3128 self .run_concurrently (
32- worker_func = heapify_func , args = (heap ,), nthreads = NTHREADS
29+ worker_func = heapq . heapify , args = (heap ,), nthreads = NTHREADS
3330 )
3431 self .assertTrue (self .is_min_heap_property_satisfied (heap ))
3532
@@ -46,12 +43,10 @@ def heappush_func(heap: list[int]):
4643 self .assertTrue (self .is_min_heap_property_satisfied (heap ))
4744
4845 def test_racing_heappop (self ):
49- heap = list (range (OBJECT_COUNT ))
50- shuffle (heap )
51- heapq .heapify (heap )
46+ heap = self .create_heap (OBJECT_COUNT , HeapKind .MIN )
5247
5348 # Each thread pops (OBJECT_COUNT / NTHREADS) items
54- self .assertEqual (0 , OBJECT_COUNT % NTHREADS )
49+ self .assertEqual (OBJECT_COUNT % NTHREADS , 0 )
5550 per_thread_pop_count = OBJECT_COUNT // NTHREADS
5651
5752 def heappop_func (heap : list [int ], pop_count : int ):
@@ -68,16 +63,11 @@ def heappop_func(heap: list[int], pop_count: int):
6863 args = (heap , per_thread_pop_count ),
6964 nthreads = NTHREADS ,
7065 )
71- self .assertEqual (0 , len (heap ))
66+ self .assertEqual (len (heap ), 0 )
7267
7368 def test_racing_heappushpop (self ):
74- heap = list (range (OBJECT_COUNT ))
75- shuffle (heap )
76- heapq .heapify (heap )
77-
78- pushpop_items = [
79- randint (- OBJECT_COUNT , OBJECT_COUNT ) for _ in range (OBJECT_COUNT )
80- ]
69+ heap = self .create_heap (OBJECT_COUNT , HeapKind .MIN )
70+ pushpop_items = self .create_random_list (- 5_000 , 10_000 , OBJECT_COUNT )
8171
8272 def heappushpop_func (heap : list [int ], pushpop_items : list [int ]):
8373 for item in pushpop_items :
@@ -89,17 +79,12 @@ def heappushpop_func(heap: list[int], pushpop_items: list[int]):
8979 args = (heap , pushpop_items ),
9080 nthreads = NTHREADS ,
9181 )
92- self .assertEqual (OBJECT_COUNT , len (heap ))
82+ self .assertEqual (len (heap ), OBJECT_COUNT )
9383 self .assertTrue (self .is_min_heap_property_satisfied (heap ))
9484
9585 def test_racing_heapreplace (self ):
96- heap = list (range (OBJECT_COUNT ))
97- shuffle (heap )
98- heapq .heapify (heap )
99-
100- replace_items = [
101- randint (- OBJECT_COUNT , OBJECT_COUNT ) for _ in range (OBJECT_COUNT )
102- ]
86+ heap = self .create_heap (OBJECT_COUNT , HeapKind .MIN )
87+ replace_items = self .create_random_list (- 5_000 , 10_000 , OBJECT_COUNT )
10388
10489 def heapreplace_func (heap : list [int ], replace_items : list [int ]):
10590 for item in replace_items :
@@ -110,18 +95,15 @@ def heapreplace_func(heap: list[int], replace_items: list[int]):
11095 args = (heap , replace_items ),
11196 nthreads = NTHREADS ,
11297 )
113- self .assertEqual (OBJECT_COUNT , len (heap ))
98+ self .assertEqual (len (heap ), OBJECT_COUNT )
11499 self .assertTrue (self .is_min_heap_property_satisfied (heap ))
115100
116101 def test_racing_heapify_max (self ):
117102 max_heap = list (range (OBJECT_COUNT ))
118103 shuffle (max_heap )
119104
120- def heapify_max_func (max_heap : list [int ]):
121- heapq .heapify_max (max_heap )
122-
123105 self .run_concurrently (
124- worker_func = heapify_max_func , args = (max_heap ,), nthreads = NTHREADS
106+ worker_func = heapq . heapify_max , args = (max_heap ,), nthreads = NTHREADS
125107 )
126108 self .assertTrue (self .is_max_heap_property_satisfied (max_heap ))
127109
@@ -138,12 +120,10 @@ def heappush_max_func(max_heap: list[int]):
138120 self .assertTrue (self .is_max_heap_property_satisfied (max_heap ))
139121
140122 def test_racing_heappop_max (self ):
141- max_heap = list (range (OBJECT_COUNT ))
142- shuffle (max_heap )
143- heapq .heapify_max (max_heap )
123+ max_heap = self .create_heap (OBJECT_COUNT , HeapKind .MAX )
144124
145125 # Each thread pops (OBJECT_COUNT / NTHREADS) items
146- self .assertEqual (0 , OBJECT_COUNT % NTHREADS )
126+ self .assertEqual (OBJECT_COUNT % NTHREADS , 0 )
147127 per_thread_pop_count = OBJECT_COUNT // NTHREADS
148128
149129 def heappop_max_func (max_heap : list [int ], pop_count : int ):
@@ -160,16 +140,11 @@ def heappop_max_func(max_heap: list[int], pop_count: int):
160140 args = (max_heap , per_thread_pop_count ),
161141 nthreads = NTHREADS ,
162142 )
163- self .assertEqual (0 , len (max_heap ))
143+ self .assertEqual (len (max_heap ), 0 )
164144
165145 def test_racing_heappushpop_max (self ):
166- max_heap = list (range (OBJECT_COUNT ))
167- shuffle (max_heap )
168- heapq .heapify_max (max_heap )
169-
170- pushpop_items = [
171- randint (- OBJECT_COUNT , OBJECT_COUNT ) for _ in range (OBJECT_COUNT )
172- ]
146+ max_heap = self .create_heap (OBJECT_COUNT , HeapKind .MAX )
147+ pushpop_items = self .create_random_list (- 5_000 , 10_000 , OBJECT_COUNT )
173148
174149 def heappushpop_max_func (
175150 max_heap : list [int ], pushpop_items : list [int ]
@@ -183,17 +158,12 @@ def heappushpop_max_func(
183158 args = (max_heap , pushpop_items ),
184159 nthreads = NTHREADS ,
185160 )
186- self .assertEqual (OBJECT_COUNT , len (max_heap ))
161+ self .assertEqual (len (max_heap ), OBJECT_COUNT )
187162 self .assertTrue (self .is_max_heap_property_satisfied (max_heap ))
188163
189164 def test_racing_heapreplace_max (self ):
190- max_heap = list (range (OBJECT_COUNT ))
191- shuffle (max_heap )
192- heapq .heapify_max (max_heap )
193-
194- replace_items = [
195- randint (- OBJECT_COUNT , OBJECT_COUNT ) for _ in range (OBJECT_COUNT )
196- ]
165+ max_heap = self .create_heap (OBJECT_COUNT , HeapKind .MAX )
166+ replace_items = self .create_random_list (- 5_000 , 10_000 , OBJECT_COUNT )
197167
198168 def heapreplace_max_func (
199169 max_heap : list [int ], replace_items : list [int ]
@@ -206,7 +176,7 @@ def heapreplace_max_func(
206176 args = (max_heap , replace_items ),
207177 nthreads = NTHREADS ,
208178 )
209- self .assertEqual (OBJECT_COUNT , len (max_heap ))
179+ self .assertEqual (len (max_heap ), OBJECT_COUNT )
210180 self .assertTrue (self .is_max_heap_property_satisfied (max_heap ))
211181
212182 def is_min_heap_property_satisfied (self , heap : list [object ]) -> bool :
@@ -254,25 +224,47 @@ def is_sorted_descending(lst: list[object]) -> bool:
254224 return all (lst [i - 1 ] >= lst [i ] for i in range (1 , len (lst )))
255225
256226 @staticmethod
257- def run_concurrently (worker_func , args , nthreads ) -> None :
227+ def create_heap (size : int , heap_kind : HeapKind ) -> list [int ]:
228+ """
229+ Create a min/max heap where elements are in the range (0, size - 1) and
230+ shuffled before heapify.
231+ """
232+ heap = list (range (OBJECT_COUNT ))
233+ shuffle (heap )
234+ if heap_kind == HeapKind .MIN :
235+ heapq .heapify (heap )
236+ else :
237+ heapq .heapify_max (heap )
238+
239+ return heap
240+
241+ @staticmethod
242+ def create_random_list (a : int , b : int , size : int ) -> list [int ]:
243+ """
244+ Create a random list where elements are in the range a <= elem <= b
245+ """
246+ return [randint (- a , b ) for _ in range (size )]
247+
248+ def run_concurrently (self , worker_func , args , nthreads ) -> None :
258249 """
259250 Run the worker function concurrently in multiple threads.
260251 """
261- barrier = Barrier (NTHREADS )
252+ barrier = Barrier (nthreads )
262253
263254 def wrapper_func (* args ):
264255 # Wait for all threadss to reach this point before proceeding.
265256 barrier .wait ()
266257 worker_func (* args )
267258
268- workers = []
269- for _ in range (nthreads ):
270- worker = Thread (target = wrapper_func , args = args )
271- workers .append (worker )
272- worker .start ()
259+ with threading_helper .catch_threading_exception () as cm :
260+ workers = (
261+ Thread (target = wrapper_func , args = args ) for _ in range (nthreads )
262+ )
263+ with threading_helper .start_threads (workers ):
264+ pass
273265
274- for worker in workers :
275- worker . join ( )
266+ # Worker threads should not raise any exceptions
267+ self . assertIsNone ( cm . exc_value )
276268
277269
278270if __name__ == "__main__" :
0 commit comments