@@ -159,39 +159,50 @@ def test_diff_no_op():
159159def test_diff_prepend_append_py_scalars (sh , axis ):
160160 get_queue_or_skip ()
161161
162- arrs = [
163- dpt .ones (sh , dtype = "?" ),
164- dpt .ones (sh , dtype = "i4" ),
165- dpt .ones (sh , dtype = "f4" ),
166- dpt .ones (sh , dtype = "c8" ),
167- ]
168-
169- py_zeros = [
170- False ,
171- 0 ,
172- 0.0 ,
173- complex (0 , 0 ),
174- ]
175-
176- py_ones = [
177- True ,
178- 1 ,
179- 1.0 ,
180- complex (1 , 0 ),
181- ]
182-
183- for zero , one , arr in zip (py_zeros , py_ones , arrs ):
184- n = 1
185- r = dpt .diff (arr , n = n , axis = axis , prepend = zero , append = one )
186- assert isinstance (r , dpt .usm_ndarray )
187- assert all (
188- r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
189- )
190- assert r .shape [axis ] == arr .shape [axis ] + 2 - n
191-
192- r = dpt .diff (arr , n = n , axis = axis , prepend = zero )
193- assert isinstance (r , dpt .usm_ndarray )
194- assert all (
195- r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
196- )
197- assert r .shape [axis ] == arr .shape [axis ] + 1 - n
162+ n = 1
163+
164+ arr = dpt .ones (sh , dtype = "i4" )
165+ zero = 0
166+
167+ # first and last elements along axis
168+ # will be checked for correctness
169+ sl1 = [slice (None )] * arr .ndim
170+ sl1 [axis ] = slice (1 )
171+ sl1 = tuple (sl1 )
172+
173+ sl2 = [slice (None )] * arr .ndim
174+ sl2 [axis ] = slice (- 1 , None , None )
175+ sl2 = tuple (sl2 )
176+
177+ r = dpt .diff (arr , axis = axis , prepend = zero , append = zero )
178+ assert isinstance (r , dpt .usm_ndarray )
179+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
180+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
181+ assert dpt .all (r [sl1 ] == 1 )
182+ assert dpt .all (r [sl2 ] == - 1 )
183+
184+ r = dpt .diff (arr , axis = axis , prepend = zero )
185+ assert isinstance (r , dpt .usm_ndarray )
186+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
187+ assert r .shape [axis ] == arr .shape [axis ] + 1 - n
188+ assert dpt .all (r [sl1 ] == 1 )
189+
190+ r = dpt .diff (arr , axis = axis , append = zero )
191+ assert isinstance (r , dpt .usm_ndarray )
192+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
193+ assert r .shape [axis ] == arr .shape [axis ] + 1 - n
194+ assert dpt .all (r [sl2 ] == - 1 )
195+
196+ r = dpt .diff (arr , axis = axis , prepend = dpt .asarray (zero ), append = zero )
197+ assert isinstance (r , dpt .usm_ndarray )
198+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
199+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
200+ assert dpt .all (r [sl1 ] == 1 )
201+ assert dpt .all (r [sl2 ] == - 1 )
202+
203+ r = dpt .diff (arr , axis = axis , prepend = zero , append = dpt .asarray (zero ))
204+ assert isinstance (r , dpt .usm_ndarray )
205+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
206+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
207+ assert dpt .all (r [sl1 ] == 1 )
208+ assert dpt .all (r [sl2 ] == - 1 )
0 commit comments