@@ -48,22 +48,18 @@ def test_diff_basic(dt):
4848 skip_if_dtype_not_supported (dt , q )
4949
5050 x = dpt .asarray ([9 , 12 , 7 , 17 , 10 , 18 , 15 , 9 , 8 , 8 ], dtype = dt , sycl_queue = q )
51- res = dpt .diff (x )
5251 op = dpt .not_equal if x .dtype is dpt .bool else dpt .subtract
53- expected_res = op (x [1 :], x [:- 1 ])
54- if dpt .dtype (dt ).kind in "fc" :
55- assert dpt .allclose (res , expected_res )
56- else :
57- assert dpt .all (res == expected_res )
5852
59- res = dpt .diff (x , n = 5 )
60- expected_res = x
61- for _ in range (5 ):
62- expected_res = op (expected_res [1 :], expected_res [:- 1 ])
63- if dpt .dtype (dt ).kind in "fc" :
64- assert dpt .allclose (res , expected_res )
65- else :
66- assert dpt .all (res == expected_res )
53+ # test both n=2 and n>2 branches
54+ for n in [1 , 2 , 5 ]:
55+ res = dpt .diff (x , n = n )
56+ expected_res = x
57+ for _ in range (n ):
58+ expected_res = op (expected_res [1 :], expected_res [:- 1 ])
59+ if dpt .dtype (dt ).kind in "fc" :
60+ assert dpt .allclose (res , expected_res )
61+ else :
62+ assert dpt .all (res == expected_res )
6763
6864
6965def test_diff_axis ():
@@ -73,17 +69,15 @@ def test_diff_axis():
7369 dpt .asarray ([9 , 12 , 7 , 17 , 10 , 18 , 15 , 9 , 8 , 8 ], dtype = "i4" ), (3 , 4 , 1 )
7470 )
7571 x [:, ::2 , :] = 0
76- res = dpt .diff (x , n = 1 , axis = 1 )
77- expected_res = dpt .subtract (x [:, 1 :, :], x [:, :- 1 , :])
78- assert dpt .all (res == expected_res )
79-
80- res = dpt .diff (x , n = 3 , axis = 1 )
81- expected_res = x
82- for _ in range (3 ):
83- expected_res = dpt .subtract (
84- expected_res [:, 1 :, :], expected_res [:, :- 1 , :]
85- )
86- assert dpt .all (res == expected_res )
72+
73+ for n in [1 , 2 , 3 ]:
74+ res = dpt .diff (x , n = 3 , axis = 1 )
75+ expected_res = x
76+ for _ in range (3 ):
77+ expected_res = dpt .subtract (
78+ expected_res [:, 1 :, :], expected_res [:, :- 1 , :]
79+ )
80+ assert dpt .all (res == expected_res )
8781
8882
8983def test_diff_prepend_append_type_promotion ():
@@ -179,33 +173,28 @@ def test_diff_prepend_append_py_scalars(sh, axis):
179173 sl2 = tuple (sl2 )
180174
181175 r = dpt .diff (arr , axis = axis , prepend = zero , append = zero )
182- assert isinstance (r , dpt .usm_ndarray )
183176 assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
184177 assert r .shape [axis ] == arr .shape [axis ] + 2 - n
185178 assert dpt .all (r [sl1 ] == 1 )
186179 assert dpt .all (r [sl2 ] == - 1 )
187180
188181 r = dpt .diff (arr , axis = axis , prepend = zero )
189- assert isinstance (r , dpt .usm_ndarray )
190182 assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
191183 assert r .shape [axis ] == arr .shape [axis ] + 1 - n
192184 assert dpt .all (r [sl1 ] == 1 )
193185
194186 r = dpt .diff (arr , axis = axis , append = zero )
195- assert isinstance (r , dpt .usm_ndarray )
196187 assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
197188 assert r .shape [axis ] == arr .shape [axis ] + 1 - n
198189 assert dpt .all (r [sl2 ] == - 1 )
199190
200191 r = dpt .diff (arr , axis = axis , prepend = dpt .asarray (zero ), append = zero )
201- assert isinstance (r , dpt .usm_ndarray )
202192 assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
203193 assert r .shape [axis ] == arr .shape [axis ] + 2 - n
204194 assert dpt .all (r [sl1 ] == 1 )
205195 assert dpt .all (r [sl2 ] == - 1 )
206196
207197 r = dpt .diff (arr , axis = axis , prepend = zero , append = dpt .asarray (zero ))
208- assert isinstance (r , dpt .usm_ndarray )
209198 assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
210199 assert r .shape [axis ] == arr .shape [axis ] + 2 - n
211200 assert dpt .all (r [sl1 ] == 1 )
@@ -218,54 +207,36 @@ def test_tensor_diff_append_prepend_arrays():
218207 n = 1
219208 axis = 0
220209
221- sz = 5
222- arr = dpt .arange (sz , 2 * sz , dtype = "i4" )
223- prepend = dpt .arange (sz , dtype = "i4" )
224- append = dpt .arange (2 * sz , 3 * sz , dtype = "i4" )
225- const_diff = 1
226-
227- r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
228- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
229- assert (
230- r .shape [axis ]
231- == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
232- )
233- assert dpt .all (r == const_diff )
234-
235- r = dpt .diff (arr , axis = axis , prepend = prepend )
236- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
237- assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
238- assert dpt .all (r == const_diff )
239-
240- r = dpt .diff (arr , axis = axis , append = append )
241- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
242- assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
243- assert dpt .all (r == const_diff )
244-
245- sh = (3 , 4 , 5 )
246- sz = prod (sh )
247- arr = dpt .reshape (dpt .arange (sz , 2 * sz , dtype = "i4" ), sh )
248- prepend = dpt .reshape (dpt .arange (sz , dtype = "i4" ), sh )
249- append = dpt .reshape (dpt .arange (2 * sz , 3 * sz , dtype = "i4" ), sh )
250- const_diff = prod (sh [axis + 1 :])
210+ for sh in [(5 ,), (3 , 4 , 5 )]:
211+ sz = prod (sh )
212+ arr = dpt .reshape (dpt .arange (sz , 2 * sz , dtype = "i4" ), sh )
213+ prepend = dpt .reshape (dpt .arange (sz , dtype = "i4" ), sh )
214+ append = dpt .reshape (dpt .arange (2 * sz , 3 * sz , dtype = "i4" ), sh )
215+ const_diff = sz / sh [axis ]
251216
252- r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
253- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
254- assert (
255- r .shape [axis ]
256- == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
257- )
258- assert dpt .all (r == const_diff )
217+ r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
218+ assert all (
219+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
220+ )
221+ assert (
222+ r .shape [axis ]
223+ == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
224+ )
225+ assert dpt .all (r == const_diff )
259226
260- r = dpt .diff (arr , axis = axis , prepend = prepend )
261- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
262- assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
263- assert dpt .all (r == const_diff )
227+ r = dpt .diff (arr , axis = axis , prepend = prepend )
228+ assert all (
229+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
230+ )
231+ assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
232+ assert dpt .all (r == const_diff )
264233
265- r = dpt .diff (arr , axis = axis , append = append )
266- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
267- assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
268- assert dpt .all (r == const_diff )
234+ r = dpt .diff (arr , axis = axis , append = append )
235+ assert all (
236+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
237+ )
238+ assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
239+ assert dpt .all (r == const_diff )
269240
270241
271242def test_diff_wrong_append_prepend_shape ():
@@ -332,6 +303,26 @@ def test_diff_compute_follows_data():
332303 append = ar3 ,
333304 )
334305
306+ assert_raises_regex (
307+ ExecutionPlacementError ,
308+ "Execution placement can not be unambiguously inferred from input "
309+ "arguments" ,
310+ dpt .diff ,
311+ ar1 ,
312+ prepend = ar2 ,
313+ append = 0 ,
314+ )
315+
316+ assert_raises_regex (
317+ ExecutionPlacementError ,
318+ "Execution placement can not be unambiguously inferred from input "
319+ "arguments" ,
320+ dpt .diff ,
321+ ar1 ,
322+ prepend = 0 ,
323+ append = ar2 ,
324+ )
325+
335326 assert_raises_regex (
336327 ExecutionPlacementError ,
337328 "Execution placement can not be unambiguously inferred from input "
0 commit comments