@@ -71,13 +71,11 @@ py_gemv(sycl::queue q,
7171 " USM allocation is not bound to the context in execution queue." );
7272 }
7373
74- int mat_flags = matrix.get_flags ();
75- int v_flags = vector.get_flags ();
76- int r_flags = result.get_flags ();
74+ auto &api = dpctl::detail::dpctl_capi::get ();
7775
78- if (!((mat_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
79- (v_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
80- (r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))))
76+ if (!((matrix. is_c_contiguous ( )) &&
77+ (vector. is_c_contiguous () || vector. is_f_contiguous ( )) &&
78+ (result. is_c_contiguous () || result. is_f_contiguous ( ))))
8179 {
8280 throw std::runtime_error (" Arrays must be contiguous." );
8381 }
@@ -87,8 +85,8 @@ py_gemv(sycl::queue q,
8785 int r_typenum = result.get_typenum ();
8886
8987 if ((mat_typenum != v_typenum) || (r_typenum != v_typenum) ||
90- !((v_typenum == UAR_DOUBLE ) || (v_typenum == UAR_FLOAT ) ||
91- (v_typenum == UAR_CDOUBLE ) || (v_typenum == UAR_CFLOAT )))
88+ !((v_typenum == api. UAR_DOUBLE_ ) || (v_typenum == api. UAR_FLOAT_ ) ||
89+ (v_typenum == api. UAR_CDOUBLE_ ) || (v_typenum == api. UAR_CFLOAT_ )))
9290 {
9391 std::cout << " Found: [" << mat_typenum << " , " << v_typenum << " , "
9492 << r_typenum << " ]" << std::endl;
@@ -103,7 +101,7 @@ py_gemv(sycl::queue q,
103101 char *r_typeless_ptr = result.get_data ();
104102
105103 sycl::event res_ev;
106- if (v_typenum == UAR_DOUBLE ) {
104+ if (v_typenum == api. UAR_DOUBLE_ ) {
107105 using T = double ;
108106 sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
109107 q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -112,7 +110,7 @@ py_gemv(sycl::queue q,
112110 reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
113111 res_ev = gemv_ev;
114112 }
115- else if (v_typenum == UAR_FLOAT ) {
113+ else if (v_typenum == api. UAR_FLOAT_ ) {
116114 using T = float ;
117115 sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
118116 q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -121,7 +119,7 @@ py_gemv(sycl::queue q,
121119 reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
122120 res_ev = gemv_ev;
123121 }
124- else if (v_typenum == UAR_CDOUBLE ) {
122+ else if (v_typenum == api. UAR_CDOUBLE_ ) {
125123 using T = std::complex <double >;
126124 sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
127125 q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -130,7 +128,7 @@ py_gemv(sycl::queue q,
130128 reinterpret_cast <T *>(r_typeless_ptr), 1 , depends);
131129 res_ev = gemv_ev;
132130 }
133- else if (v_typenum == UAR_CFLOAT ) {
131+ else if (v_typenum == api. UAR_CFLOAT_ ) {
134132 using T = std::complex <float >;
135133 sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv (
136134 q, oneapi::mkl::transpose::nontrans, n, m, T (1 ),
@@ -185,21 +183,18 @@ py_sub(sycl::queue q,
185183 throw std::runtime_error (" Vectors must have the same length" );
186184 }
187185
188- if (q.get_context () != in_v1.get_queue ().get_context () ||
189- q.get_context () != in_v2.get_queue ().get_context () ||
190- q.get_context () != out_r.get_queue ().get_context ())
186+ if (!dpctl::utils::queues_are_compatible (
187+ q, {in_v1.get_queue (), in_v2.get_queue (), out_r.get_queue ()}))
191188 {
192189 throw std::runtime_error (
193190 " USM allocation is not bound to the context in execution queue" );
194191 }
195192
196- int in_v1_flags = in_v1.get_flags ();
197- int in_v2_flags = in_v2.get_flags ();
198- int out_r_flags = out_r.get_flags ();
193+ auto &api = dpctl::detail::dpctl_capi::get ();
199194
200- if (!((in_v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
201- (in_v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS )) &&
202- (out_r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))))
195+ if (!((in_v1. is_c_contiguous () || in_v1. is_f_contiguous ( )) &&
196+ (in_v2. is_c_contiguous () || in_v2. is_f_contiguous ( )) &&
197+ (out_r. is_c_contiguous () || out_r. is_f_contiguous ( ))))
203198 {
204199 throw std::runtime_error (" Vectors must be contiguous." );
205200 }
@@ -209,8 +204,10 @@ py_sub(sycl::queue q,
209204 int out_r_typenum = out_r.get_typenum ();
210205
211206 if ((in_v2_typenum != in_v1_typenum) || (out_r_typenum != in_v1_typenum) ||
212- !((in_v1_typenum == UAR_DOUBLE) || (in_v1_typenum == UAR_FLOAT) ||
213- (in_v1_typenum == UAR_CDOUBLE) || (in_v1_typenum == UAR_CFLOAT)))
207+ !((in_v1_typenum == api.UAR_DOUBLE_ ) ||
208+ (in_v1_typenum == api.UAR_FLOAT_ ) ||
209+ (in_v1_typenum == api.UAR_CDOUBLE_ ) ||
210+ (in_v1_typenum == api.UAR_CFLOAT_ )))
214211 {
215212 throw std::runtime_error (
216213 " Only real and complex floating point arrays are supported." );
@@ -221,22 +218,22 @@ py_sub(sycl::queue q,
221218 char *out_r_typeless_ptr = out_r.get_data ();
222219
223220 sycl::event res_ev;
224- if (out_r_typenum == UAR_DOUBLE ) {
221+ if (out_r_typenum == api. UAR_DOUBLE_ ) {
225222 using T = double ;
226223 res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
227224 out_r_typeless_ptr, depends);
228225 }
229- else if (out_r_typenum == UAR_FLOAT ) {
226+ else if (out_r_typenum == api. UAR_FLOAT_ ) {
230227 using T = float ;
231228 res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
232229 out_r_typeless_ptr, depends);
233230 }
234- else if (out_r_typenum == UAR_CDOUBLE ) {
231+ else if (out_r_typenum == api. UAR_CDOUBLE_ ) {
235232 using T = std::complex <double >;
236233 res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
237234 out_r_typeless_ptr, depends);
238235 }
239- else if (out_r_typenum == UAR_CFLOAT ) {
236+ else if (out_r_typenum == api. UAR_CFLOAT_ ) {
240237 using T = std::complex <float >;
241238 res_ev = sub_impl<T>(q, n, in_v1_typeless_ptr, in_v2_typeless_ptr,
242239 out_r_typeless_ptr, depends);
@@ -294,18 +291,15 @@ py_axpby_inplace(sycl::queue q,
294291 throw std::runtime_error (" Vectors must have the same length" );
295292 }
296293
297- if (q.get_context () != x.get_queue ().get_context () ||
298- q.get_context () != y.get_queue ().get_context ())
294+ if (!dpctl::utils::queues_are_compatible (q, {x.get_queue (), y.get_queue ()}))
299295 {
300296 throw std::runtime_error (
301297 " USM allocation is not bound to the context in execution queue" );
302298 }
299+ auto &api = dpctl::detail::dpctl_capi::get ();
303300
304- int x_flags = x.get_flags ();
305- int y_flags = y.get_flags ();
306-
307- if (!((x_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) &&
308- (y_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS))))
301+ if (!((x.is_c_contiguous () || x.is_f_contiguous ()) &&
302+ (y.is_c_contiguous () || y.is_f_contiguous ())))
309303 {
310304 throw std::runtime_error (" Vectors must be contiguous." );
311305 }
@@ -314,8 +308,8 @@ py_axpby_inplace(sycl::queue q,
314308 int y_typenum = y.get_typenum ();
315309
316310 if ((x_typenum != y_typenum) ||
317- !((x_typenum == UAR_DOUBLE ) || (x_typenum == UAR_FLOAT ) ||
318- (x_typenum == UAR_CDOUBLE ) || (x_typenum == UAR_CFLOAT )))
311+ !((x_typenum == api. UAR_DOUBLE_ ) || (x_typenum == api. UAR_FLOAT_ ) ||
312+ (x_typenum == api. UAR_CDOUBLE_ ) || (x_typenum == api. UAR_CFLOAT_ )))
319313 {
320314 throw std::runtime_error (
321315 " Only real and complex floating point arrays are supported." );
@@ -325,22 +319,22 @@ py_axpby_inplace(sycl::queue q,
325319 char *y_typeless_ptr = y.get_data ();
326320
327321 sycl::event res_ev;
328- if (x_typenum == UAR_DOUBLE ) {
322+ if (x_typenum == api. UAR_DOUBLE_ ) {
329323 using T = double ;
330324 res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
331325 y_typeless_ptr, depends);
332326 }
333- else if (x_typenum == UAR_FLOAT ) {
327+ else if (x_typenum == api. UAR_FLOAT_ ) {
334328 using T = float ;
335329 res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
336330 y_typeless_ptr, depends);
337331 }
338- else if (x_typenum == UAR_CDOUBLE ) {
332+ else if (x_typenum == api. UAR_CDOUBLE_ ) {
339333 using T = std::complex <double >;
340334 res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
341335 y_typeless_ptr, depends);
342336 }
343- else if (x_typenum == UAR_CFLOAT ) {
337+ else if (x_typenum == api. UAR_CFLOAT_ ) {
344338 using T = std::complex <float >;
345339 res_ev = axpby_inplace_impl<T>(q, n, a, x_typeless_ptr, b,
346340 y_typeless_ptr, depends);
@@ -393,18 +387,20 @@ py::object py_norm_squared_blocking(sycl::queue q,
393387
394388 int r_flags = r.get_flags ();
395389
396- if (!(r_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS ))) {
390+ if (!(r. is_c_contiguous () || r. is_f_contiguous ( ))) {
397391 throw std::runtime_error (" Vector must be contiguous." );
398392 }
399393
400- if (q. get_context () != r.get_queue (). get_context ( )) {
394+ if (! dpctl::utils::queues_are_compatible (q, { r.get_queue ()} )) {
401395 throw std::runtime_error (
402396 " USM allocation is not bound to the context in execution queue" );
403397 }
404398
399+ auto &api = dpctl::detail::dpctl_capi::get ();
400+
405401 int r_typenum = r.get_typenum ();
406- if ((r_typenum != UAR_DOUBLE ) && (r_typenum != UAR_FLOAT ) &&
407- (r_typenum != UAR_CDOUBLE ) && (r_typenum != UAR_CFLOAT ))
402+ if ((r_typenum != api. UAR_DOUBLE_ ) && (r_typenum != api. UAR_FLOAT_ ) &&
403+ (r_typenum != api. UAR_CDOUBLE_ ) && (r_typenum != api. UAR_CFLOAT_ ))
408404 {
409405 throw std::runtime_error (
410406 " Only real and complex floating point arrays are supported." );
@@ -413,23 +409,23 @@ py::object py_norm_squared_blocking(sycl::queue q,
413409 const char *r_typeless_ptr = r.get_data ();
414410 py::object res;
415411
416- if (r_typenum == UAR_DOUBLE ) {
412+ if (r_typenum == api. UAR_DOUBLE_ ) {
417413 using T = double ;
418414 T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
419415 res = py::float_ (n_sq);
420416 }
421- else if (r_typenum == UAR_FLOAT ) {
417+ else if (r_typenum == api. UAR_FLOAT_ ) {
422418 using T = float ;
423419 T n_sq = norm_squared_blocking_impl<T>(q, n, r_typeless_ptr, depends);
424420 res = py::float_ (n_sq);
425421 }
426- else if (r_typenum == UAR_CDOUBLE ) {
422+ else if (r_typenum == api. UAR_CDOUBLE_ ) {
427423 using T = std::complex <double >;
428424 double n_sq = complex_norm_squared_blocking_impl<double >(
429425 q, n, r_typeless_ptr, depends);
430426 res = py::float_ (n_sq);
431427 }
432- else if (r_typenum == UAR_CFLOAT ) {
428+ else if (r_typenum == api. UAR_CFLOAT_ ) {
433429 using T = std::complex <float >;
434430 float n_sq = complex_norm_squared_blocking_impl<float >(
435431 q, n, r_typeless_ptr, depends);
@@ -457,28 +453,27 @@ py::object py_dot_blocking(sycl::queue q,
457453 throw std::runtime_error (" Length of vectors are not the same" );
458454 }
459455
460- int v1_flags = v1.get_flags ();
461- int v2_flags = v2.get_flags ();
462-
463- if (!(v1_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)) ||
464- !(v2_flags & (USM_ARRAY_C_CONTIGUOUS | USM_ARRAY_F_CONTIGUOUS)))
456+ if (!(v1.is_c_contiguous () || v1.is_f_contiguous ()) ||
457+ !(v2.is_c_contiguous () || v2.is_f_contiguous ()))
465458 {
466459 throw std::runtime_error (" Vectors must be contiguous." );
467460 }
468461
469- if (q. get_context () != v1. get_queue (). get_context () ||
470- q. get_context () != v2.get_queue (). get_context ( ))
462+ if (! dpctl::utils::queues_are_compatible (q,
463+ {v1. get_queue (), v2.get_queue ()} ))
471464 {
472465 throw std::runtime_error (
473466 " USM allocation is not bound to the context in execution queue" );
474467 }
475468
469+ auto &api = dpctl::detail::dpctl_capi::get ();
470+
476471 int v1_typenum = v1.get_typenum ();
477472 int v2_typenum = v2.get_typenum ();
478473
479474 if ((v1_typenum != v2_typenum) ||
480- ((v1_typenum != UAR_DOUBLE ) && (v1_typenum != UAR_FLOAT ) &&
481- (v1_typenum != UAR_CDOUBLE ) && (v1_typenum != UAR_CFLOAT )))
475+ ((v1_typenum != api. UAR_DOUBLE_ ) && (v1_typenum != api. UAR_FLOAT_ ) &&
476+ (v1_typenum != api. UAR_CDOUBLE_ ) && (v1_typenum != api. UAR_CFLOAT_ )))
482477 {
483478 throw py::value_error (
484479 " Data types of vectors must be the same. "
@@ -489,7 +484,7 @@ py::object py_dot_blocking(sycl::queue q,
489484 const char *v2_typeless_ptr = v2.get_data ();
490485 py::object res;
491486
492- if (v1_typenum == UAR_DOUBLE ) {
487+ if (v1_typenum == api. UAR_DOUBLE_ ) {
493488 using T = double ;
494489 T *res_usm = sycl::malloc_device<T>(1 , q);
495490 sycl::event dot_ev = oneapi::mkl::blas::row_major::dot (
@@ -500,7 +495,7 @@ py::object py_dot_blocking(sycl::queue q,
500495 sycl::free (res_usm, q);
501496 res = py::float_ (res_v);
502497 }
503- else if (v1_typenum == UAR_FLOAT ) {
498+ else if (v1_typenum == api. UAR_FLOAT_ ) {
504499 using T = float ;
505500 T *res_usm = sycl::malloc_device<T>(1 , q);
506501 sycl::event dot_ev = oneapi::mkl::blas::row_major::dot (
@@ -511,7 +506,7 @@ py::object py_dot_blocking(sycl::queue q,
511506 sycl::free (res_usm, q);
512507 res = py::float_ (res_v);
513508 }
514- else if (v1_typenum == UAR_CDOUBLE ) {
509+ else if (v1_typenum == api. UAR_CDOUBLE_ ) {
515510 using T = std::complex <double >;
516511 T *res_usm = sycl::malloc_device<T>(1 , q);
517512 sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc (
@@ -522,7 +517,7 @@ py::object py_dot_blocking(sycl::queue q,
522517 sycl::free (res_usm, q);
523518 res = py::cast (res_v);
524519 }
525- else if (v1_typenum == UAR_CFLOAT ) {
520+ else if (v1_typenum == api. UAR_CFLOAT_ ) {
526521 using T = std::complex <float >;
527522 T *res_usm = sycl::malloc_device<T>(1 , q);
528523 sycl::event dotc_ev = oneapi::mkl::blas::row_major::dotc (
@@ -563,9 +558,8 @@ int py_cg_solve(sycl::queue exec_q,
563558 " Dimensions of the matrix and vectors are not consistent." );
564559 }
565560
566- bool all_contig = (Amat.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
567- (bvec.get_flags () & USM_ARRAY_C_CONTIGUOUS) &&
568- (xvec.get_flags () & USM_ARRAY_C_CONTIGUOUS);
561+ bool all_contig = (Amat.is_c_contiguous ()) && (bvec.is_c_contiguous ()) &&
562+ (xvec.is_c_contiguous ());
569563 if (!all_contig) {
570564 throw py::value_error (" All inputs must be C-contiguous" );
571565 }
@@ -578,19 +572,20 @@ int py_cg_solve(sycl::queue exec_q,
578572 throw py::value_error (" All arrays must have the same type" );
579573 }
580574
581- if (exec_q.get_context () != Amat.get_queue ().get_context () ||
582- exec_q.get_context () != bvec.get_queue ().get_context () ||
583- exec_q.get_context () != xvec.get_queue ().get_context ())
575+ if (!dpctl::utils::queues_are_compatible (
576+ exec_q, {Amat.get_queue (), bvec.get_queue (), xvec.get_queue ()}))
584577 {
585578 throw std::runtime_error (
586- " USM allocations are not bound to context in execution queue" );
579+ " USM allocation queues are not the same as the execution queue" );
587580 }
588581
589582 const char *A_ch = Amat.get_data ();
590583 const char *b_ch = bvec.get_data ();
591584 char *x_ch = xvec.get_data ();
592585
593- if (A_typenum == UAR_DOUBLE) {
586+ auto &api = dpctl::detail::dpctl_capi::get ();
587+
588+ if (A_typenum == api.UAR_DOUBLE_ ) {
594589 using T = double ;
595590 int iters = cg_solver::cg_solve<T>(
596591 exec_q, n0, reinterpret_cast <const T *>(A_ch),
@@ -599,7 +594,7 @@ int py_cg_solve(sycl::queue exec_q,
599594
600595 return iters;
601596 }
602- else if (A_typenum == UAR_FLOAT ) {
597+ else if (A_typenum == api. UAR_FLOAT_ ) {
603598 using T = float ;
604599 int iters = cg_solver::cg_solve<T>(
605600 exec_q, n0, reinterpret_cast <const T *>(A_ch),
@@ -616,9 +611,6 @@ int py_cg_solve(sycl::queue exec_q,
616611
617612PYBIND11_MODULE (_onemkl, m)
618613{
619- // Import the dpctl extensions
620- import_dpctl ();
621-
622614 m.def (" gemv" , &py_gemv, " Uses oneMKL to compute dot(matrix, vector)" ,
623615 py::arg (" exec_queue" ), py::arg (" Amatrix" ), py::arg (" xvec" ),
624616 py::arg (" resvec" ), py::arg (" depends" ) = py::list ());
0 commit comments