@@ -65,9 +65,6 @@ struct UnaryContigFunctor
6565 if constexpr (UnaryOperatorT::is_constant::value) {
6666 // value of operator is known to be a known constant
6767 constexpr resT const_val = UnaryOperatorT::constant_value;
68- using out_ptrT =
69- sycl::multi_ptr<resT,
70- sycl::access::address_space::global_space>;
7168
7269 auto sg = ndit.get_sub_group ();
7370 std::uint8_t sgSize = sg.get_local_range ()[0 ];
@@ -80,8 +77,11 @@ struct UnaryContigFunctor
8077 sycl::vec<resT, vec_sz> res_vec (const_val);
8178#pragma unroll
8279 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
83- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
84- res_vec);
80+ auto out_multi_ptr = sycl::address_space_cast<
81+ sycl::access::address_space::global_space,
82+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
83+
84+ sg.store <vec_sz>(out_multi_ptr, res_vec);
8585 }
8686 }
8787 else {
@@ -94,13 +94,6 @@ struct UnaryContigFunctor
9494 else if constexpr (UnaryOperatorT::supports_sg_loadstore::value &&
9595 UnaryOperatorT::supports_vec::value)
9696 {
97- using in_ptrT =
98- sycl::multi_ptr<const argT,
99- sycl::access::address_space::global_space>;
100- using out_ptrT =
101- sycl::multi_ptr<resT,
102- sycl::access::address_space::global_space>;
103-
10497 auto sg = ndit.get_sub_group ();
10598 std::uint16_t sgSize = sg.get_local_range ()[0 ];
10699 std::uint16_t max_sgSize = sg.get_max_local_range ()[0 ];
@@ -113,10 +106,16 @@ struct UnaryContigFunctor
113106
114107#pragma unroll
115108 for (std::uint16_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
116- x = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
109+ auto in_multi_ptr = sycl::address_space_cast<
110+ sycl::access::address_space::global_space,
111+ sycl::access::decorated::yes>(&in[base + it * sgSize]);
112+ auto out_multi_ptr = sycl::address_space_cast<
113+ sycl::access::address_space::global_space,
114+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
115+
116+ x = sg.load <vec_sz>(in_multi_ptr);
117117 sycl::vec<resT, vec_sz> res_vec = op (x);
118- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
119- res_vec);
118+ sg.store <vec_sz>(out_multi_ptr, res_vec);
120119 }
121120 }
122121 else {
@@ -141,23 +140,23 @@ struct UnaryContigFunctor
141140
142141 if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
143142 (maxsgSize == sgSize)) {
144- using in_ptrT =
145- sycl::multi_ptr<const argT,
146- sycl::access::address_space::global_space>;
147- using out_ptrT =
148- sycl::multi_ptr<resT,
149- sycl::access::address_space::global_space>;
150143 sycl::vec<argT, vec_sz> arg_vec;
151144
152145#pragma unroll
153146 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
154- arg_vec = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
147+ auto in_multi_ptr = sycl::address_space_cast<
148+ sycl::access::address_space::global_space,
149+ sycl::access::decorated::yes>(&in[base + it * sgSize]);
150+ auto out_multi_ptr = sycl::address_space_cast<
151+ sycl::access::address_space::global_space,
152+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
153+
154+ arg_vec = sg.load <vec_sz>(in_multi_ptr);
155155#pragma unroll
156156 for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
157157 arg_vec[k] = op (arg_vec[k]);
158158 }
159- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
160- arg_vec);
159+ sg.store <vec_sz>(out_multi_ptr, arg_vec);
161160 }
162161 }
163162 else {
@@ -179,24 +178,24 @@ struct UnaryContigFunctor
179178
180179 if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
181180 (maxsgSize == sgSize)) {
182- using in_ptrT =
183- sycl::multi_ptr<const argT,
184- sycl::access::address_space::global_space>;
185- using out_ptrT =
186- sycl::multi_ptr<resT,
187- sycl::access::address_space::global_space>;
188181 sycl::vec<argT, vec_sz> arg_vec;
189182 sycl::vec<resT, vec_sz> res_vec;
190183
191184#pragma unroll
192185 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
193- arg_vec = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
186+ auto in_multi_ptr = sycl::address_space_cast<
187+ sycl::access::address_space::global_space,
188+ sycl::access::decorated::yes>(&in[base + it * sgSize]);
189+ auto out_multi_ptr = sycl::address_space_cast<
190+ sycl::access::address_space::global_space,
191+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
192+
193+ arg_vec = sg.load <vec_sz>(in_multi_ptr);
194194#pragma unroll
195195 for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
196196 res_vec[k] = op (arg_vec[k]);
197197 }
198- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
199- res_vec);
198+ sg.store <vec_sz>(out_multi_ptr, res_vec);
200199 }
201200 }
202201 else {
@@ -365,28 +364,26 @@ struct BinaryContigFunctor
365364
366365 if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
367366 (sgSize == maxsgSize)) {
368- using in_ptrT1 =
369- sycl::multi_ptr<const argT1,
370- sycl::access::address_space::global_space>;
371- using in_ptrT2 =
372- sycl::multi_ptr<const argT2,
373- sycl::access::address_space::global_space>;
374- using out_ptrT =
375- sycl::multi_ptr<resT,
376- sycl::access::address_space::global_space>;
377367 sycl::vec<argT1, vec_sz> arg1_vec;
378368 sycl::vec<argT2, vec_sz> arg2_vec;
379369 sycl::vec<resT, vec_sz> res_vec;
380370
381371#pragma unroll
382372 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
383- arg1_vec =
384- sg.load <vec_sz>(in_ptrT1 (&in1[base + it * sgSize]));
385- arg2_vec =
386- sg.load <vec_sz>(in_ptrT2 (&in2[base + it * sgSize]));
373+ auto in1_multi_ptr = sycl::address_space_cast<
374+ sycl::access::address_space::global_space,
375+ sycl::access::decorated::yes>(&in1[base + it * sgSize]);
376+ auto in2_multi_ptr = sycl::address_space_cast<
377+ sycl::access::address_space::global_space,
378+ sycl::access::decorated::yes>(&in2[base + it * sgSize]);
379+ auto out_multi_ptr = sycl::address_space_cast<
380+ sycl::access::address_space::global_space,
381+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
382+
383+ arg1_vec = sg.load <vec_sz>(in1_multi_ptr);
384+ arg2_vec = sg.load <vec_sz>(in2_multi_ptr);
387385 res_vec = op (arg1_vec, arg2_vec);
388- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
389- res_vec);
386+ sg.store <vec_sz>(out_multi_ptr, res_vec);
390387 }
391388 }
392389 else {
@@ -407,32 +404,30 @@ struct BinaryContigFunctor
407404
408405 if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
409406 (sgSize == maxsgSize)) {
410- using in_ptrT1 =
411- sycl::multi_ptr<const argT1,
412- sycl::access::address_space::global_space>;
413- using in_ptrT2 =
414- sycl::multi_ptr<const argT2,
415- sycl::access::address_space::global_space>;
416- using out_ptrT =
417- sycl::multi_ptr<resT,
418- sycl::access::address_space::global_space>;
419407 sycl::vec<argT1, vec_sz> arg1_vec;
420408 sycl::vec<argT2, vec_sz> arg2_vec;
421409 sycl::vec<resT, vec_sz> res_vec;
422410
423411#pragma unroll
424412 for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
425- arg1_vec =
426- sg.load <vec_sz>(in_ptrT1 (&in1[base + it * sgSize]));
427- arg2_vec =
428- sg.load <vec_sz>(in_ptrT2 (&in2[base + it * sgSize]));
413+ auto in1_multi_ptr = sycl::address_space_cast<
414+ sycl::access::address_space::global_space,
415+ sycl::access::decorated::yes>(&in1[base + it * sgSize]);
416+ auto in2_multi_ptr = sycl::address_space_cast<
417+ sycl::access::address_space::global_space,
418+ sycl::access::decorated::yes>(&in2[base + it * sgSize]);
419+ auto out_multi_ptr = sycl::address_space_cast<
420+ sycl::access::address_space::global_space,
421+ sycl::access::decorated::yes>(&out[base + it * sgSize]);
422+
423+ arg1_vec = sg.load <vec_sz>(in1_multi_ptr);
424+ arg2_vec = sg.load <vec_sz>(in2_multi_ptr);
429425#pragma unroll
430426 for (std::uint8_t vec_id = 0 ; vec_id < vec_sz; ++vec_id) {
431427 res_vec[vec_id] =
432428 op (arg1_vec[vec_id], arg2_vec[vec_id]);
433429 }
434- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
435- res_vec);
430+ sg.store <vec_sz>(out_multi_ptr, res_vec);
436431 }
437432 }
438433 else {
@@ -530,22 +525,24 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
530525 size_t base = gid - sg.get_local_id ()[0 ];
531526
532527 if (base + sgSize < n_elems) {
533- using in_ptrT1 =
534- sycl::multi_ptr<const argT1,
535- sycl::access::address_space::global_space>;
536- using in_ptrT2 =
537- sycl::multi_ptr<const argT2,
538- sycl::access::address_space::global_space>;
539- using res_ptrT =
540- sycl::multi_ptr<resT,
541- sycl::access::address_space::global_space>;
542-
543- const argT1 mat_el = sg.load (in_ptrT1 (&mat[base]));
544- const argT2 vec_el = sg.load (in_ptrT2 (&padded_vec[base % n1]));
528+ auto in1_multi_ptr = sycl::address_space_cast<
529+ sycl::access::address_space::global_space,
530+ sycl::access::decorated::yes>(&mat[base]);
531+
532+ auto in2_multi_ptr = sycl::address_space_cast<
533+ sycl::access::address_space::global_space,
534+ sycl::access::decorated::yes>(&padded_vec[base % n1]);
535+
536+ auto out_multi_ptr = sycl::address_space_cast<
537+ sycl::access::address_space::global_space,
538+ sycl::access::decorated::yes>(&res[base]);
539+
540+ const argT1 mat_el = sg.load (in1_multi_ptr);
541+ const argT2 vec_el = sg.load (in2_multi_ptr);
545542
546543 resT res_el = op (mat_el, vec_el);
547544
548- sg.store (res_ptrT (&res[base]) , res_el);
545+ sg.store (out_multi_ptr , res_el);
549546 }
550547 else {
551548 for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
@@ -592,22 +589,24 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
592589 size_t base = gid - sg.get_local_id ()[0 ];
593590
594591 if (base + sgSize < n_elems) {
595- using in_ptrT1 =
596- sycl::multi_ptr<const argT1,
597- sycl::access::address_space::global_space>;
598- using in_ptrT2 =
599- sycl::multi_ptr<const argT2,
600- sycl::access::address_space::global_space>;
601- using res_ptrT =
602- sycl::multi_ptr<resT,
603- sycl::access::address_space::global_space>;
604-
605- const argT2 mat_el = sg.load (in_ptrT2 (&mat[base]));
606- const argT1 vec_el = sg.load (in_ptrT1 (&padded_vec[base % n1]));
592+ auto in1_multi_ptr = sycl::address_space_cast<
593+ sycl::access::address_space::global_space,
594+ sycl::access::decorated::yes>(&padded_vec[base % n1]);
595+
596+ auto in2_multi_ptr = sycl::address_space_cast<
597+ sycl::access::address_space::global_space,
598+ sycl::access::decorated::yes>(&mat[base]);
599+
600+ auto out_multi_ptr = sycl::address_space_cast<
601+ sycl::access::address_space::global_space,
602+ sycl::access::decorated::yes>(&res[base]);
603+
604+ const argT2 mat_el = sg.load (in2_multi_ptr);
605+ const argT1 vec_el = sg.load (in1_multi_ptr);
607606
608607 resT res_el = op (vec_el, mat_el);
609608
610- sg.store (res_ptrT (&res[base]) , res_el);
609+ sg.store (out_multi_ptr , res_el);
611610 }
612611 else {
613612 for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
0 commit comments