@@ -141,3 +141,56 @@ def test_diff_empty_array():
141141 prepend = dpt .ones ((3 , 2 , 5 ))
142142 res = dpt .diff (x , axis = 1 , prepend = prepend )
143143 assert res .shape == (3 , 1 , 5 )
144+
145+
146+ def test_diff_no_op ():
147+ get_queue_or_skip ()
148+
149+ x = dpt .ones (10 , dtype = "i4" )
150+ res = dpt .diff (x , n = 0 )
151+ assert dpt .all (x == res )
152+
153+ res = dpt .diff (dpt .reshape (x , (2 , 5 )), n = 0 , axis = 0 )
154+ assert dpt .all (x == res )
155+
156+
157+ @pytest .mark .parametrize ("sh,axis" , [((1 ,), 0 ), ((3 , 4 , 5 ), 1 )])
158+ def test_diff_prepend_append_py_scalars (sh , axis ):
159+ get_queue_or_skip ()
160+
161+ arrs = [
162+ dpt .ones (sh , dtype = "?" ),
163+ dpt .ones (sh , dtype = "i4" ),
164+ dpt .ones (sh , dtype = "f4" ),
165+ dpt .ones (sh , dtype = "c8" ),
166+ ]
167+
168+ py_zeros = [
169+ False ,
170+ 0 ,
171+ 0.0 ,
172+ complex (0 , 0 ),
173+ ]
174+
175+ py_ones = [
176+ True ,
177+ 1 ,
178+ 1.0 ,
179+ complex (1 , 0 ),
180+ ]
181+
182+ for zero , one , arr in zip (py_zeros , py_ones , arrs ):
183+ n = 1
184+ r = dpt .diff (arr , n = n , axis = axis , prepend = zero , append = one )
185+ assert isinstance (r , dpt .usm_ndarray )
186+ assert all (
187+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
188+ )
189+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
190+
191+ r = dpt .diff (arr , n = n , axis = axis , prepend = zero )
192+ assert isinstance (r , dpt .usm_ndarray )
193+ assert all (
194+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
195+ )
196+ assert r .shape [axis ] == arr .shape [axis ] + 1 - n
0 commit comments