@@ -177,12 +177,24 @@ def test_argsort_axis0():
177177 x = dpt .reshape (xf , (n , m ))
178178 idx = dpt .argsort (x , axis = 0 )
179179
180- conseq_idx = dpt .arange (m , dtype = idx .dtype )
181- s = x [idx , conseq_idx [dpt .newaxis , :]]
180+ s = dpt .take_along_axis (x , idx , axis = 0 )
182181
183182 assert dpt .all (s [:- 1 , :] <= s [1 :, :])
184183
185184
185+ def test_argsort_axis1 ():
186+ get_queue_or_skip ()
187+
188+ n , m = 200 , 30
189+ xf = dpt .arange (n * m , 0 , step = - 1 , dtype = "i4" )
190+ x = dpt .reshape (xf , (n , m ))
191+ idx = dpt .argsort (x , axis = 1 )
192+
193+ s = dpt .take_along_axis (x , idx , axis = 1 )
194+
195+ assert dpt .all (s [:, :- 1 ] <= s [:, 1 :])
196+
197+
186198def test_sort_strided ():
187199 get_queue_or_skip ()
188200
@@ -199,8 +211,9 @@ def test_argsort_strided():
199211 x_orig = dpt .arange (100 , dtype = "i4" )
200212 x_flipped = dpt .flip (x_orig , axis = 0 )
201213 idx = dpt .argsort (x_flipped )
214+ s = dpt .take_along_axis (x_flipped , idx , axis = 0 )
202215
203- assert dpt .all (x_flipped [ idx ] == x_orig )
216+ assert dpt .all (s == x_orig )
204217
205218
206219def test_sort_0d_array ():
0 commit comments