2121import dpctl .tensor as dpt
2222
2323
24- def empty_like (A ):
25- return dpt .empty (A .shape , A .dtype , device = A .device )
26-
27-
2824def chebyshev (A , b , x0 , nIters , lMax , lMin , depends = []):
2925 """Chebyshev iterative solver using SYCL routines"""
3026 d = (lMax + lMin ) / 2
@@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
3329 x = dpt .copy (x0 )
3430 exec_queue = A .sycl_queue
3531 assert exec_queue == x .sycl_queue
36- Ax = empty_like (A [:, 0 ])
37- r = empty_like (Ax )
38- p = empty_like (Ax )
32+ Ax = dpt . empty_like (A [:, 0 ])
33+ r = dpt . empty_like (Ax )
34+ p = dpt . empty_like (Ax )
3935
4036 e_x = dpctl .SyclEvent ()
4137 # Ax = A @ x
@@ -131,12 +127,13 @@ def cg_solve(A, b):
131127 converged is False if solver has not converged, or the iteration number
132128 """
133129 exec_queue = A .sycl_queue
134- x = dpt .zeros ( b . shape , dtype = b . dtype )
135- Ap = empty_like (x )
130+ x = dpt .zeros_like ( b )
131+ Ap = dpt . empty_like (x )
136132
137133 all_host_tasks = []
138- r = dpt .copy (b )
139- p = dpt .copy (b )
134+ r = dpt .copy (b ) # synchronous copy
135+ p = dpt .copy (b ) # synchronous copy
136+
140137 rsold = sycl_gemm .norm_squared_blocking (exec_queue , r )
141138 if rsold < 1e-20 :
142139 return (b , 0 )
@@ -147,22 +144,21 @@ def cg_solve(A, b):
147144 e_x = dpctl .SyclEvent ()
148145 for i in range (max_iters ):
149146 # Ap = A @ p
150- he_dot , e_dot = sycl_gemm .gemv (exec_queue , A , p , Ap , depends = [e_p ])
151- all_host_tasks .append (he_dot )
147+ he_gemv , e_gemv = sycl_gemm .gemv (exec_queue , A , p , Ap , depends = [e_p ])
148+ all_host_tasks .append (he_gemv )
152149 # alpha = rsold / dot(p, Ap)
153150 alpha = rsold / sycl_gemm .dot_blocking (
154- exec_queue , p , Ap , depends = [e_dot ]
151+ exec_queue , p , Ap , depends = [e_p , e_gemv ]
155152 )
156153 # x = x + alpha * p
157154 he1_x_update , e1_x_update = sycl_gemm .axpby_inplace (
158- exec_queue , alpha , p , 1 , x , depends = [e_p , e_x ]
155+ exec_queue , alpha , p , 1 , x , depends = [e_x ]
159156 )
160157 all_host_tasks .append (he1_x_update )
161- e_x = e1_x_update
162158
163159 # r = r - alpha * Ap
164160 he2_r_update , e2_r_update = sycl_gemm .axpby_inplace (
165- exec_queue , - alpha , Ap , 1 , r , depends = [ e_p ]
161+ exec_queue , - alpha , Ap , 1 , r
166162 )
167163 all_host_tasks .append (he2_r_update )
168164
0 commit comments