3434#include " pybind11/pybind11.h"
3535
3636#include " utils/offset_utils.hpp"
37+ #include " utils/sycl_utils.hpp"
3738#include " utils/type_dispatch.hpp"
3839#include " utils/type_utils.hpp"
3940
@@ -227,9 +228,8 @@ struct ContigBooleanReduction
227228
228229 void operator ()(sycl::nd_item<1 > it) const
229230 {
230- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
231- const size_t reduction_id = it.get_global_id (0 ) / red_gws_;
232- const size_t reduction_batch_id = get_reduction_batch_id (it);
231+ const size_t reduction_id = it.get_group (0 ) % iter_gws_;
232+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_;
233233 const size_t wg_size = it.get_local_range (0 );
234234
235235 const size_t base = reduction_id * reduction_max_gid_;
@@ -241,14 +241,6 @@ struct ContigBooleanReduction
241241 // in group_op_
242242 group_op_ (it, out_, reduction_id, inp_ + start, inp_ + end);
243243 }
244-
245- private:
246- size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
247- {
248- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
249- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
250- return reduction_batch_id;
251- }
252244};
253245
254246typedef sycl::event (*boolean_reduction_contig_impl_fn_ptr)(
@@ -268,17 +260,19 @@ class boolean_reduction_contig_krn;
268260template <typename T1, typename T2, typename T3, typename T4, typename T5>
269261class boolean_reduction_seq_contig_krn ;
270262
263+ using dpctl::tensor::sycl_utils::choose_workgroup_size;
264+
271265template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
272266sycl::event
273- boolean_reduction_contig_impl (sycl::queue exec_q,
274- size_t iter_nelems,
275- size_t reduction_nelems,
276- const char *arg_cp,
277- char *res_cp,
278- py::ssize_t iter_arg_offset,
279- py::ssize_t iter_res_offset,
280- py::ssize_t red_arg_offset,
281- const std::vector<sycl::event> &depends)
267+ boolean_reduction_axis1_contig_impl (sycl::queue exec_q,
268+ size_t iter_nelems,
269+ size_t reduction_nelems,
270+ const char *arg_cp,
271+ char *res_cp,
272+ py::ssize_t iter_arg_offset,
273+ py::ssize_t iter_res_offset,
274+ py::ssize_t red_arg_offset,
275+ const std::vector<sycl::event> &depends)
282276{
283277 const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
284278 iter_arg_offset + red_arg_offset;
@@ -288,8 +282,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
288282
289283 const sycl::device &d = exec_q.get_device ();
290284 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
291- size_t wg =
292- 4 * (*std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
285+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
293286
294287 sycl::event red_ev;
295288 if (reduction_nelems < wg) {
@@ -322,18 +315,8 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
322315 });
323316 }
324317 else {
325- sycl::event init_ev = exec_q.submit ([&](sycl::handler &cgh) {
326- using IndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
327-
328- IndexerT res_indexer{};
329-
330- cgh.depends_on (depends);
331-
332- cgh.parallel_for (sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
333- auto res_offset = res_indexer (id[0 ]);
334- res_tp[res_offset] = identity_val;
335- });
336- });
318+ sycl::event init_ev = exec_q.fill <resTy>(res_tp, resTy (identity_val),
319+ iter_nelems, depends);
337320 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
338321 cgh.depends_on (init_ev);
339322
@@ -363,7 +346,7 @@ boolean_reduction_contig_impl(sycl::queue exec_q,
363346 return red_ev;
364347}
365348
366- template <typename fnT, typename srcTy> struct AllContigFactory
349+ template <typename fnT, typename srcTy> struct AllAxis1ContigFactory
367350{
368351 fnT get () const
369352 {
@@ -372,12 +355,12 @@ template <typename fnT, typename srcTy> struct AllContigFactory
372355 using GroupOpT =
373356 all_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
374357
375- return dpctl::tensor::kernels::boolean_reduction_contig_impl <
358+ return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl <
376359 srcTy, resTy, RedOpT, GroupOpT>;
377360 }
378361};
379362
380- template <typename fnT, typename srcTy> struct AnyContigFactory
363+ template <typename fnT, typename srcTy> struct AnyAxis1ContigFactory
381364{
382365 fnT get () const
383366 {
@@ -386,7 +369,7 @@ template <typename fnT, typename srcTy> struct AnyContigFactory
386369 using GroupOpT =
387370 any_reduce_wg_contig<srcTy, resTy, boolean_predicate<srcTy>>;
388371
389- return dpctl::tensor::kernels::boolean_reduction_contig_impl <
372+ return dpctl::tensor::kernels::boolean_reduction_axis1_contig_impl <
390373 srcTy, resTy, RedOpT, GroupOpT>;
391374 }
392375};
@@ -433,9 +416,9 @@ struct StridedBooleanReduction
433416
434417 void operator ()(sycl::nd_item<1 > it) const
435418 {
436- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
437- const size_t reduction_id = it.get_global_id (0 ) / red_gws_ ;
438- const size_t reduction_batch_id = get_reduction_batch_id (it);
419+ const size_t reduction_id = it.get_group (0 ) % iter_gws_;
420+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_ ;
421+
439422 const size_t reduction_lid = it.get_local_id (0 );
440423 const size_t wg_size = it.get_local_range (0 );
441424
@@ -468,13 +451,112 @@ struct StridedBooleanReduction
468451 // in group_op_
469452 group_op_ (it, out_, out_iter_offset, local_red_val);
470453 }
454+ };
455+
456+ template <typename T1,
457+ typename T2,
458+ typename T3,
459+ typename T4,
460+ typename T5,
461+ typename T6>
462+ class boolean_reduction_axis0_contig_krn ;
463+
464+ template <typename argTy, typename resTy, typename RedOpT, typename GroupOpT>
465+ sycl::event
466+ boolean_reduction_axis0_contig_impl (sycl::queue exec_q,
467+ size_t iter_nelems,
468+ size_t reduction_nelems,
469+ const char *arg_cp,
470+ char *res_cp,
471+ py::ssize_t iter_arg_offset,
472+ py::ssize_t iter_res_offset,
473+ py::ssize_t red_arg_offset,
474+ const std::vector<sycl::event> &depends)
475+ {
476+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
477+ iter_arg_offset + red_arg_offset;
478+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
479+
480+ constexpr resTy identity_val = sycl::known_identity<RedOpT, resTy>::value;
481+
482+ const sycl::device &d = exec_q.get_device ();
483+ const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
484+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
471485
472- private:
473- size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
474486 {
475- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
476- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
477- return reduction_batch_id;
487+ sycl::event init_ev = exec_q.fill <resTy>(res_tp, resTy (identity_val),
488+ iter_nelems, depends);
489+ sycl::event red_ev = exec_q.submit ([&](sycl::handler &cgh) {
490+ cgh.depends_on (init_ev);
491+
492+ constexpr std::uint8_t dim = 1 ;
493+
494+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
495+ using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
496+ using InputOutputIterIndexerT =
497+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
498+ NoOpIndexerT, NoOpIndexerT>;
499+ using ReductionIndexerT = ColsIndexerT;
500+
501+ NoOpIndexerT columns_indexer{};
502+ NoOpIndexerT result_indexer{};
503+ InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
504+ result_indexer};
505+ ReductionIndexerT reduction_indexer{
506+ 0 , static_cast <py::ssize_t >(reduction_nelems),
507+ static_cast <py::ssize_t >(iter_nelems)};
508+
509+ constexpr size_t preferred_reductions_per_wi = 4 ;
510+ size_t reductions_per_wi =
511+ (reduction_nelems < preferred_reductions_per_wi * wg)
512+ ? ((reduction_nelems + wg - 1 ) / wg)
513+ : preferred_reductions_per_wi;
514+
515+ size_t reduction_groups =
516+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
517+ (reductions_per_wi * wg);
518+
519+ auto gws = sycl::range<dim>{iter_nelems * reduction_groups * wg};
520+ auto lws = sycl::range<dim>{wg};
521+
522+ cgh.parallel_for <class boolean_reduction_axis0_contig_krn <
523+ argTy, resTy, RedOpT, GroupOpT, InputOutputIterIndexerT,
524+ ReductionIndexerT>>(
525+ sycl::nd_range<dim>(gws, lws),
526+ StridedBooleanReduction<argTy, resTy, RedOpT, GroupOpT,
527+ InputOutputIterIndexerT,
528+ ReductionIndexerT>(
529+ arg_tp, res_tp, RedOpT (), GroupOpT (), identity_val,
530+ in_out_iter_indexer, reduction_indexer, reduction_nelems,
531+ iter_nelems, reductions_per_wi));
532+ });
533+ return red_ev;
534+ }
535+ }
536+
537+ template <typename fnT, typename srcTy> struct AllAxis0ContigFactory
538+ {
539+ fnT get () const
540+ {
541+ using resTy = std::int32_t ;
542+ using RedOpT = sycl::logical_and<resTy>;
543+ using GroupOpT = all_reduce_wg_strided<resTy>;
544+
545+ return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl<
546+ srcTy, resTy, RedOpT, GroupOpT>;
547+ }
548+ };
549+
550+ template <typename fnT, typename srcTy> struct AnyAxis0ContigFactory
551+ {
552+ fnT get () const
553+ {
554+ using resTy = std::int32_t ;
555+ using RedOpT = sycl::logical_or<resTy>;
556+ using GroupOpT = any_reduce_wg_strided<resTy>;
557+
558+ return dpctl::tensor::kernels::boolean_reduction_axis0_contig_impl<
559+ srcTy, resTy, RedOpT, GroupOpT>;
478560 }
479561};
480562
@@ -527,8 +609,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
527609
528610 const sycl::device &d = exec_q.get_device ();
529611 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
530- size_t wg =
531- 4 * (*std::max_element (std::begin (sg_sizes), std::end (sg_sizes)));
612+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
532613
533614 sycl::event red_ev;
534615 if (reduction_nelems < wg) {
@@ -558,7 +639,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
558639 });
559640 }
560641 else {
561- sycl::event res_init_ev = exec_q.submit ([&](sycl::handler &cgh) {
642+ sycl::event init_ev = exec_q.submit ([&](sycl::handler &cgh) {
562643 using IndexerT =
563644 dpctl::tensor::offset_utils::UnpackedStridedIndexer;
564645
@@ -576,7 +657,7 @@ boolean_reduction_strided_impl(sycl::queue exec_q,
576657 });
577658 });
578659 red_ev = exec_q.submit ([&](sycl::handler &cgh) {
579- cgh.depends_on (res_init_ev );
660+ cgh.depends_on (init_ev );
580661
581662 constexpr std::uint8_t dim = 1 ;
582663
0 commit comments