1616
1717import ctypes
1818import numbers
19+ from math import prod
1920
2021import numpy as np
2122import pytest
@@ -1102,7 +1103,7 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type):
11021103 skip_if_dtype_not_supported (dtype , q )
11031104 shape = (2 , 4 , 3 )
11041105 Xnp = (
1105- np .random .randint (- 10 , 10 , size = np . prod (shape ))
1106+ np .random .randint (- 10 , 10 , size = prod (shape ))
11061107 .astype (dtype )
11071108 .reshape (shape )
11081109 )
@@ -1307,6 +1308,10 @@ def relaxed_strides_equal(st1, st2, sh):
13071308 X = dpt .usm_ndarray (sh_s , dtype = "?" )
13081309 X .shape = sh_f
13091310 assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
1311+ sz = X .size
1312+ X .shape = sz
1313+ assert X .shape == (sz ,)
1314+ assert relaxed_strides_equal (X .strides , (1 ,), (sz ,))
13101315
13111316 X = dpt .usm_ndarray (sh_s , dtype = "u4" )
13121317 with pytest .raises (TypeError ):
@@ -2077,11 +2082,9 @@ def test_tril(dtype):
20772082 skip_if_dtype_not_supported (dtype , q )
20782083
20792084 shape = (2 , 3 , 4 , 5 , 5 )
2080- X = dpt .reshape (
2081- dpt .arange (np .prod (shape ), dtype = dtype , sycl_queue = q ), shape
2082- )
2085+ X = dpt .reshape (dpt .arange (prod (shape ), dtype = dtype , sycl_queue = q ), shape )
20832086 Y = dpt .tril (X )
2084- Xnp = np .arange (np . prod (shape ), dtype = dtype ).reshape (shape )
2087+ Xnp = np .arange (prod (shape ), dtype = dtype ).reshape (shape )
20852088 Ynp = np .tril (Xnp )
20862089 assert Y .dtype == Ynp .dtype
20872090 assert np .array_equal (Ynp , dpt .asnumpy (Y ))
@@ -2093,11 +2096,9 @@ def test_triu(dtype):
20932096 skip_if_dtype_not_supported (dtype , q )
20942097
20952098 shape = (4 , 5 )
2096- X = dpt .reshape (
2097- dpt .arange (np .prod (shape ), dtype = dtype , sycl_queue = q ), shape
2098- )
2099+ X = dpt .reshape (dpt .arange (prod (shape ), dtype = dtype , sycl_queue = q ), shape )
20992100 Y = dpt .triu (X , k = 1 )
2100- Xnp = np .arange (np . prod (shape ), dtype = dtype ).reshape (shape )
2101+ Xnp = np .arange (prod (shape ), dtype = dtype ).reshape (shape )
21012102 Ynp = np .triu (Xnp , k = 1 )
21022103 assert Y .dtype == Ynp .dtype
21032104 assert np .array_equal (Ynp , dpt .asnumpy (Y ))
@@ -2110,7 +2111,7 @@ def test_tri_usm_type(tri_fn, usm_type):
21102111 dtype = dpt .uint16
21112112
21122113 shape = (2 , 3 , 4 , 5 , 5 )
2113- size = np . prod (shape )
2114+ size = prod (shape )
21142115 X = dpt .reshape (
21152116 dpt .arange (size , dtype = dtype , usm_type = usm_type , sycl_queue = q ), shape
21162117 )
@@ -2129,11 +2130,11 @@ def test_tril_slice():
21292130 q = get_queue_or_skip ()
21302131
21312132 shape = (6 , 10 )
2132- X = dpt .reshape (
2133- dpt . arange ( np . prod ( shape ), dtype = "int" , sycl_queue = q ), shape
2134- )[ 1 :, :: - 2 ]
2133+ X = dpt .reshape (dpt . arange ( prod ( shape ), dtype = "int" , sycl_queue = q ), shape )[
2134+ 1 :, :: - 2
2135+ ]
21352136 Y = dpt .tril (X )
2136- Xnp = np .arange (np . prod (shape ), dtype = "int" ).reshape (shape )[1 :, ::- 2 ]
2137+ Xnp = np .arange (prod (shape ), dtype = "int" ).reshape (shape )[1 :, ::- 2 ]
21372138 Ynp = np .tril (Xnp )
21382139 assert Y .dtype == Ynp .dtype
21392140 assert np .array_equal (Ynp , dpt .asnumpy (Y ))
@@ -2144,14 +2145,12 @@ def test_triu_permute_dims():
21442145
21452146 shape = (2 , 3 , 4 , 5 )
21462147 X = dpt .permute_dims (
2147- dpt .reshape (
2148- dpt .arange (np .prod (shape ), dtype = "int" , sycl_queue = q ), shape
2149- ),
2148+ dpt .reshape (dpt .arange (prod (shape ), dtype = "int" , sycl_queue = q ), shape ),
21502149 (3 , 2 , 1 , 0 ),
21512150 )
21522151 Y = dpt .triu (X )
21532152 Xnp = np .transpose (
2154- np .arange (np . prod (shape ), dtype = "int" ).reshape (shape ), (3 , 2 , 1 , 0 )
2153+ np .arange (prod (shape ), dtype = "int" ).reshape (shape ), (3 , 2 , 1 , 0 )
21552154 )
21562155 Ynp = np .triu (Xnp )
21572156 assert Y .dtype == Ynp .dtype
@@ -2189,12 +2188,12 @@ def test_triu_order_k(order, k):
21892188
21902189 shape = (3 , 3 )
21912190 X = dpt .reshape (
2192- dpt .arange (np . prod (shape ), dtype = "int" , sycl_queue = q ),
2191+ dpt .arange (prod (shape ), dtype = "int" , sycl_queue = q ),
21932192 shape ,
21942193 order = order ,
21952194 )
21962195 Y = dpt .triu (X , k = k )
2197- Xnp = np .arange (np . prod (shape ), dtype = "int" ).reshape (shape , order = order )
2196+ Xnp = np .arange (prod (shape ), dtype = "int" ).reshape (shape , order = order )
21982197 Ynp = np .triu (Xnp , k = k )
21992198 assert Y .dtype == Ynp .dtype
22002199 assert X .flags == Y .flags
@@ -2210,12 +2209,12 @@ def test_tril_order_k(order, k):
22102209 pytest .skip ("Queue could not be created" )
22112210 shape = (3 , 3 )
22122211 X = dpt .reshape (
2213- dpt .arange (np . prod (shape ), dtype = "int" , sycl_queue = q ),
2212+ dpt .arange (prod (shape ), dtype = "int" , sycl_queue = q ),
22142213 shape ,
22152214 order = order ,
22162215 )
22172216 Y = dpt .tril (X , k = k )
2218- Xnp = np .arange (np . prod (shape ), dtype = "int" ).reshape (shape , order = order )
2217+ Xnp = np .arange (prod (shape ), dtype = "int" ).reshape (shape , order = order )
22192218 Ynp = np .tril (Xnp , k = k )
22202219 assert Y .dtype == Ynp .dtype
22212220 assert X .flags == Y .flags
0 commit comments