@@ -55,15 +55,12 @@ template <typename T> struct boolean_predicate
5555 }
5656};
5757
58- template <typename inpT,
59- typename outT,
60- typename PredicateT,
61- std::uint8_t wg_dim = 2 >
58+ template <typename inpT, typename outT, typename PredicateT>
6259struct all_reduce_wg_contig
6360{
64- void operator ()(sycl::nd_item<wg_dim > &ndit,
61+ void operator ()(sycl::nd_item<1 > &ndit,
6562 outT *out,
66- size_t &out_idx,
63+ const size_t &out_idx,
6764 const inpT *start,
6865 const inpT *end) const
6966 {
@@ -82,15 +79,12 @@ struct all_reduce_wg_contig
8279 }
8380};
8481
85- template <typename inpT,
86- typename outT,
87- typename PredicateT,
88- std::uint8_t wg_dim = 2 >
82+ template <typename inpT, typename outT, typename PredicateT>
8983struct any_reduce_wg_contig
9084{
91- void operator ()(sycl::nd_item<wg_dim > &ndit,
85+ void operator ()(sycl::nd_item<1 > &ndit,
9286 outT *out,
93- size_t &out_idx,
87+ const size_t &out_idx,
9488 const inpT *start,
9589 const inpT *end) const
9690 {
@@ -109,9 +103,9 @@ struct any_reduce_wg_contig
109103 }
110104};
111105
112- template <typename T, std:: uint8_t wg_dim = 2 > struct all_reduce_wg_strided
106+ template <typename T> struct all_reduce_wg_strided
113107{
114- void operator ()(sycl::nd_item<wg_dim > &ndit,
108+ void operator ()(sycl::nd_item<1 > &ndit,
115109 T *out,
116110 const size_t &out_idx,
117111 const T &local_val) const
@@ -129,9 +123,9 @@ template <typename T, std::uint8_t wg_dim = 2> struct all_reduce_wg_strided
129123 }
130124};
131125
132- template <typename T, std:: uint8_t wg_dim = 2 > struct any_reduce_wg_strided
126+ template <typename T> struct any_reduce_wg_strided
133127{
134- void operator ()(sycl::nd_item<wg_dim > &ndit,
128+ void operator ()(sycl::nd_item<1 > &ndit,
135129 T *out,
136130 const size_t &out_idx,
137131 const T &local_val) const
@@ -215,35 +209,46 @@ struct ContigBooleanReduction
215209 outT *out_ = nullptr ;
216210 GroupOp group_op_;
217211 size_t reduction_max_gid_ = 0 ;
212+ size_t iter_gws_ = 1 ;
218213 size_t reductions_per_wi = 16 ;
219214
220215public:
221216 ContigBooleanReduction (const argT *inp,
222217 outT *res,
223218 GroupOp group_op,
224219 size_t reduction_size,
220+ size_t iteration_size,
225221 size_t reduction_size_per_wi)
226222 : inp_(inp), out_(res), group_op_(group_op),
227- reduction_max_gid_ (reduction_size),
223+ reduction_max_gid_ (reduction_size), iter_gws_(iteration_size),
228224 reductions_per_wi(reduction_size_per_wi)
229225 {
230226 }
231227
232- void operator ()(sycl::nd_item<2 > it) const
228+ void operator ()(sycl::nd_item<1 > it) const
233229 {
234-
235- size_t reduction_id = it.get_group (0 );
236- size_t reduction_batch_id = it.get_group (1 );
237- size_t wg_size = it.get_local_range (1 );
238-
239- size_t base = reduction_id * reduction_max_gid_;
240- size_t start = base + reduction_batch_id * wg_size * reductions_per_wi;
241- size_t end = std::min ((start + (reductions_per_wi * wg_size)),
242- base + reduction_max_gid_);
230+ const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
231+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_;
232+ const size_t reduction_batch_id = get_reduction_batch_id (it);
233+ const size_t wg_size = it.get_local_range (0 );
234+
235+ const size_t base = reduction_id * reduction_max_gid_;
236+ const size_t start =
237+ base + reduction_batch_id * wg_size * reductions_per_wi;
238+ const size_t end = std::min ((start + (reductions_per_wi * wg_size)),
239+ base + reduction_max_gid_);
243240 // reduction and atomic operations are performed
244241 // in group_op_
245242 group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
246243 }
244+
245+ private:
246+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
247+ {
248+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
249+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
250+ return reduction_batch_id;
251+ }
247252};
248253
249254typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -332,7 +337,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
332337 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
333338 cgh.depends_on (init_ev);
334339
335- constexpr std::uint8_t group_dim = 2 ;
340+ constexpr std::uint8_t dim = 1 ;
336341
337342 constexpr size_t preferred_reductions_per_wi = 4 ;
338343 size_t reductions_per_wi =
@@ -344,15 +349,14 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
344349 (reduction_nelems + reductions_per_wi * wg - 1 ) /
345350 (reductions_per_wi * wg);
346351
347- auto gws =
348- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
349- auto lws = sycl::range<group_dim>{1 , wg};
352+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
353+ auto lws = sycl::range<dim>{wg};
350354
351355 cgh.parallel_for <
352356 class boolean_reduction_contig_krn <argTy, resTy, GroupOpT>>(
353- sycl::nd_range<group_dim >(gws, lws),
357+ sycl::nd_range<dim >(gws, lws),
354358 ContigBooleanReduction<argTy, resTy, GroupOpT>(
355- arg_tp, res_tp, GroupOpT (), reduction_nelems,
359+ arg_tp, res_tp, GroupOpT (), reduction_nelems, iter_nelems,
356360 reductions_per_wi));
357361 });
358362 }
@@ -404,6 +408,7 @@ struct StridedBooleanReduction
404408 InputOutputIterIndexerT inp_out_iter_indexer_;
405409 InputRedIndexerT inp_reduced_dims_indexer_;
406410 size_t reduction_max_gid_ = 0 ;
411+ size_t iter_gws_ = 1 ;
407412 size_t reductions_per_wi = 16 ;
408413
409414public:
@@ -415,23 +420,24 @@ struct StridedBooleanReduction
415420 InputOutputIterIndexerT arg_res_iter_indexer,
416421 InputRedIndexerT arg_reduced_dims_indexer,
417422 size_t reduction_size,
423+ size_t iteration_size,
418424 size_t reduction_size_per_wi)
419425 : inp_(inp), out_(res), reduction_op_(reduction_op),
420426 group_op_ (group_op), identity_(identity_val),
421427 inp_out_iter_indexer_(arg_res_iter_indexer),
422428 inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
423- reduction_max_gid_(reduction_size),
429+ reduction_max_gid_(reduction_size), iter_gws_(iteration_size),
424430 reductions_per_wi(reduction_size_per_wi)
425431 {
426432 }
427433
428- void operator ()(sycl::nd_item<2 > it) const
434+ void operator ()(sycl::nd_item<1 > it) const
429435 {
430-
431- size_t reduction_id = it.get_group (0 );
432- size_t reduction_batch_id = it. get_group ( 1 );
433- size_t reduction_lid = it.get_local_id (1 );
434- size_t wg_size = it.get_local_range (1 );
436+ const size_t red_gws_ = it. get_global_range ( 0 ) / iter_gws_;
437+ const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
438+ const size_t reduction_batch_id = get_reduction_batch_id (it );
439+ const size_t reduction_lid = it.get_local_id (0 );
440+ const size_t wg_size = it.get_local_range (0 );
435441
436442 auto inp_out_iter_offsets_ = inp_out_iter_indexer_ (reduction_id);
437443 const py::ssize_t &inp_iter_offset =
@@ -442,26 +448,34 @@ struct StridedBooleanReduction
442448 outT local_red_val (identity_);
443449 size_t arg_reduce_gid0 =
444450 reduction_lid + reduction_batch_id * wg_size * reductions_per_wi;
445- for (size_t m = 0 ; m < reductions_per_wi; ++m) {
446- size_t arg_reduce_gid = arg_reduce_gid0 + m * wg_size;
447-
448- if (arg_reduce_gid < reduction_max_gid_) {
449- py::ssize_t inp_reduction_offset = static_cast <py::ssize_t >(
450- inp_reduced_dims_indexer_ (arg_reduce_gid));
451- py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
451+ size_t arg_reduce_gid_max = std::min (
452+ reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg_size);
453+ for (size_t arg_reduce_gid = arg_reduce_gid0;
454+ arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg_size)
455+ {
456+ py::ssize_t inp_reduction_offset = static_cast <py::ssize_t >(
457+ inp_reduced_dims_indexer_ (arg_reduce_gid));
458+ py::ssize_t inp_offset = inp_iter_offset + inp_reduction_offset;
452459
453- // must convert to boolean first to handle nans
454- using dpctl::tensor::type_utils::convert_impl;
455- bool val = convert_impl<bool , argT>(inp_[inp_offset]);
456- ReductionOp op = reduction_op_;
460+ // must convert to boolean first to handle nans
461+ using dpctl::tensor::type_utils::convert_impl;
462+ bool val = convert_impl<bool , argT>(inp_[inp_offset]);
463+ ReductionOp op = reduction_op_;
457464
458- local_red_val = op (local_red_val, static_cast <outT>(val));
459- }
465+ local_red_val = op (local_red_val, static_cast <outT>(val));
460466 }
461467 // reduction and atomic operations are performed
462468 // in group_op_
463469 group_op_ (it, out_, out_iter_offset, local_red_val);
464470 }
471+
472+ private:
473+ size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
474+ {
475+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
476+ const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
477+ return reduction_batch_id;
478+ }
465479};
466480
467481template <typename T1,
@@ -564,7 +578,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
564578 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
565579 cgh.depends_on (res_init_ev);
566580
567- constexpr std::uint8_t group_dim = 2 ;
581+ constexpr std::uint8_t dim = 1 ;
568582
569583 using InputOutputIterIndexerT =
570584 dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
@@ -587,20 +601,19 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
587601 (reduction_nelems + reductions_per_wi * wg - 1 ) /
588602 (reductions_per_wi * wg);
589603
590- auto gws =
591- sycl::range<group_dim>{iter_nelems, reduction_groups * wg};
592- auto lws = sycl::range<group_dim>{1 , wg};
604+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
605+ auto lws = sycl::range<dim>{wg};
593606
594607 cgh.parallel_for <class boolean_reduction_strided_krn <
595608 argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
596609 ReductionIndexerT>>(
597- sycl::nd_range<group_dim >(gws, lws),
610+ sycl::nd_range<dim >(gws, lws),
598611 StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
599612 InputOutputIterIndexerT,
600613 ReductionIndexerT>(
601614 arg_tp, res_tp, RedOpT (), GroupOpT (), identity_val,
602615 in_out_iter_indexer, reduction_indexer, reduction_nelems,
603- reductions_per_wi));
616+ iter_nelems, reductions_per_wi));
604617 });
605618 }
606619 return red_ev;
0 commit comments