2020from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2121
2222
23+ @pytest .fixture
24+ def skip_known_failues_on_cpu (request ):
25+ return request .config .getoption ("--skip-known-top-k-failures-on-cpu" )
26+
27+
2328def _expected_largest_inds (inp , n , shift , k ):
2429 "Computed expected top_k indices for mode='largest'"
2530 assert k < n
@@ -52,10 +57,17 @@ def _expected_largest_inds(inp, n, shift, k):
5257 return expected_inds
5358
5459
60+ def _skip_if_workaround_is_needed (q , dtype , n , enabled ):
61+ if enabled :
62+ dev = q .sycl_device
63+ if dev .is_cpu and dtype in ["i1" , "i2" ] and n > 128 :
64+ pytest .skip (reason = "CPU driver bug" )
65+
66+
5567@pytest .mark .parametrize (
5668 "dtype" ,
5769 [
58- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
70+ "i1" ,
5971 "u1" ,
6072 "i2" ,
6173 "u2" ,
@@ -71,11 +83,10 @@ def _expected_largest_inds(inp, n, shift, k):
7183 ],
7284)
7385@pytest .mark .parametrize ("n" , [33 , 43 , 255 , 511 , 1021 , 8193 ])
74- def test_top_k_1d_largest (dtype , n ):
86+ def test_top_k_1d_largest (dtype , n , skip_known_failues_on_cpu ):
7587 q = get_queue_or_skip ()
7688 skip_if_dtype_not_supported (dtype , q )
77- if dtype == "i1" :
78- pytest .skip ()
89+ _skip_if_workaround_is_needed (q , dtype , n , skip_known_failues_on_cpu )
7990
8091 shift , k = 734 , 5
8192 o = dpt .ones (n , dtype = dtype )
@@ -128,7 +139,7 @@ def _expected_smallest_inds(inp, n, shift, k):
128139@pytest .mark .parametrize (
129140 "dtype" ,
130141 [
131- pytest . param ( "i1" , marks = pytest . mark . skip ( reason = "CPU bug" )) ,
142+ "i1" ,
132143 "u1" ,
133144 "i2" ,
134145 "u2" ,
@@ -144,10 +155,12 @@ def _expected_smallest_inds(inp, n, shift, k):
144155 ],
145156)
146157@pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
147- def test_top_k_1d_smallest (dtype , n ):
158+ def test_top_k_1d_smallest (dtype , n , skip_known_failues_on_cpu ):
148159 q = get_queue_or_skip ()
149160 skip_if_dtype_not_supported (dtype , q )
150161
162+ _skip_if_workaround_is_needed (q , dtype , n , skip_known_failues_on_cpu )
163+
151164 shift , k = 734 , 5
152165 o = dpt .ones (n , dtype = dtype )
153166 z = dpt .zeros (n , dtype = dtype )
@@ -163,3 +176,91 @@ def test_top_k_1d_smallest(dtype, n):
163176 assert dpt .all (s .indices == expected_inds )
164177 assert dpt .all (s .values == dpt .zeros (k , dtype = dtype )), s .values
165178 assert dpt .all (s .values == inp [s .indices ]), s .indices
179+
180+
181+ @pytest .mark .parametrize (
182+ "dtype" ,
183+ [
184+ # skip short types to ensure that m*n can be represented
185+ # in the type
186+ "i4" ,
187+ "u4" ,
188+ "i8" ,
189+ "u8" ,
190+ "f2" ,
191+ "f4" ,
192+ "f8" ,
193+ "c8" ,
194+ "c16" ,
195+ ],
196+ )
197+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
198+ def test_top_k_2d_largest (dtype , n ):
199+ q = get_queue_or_skip ()
200+ skip_if_dtype_not_supported (dtype , q )
201+
202+ m , k = 8 , 3
203+ if dtype == "f2" and m * n > 2000 :
204+ pytest .skip (
205+ "f2 can not distinguish between large integers used in this test"
206+ )
207+
208+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
209+
210+ r = dpt .top_k (x , k , axis = 1 )
211+
212+ assert r .values .shape == (m , k )
213+ assert r .indices .shape == (m , k )
214+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
215+ :, - k :
216+ ]
217+ assert expected_inds .shape == (1 , k )
218+ assert dpt .all (
219+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
220+ ), (r .indices , expected_inds )
221+ expected_vals = x [:, - k :]
222+ assert dpt .all (
223+ dpt .sort (r .values , axis = 1 ) == dpt .sort (expected_vals , axis = 1 )
224+ )
225+
226+
227+ @pytest .mark .parametrize (
228+ "dtype" ,
229+ [
230+ # skip short types to ensure that m*n can be represented
231+ # in the type
232+ "i4" ,
233+ "u4" ,
234+ "i8" ,
235+ "u8" ,
236+ "f2" ,
237+ "f4" ,
238+ "f8" ,
239+ "c8" ,
240+ "c16" ,
241+ ],
242+ )
243+ @pytest .mark .parametrize ("n" , [37 , 39 , 61 , 255 , 257 , 513 , 1021 , 8193 ])
244+ def test_top_k_2d_smallest (dtype , n ):
245+ q = get_queue_or_skip ()
246+ skip_if_dtype_not_supported (dtype , q )
247+
248+ m , k = 8 , 3
249+ if dtype == "f2" and m * n > 2000 :
250+ pytest .skip (
251+ "f2 can not distinguish between large integers used in this test"
252+ )
253+
254+ x = dpt .reshape (dpt .arange (m * n , dtype = dtype ), (m , n ))
255+
256+ r = dpt .top_k (x , k , axis = 1 , mode = "smallest" )
257+
258+ assert r .values .shape == (m , k )
259+ assert r .indices .shape == (m , k )
260+ expected_inds = dpt .reshape (dpt .arange (n , dtype = r .indices .dtype ), (1 , n ))[
261+ :, :k
262+ ]
263+ assert dpt .all (
264+ dpt .sort (r .indices , axis = 1 ) == dpt .sort (expected_inds , axis = 1 )
265+ )
266+ assert dpt .all (dpt .sort (r .values , axis = 1 ) == dpt .sort (x [:, :k ], axis = 1 ))
0 commit comments