@@ -146,9 +146,9 @@ struct ReductionOverGroupWithAtomicFunctor
146146
147147 void operator ()(sycl::nd_item<1 > it) const
148148 {
149- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
150- const size_t iter_gid = it.get_global_id (0 ) / red_gws_ ;
151- const size_t reduction_batch_id = get_reduction_batch_id (it);
149+ const size_t iter_gid = it.get_group (0 ) % iter_gws_;
150+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_ ;
151+
152152 const size_t reduction_lid = it.get_local_id (0 );
153153 const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
154154
@@ -204,14 +204,6 @@ struct ReductionOverGroupWithAtomicFunctor
204204 }
205205 }
206206 }
207-
208- private:
209- size_t get_reduction_batch_id (sycl::nd_item<1 > const &it) const
210- {
211- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
212- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
213- return reduction_batch_id;
214- }
215207};
216208
217209typedef sycl::event (*sum_reduction_strided_impl_fn_ptr)(
@@ -241,6 +233,12 @@ class sum_reduction_seq_strided_krn;
241233template <typename T1, typename T2, typename T3, typename T4, typename T5>
242234class sum_reduction_seq_contig_krn ;
243235
236+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
237+ class sum_reduction_axis0_over_group_with_atomics_contig_krn ;
238+
239+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
240+ class sum_reduction_axis1_over_group_with_atomics_contig_krn ;
241+
244242using dpctl::tensor::sycl_utils::choose_workgroup_size;
245243
246244template <typename argTy, typename resTy>
@@ -344,20 +342,6 @@ sycl::event sum_reduction_over_group_with_atomics_strided_impl(
344342 (reduction_nelems + reductions_per_wi * wg - 1 ) /
345343 (reductions_per_wi * wg);
346344
347- if (reduction_groups > 1 ) {
348- const size_t &max_wg =
349- d.get_info <sycl::info::device::max_work_group_size>();
350-
351- if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
352- wg = max_wg;
353- reductions_per_wi =
354- std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg);
355- reduction_groups =
356- (reduction_nelems + reductions_per_wi * wg - 1 ) /
357- (reductions_per_wi * wg);
358- }
359- }
360-
361345 auto globalRange =
362346 sycl::range<1 >{iter_nelems * reduction_groups * wg};
363347 auto localRange = sycl::range<1 >{wg};
@@ -395,7 +379,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)(
395379
396380/* @brief Reduce rows in a matrix */
397381template <typename argTy, typename resTy>
398- sycl::event sum_reduction_over_group_with_atomics_contig_impl (
382+ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl (
399383 sycl::queue exec_q,
400384 size_t iter_nelems, // number of reductions (num. of rows in a matrix
401385 // when reducing over rows)
@@ -417,7 +401,7 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
417401
418402 const sycl::device &d = exec_q.get_device ();
419403 const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
420- size_t wg = choose_workgroup_size<2 >(reduction_nelems, sg_sizes);
404+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
421405
422406 if (reduction_nelems < wg) {
423407 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
@@ -463,11 +447,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
463447 RowsIndexerT, NoOpIndexerT>;
464448 using ReductionIndexerT = NoOpIndexerT;
465449
466- RowsIndexerT columns_indexer {
450+ RowsIndexerT rows_indexer {
467451 0 , static_cast <py::ssize_t >(iter_nelems),
468452 static_cast <py::ssize_t >(reduction_nelems)};
469453 NoOpIndexerT result_indexer{};
470- InputOutputIterIndexerT in_out_iter_indexer{columns_indexer ,
454+ InputOutputIterIndexerT in_out_iter_indexer{rows_indexer ,
471455 result_indexer};
472456 ReductionIndexerT reduction_indexer{};
473457
@@ -481,27 +465,95 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
481465 (reduction_nelems + reductions_per_wi * wg - 1 ) /
482466 (reductions_per_wi * wg);
483467
484- if (reduction_groups > 1 ) {
485- const size_t &max_wg =
486- d.get_info <sycl::info::device::max_work_group_size>();
487-
488- if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
489- wg = max_wg;
490- reductions_per_wi =
491- std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg);
492- reduction_groups =
493- (reduction_nelems + reductions_per_wi * wg - 1 ) /
494- (reductions_per_wi * wg);
495- }
496- }
468+ auto globalRange =
469+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
470+ auto localRange = sycl::range<1 >{wg};
471+
472+ using KernelName =
473+ class sum_reduction_axis1_over_group_with_atomics_contig_krn <
474+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
475+ ReductionIndexerT>;
476+
477+ cgh.parallel_for <KernelName>(
478+ sycl::nd_range<1 >(globalRange, localRange),
479+ ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
480+ InputOutputIterIndexerT,
481+ ReductionIndexerT>(
482+ arg_tp, res_tp, ReductionOpT (), identity_val,
483+ in_out_iter_indexer, reduction_indexer, reduction_nelems,
484+ iter_nelems, reductions_per_wi));
485+ });
486+
487+ return comp_ev;
488+ }
489+ }
490+
491+ /* @brief Reduce rows in a matrix */
492+ template <typename argTy, typename resTy>
493+ sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl (
494+ sycl::queue exec_q,
495+ size_t iter_nelems, // number of reductions (num. of cols in a matrix
496+ // when reducing over cols)
497+ size_t reduction_nelems, // size of each reduction (length of cols, i.e.
498+ // number of rows)
499+ const char *arg_cp,
500+ char *res_cp,
501+ py::ssize_t iter_arg_offset,
502+ py::ssize_t iter_res_offset,
503+ py::ssize_t reduction_arg_offset,
504+ const std::vector<sycl::event> &depends)
505+ {
506+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
507+ iter_arg_offset + reduction_arg_offset;
508+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
509+
510+ using ReductionOpT = sycl::plus<resTy>;
511+ constexpr resTy identity_val = resTy{0 };
512+
513+ const sycl::device &d = exec_q.get_device ();
514+ const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
515+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
516+
517+ {
518+ sycl::event res_init_ev = exec_q.fill <resTy>(
519+ res_tp, resTy (identity_val), iter_nelems, depends);
520+
521+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
522+ cgh.depends_on (res_init_ev);
523+
524+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
525+ using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
526+ using InputOutputIterIndexerT =
527+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
528+ NoOpIndexerT, NoOpIndexerT>;
529+ using ReductionIndexerT = ColsIndexerT;
530+
531+ NoOpIndexerT columns_indexer{};
532+ NoOpIndexerT result_indexer{};
533+ InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
534+ result_indexer};
535+ ReductionIndexerT reduction_indexer{
536+ 0 , /* size */ static_cast <py::ssize_t >(reduction_nelems),
537+ /* step */ static_cast <py::ssize_t >(iter_nelems)};
538+
539+ constexpr size_t preferrered_reductions_per_wi = 8 ;
540+ size_t reductions_per_wi =
541+ (reduction_nelems < preferrered_reductions_per_wi * wg)
542+ ? std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg)
543+ : preferrered_reductions_per_wi;
544+
545+ size_t reduction_groups =
546+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
547+ (reductions_per_wi * wg);
497548
498549 auto globalRange =
499550 sycl::range<1 >{iter_nelems * reduction_groups * wg};
500551 auto localRange = sycl::range<1 >{wg};
501552
502- using KernelName = class sum_reduction_over_group_with_atomics_krn <
503- argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
504- ReductionIndexerT>;
553+ using KernelName =
554+ class sum_reduction_axis0_over_group_with_atomics_contig_krn <
555+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
556+ ReductionIndexerT>;
505557
506558 cgh.parallel_for <KernelName>(
507559 sycl::nd_range<1 >(globalRange, localRange),
@@ -558,14 +610,13 @@ struct ReductionOverGroupNoAtomicFunctor
558610
559611 void operator ()(sycl::nd_item<1 > it) const
560612 {
561-
562- const size_t red_gws_ = it.get_global_range (0 ) / iter_gws_;
563- const size_t iter_gid = it.get_global_id (0 ) / red_gws_;
564- const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
565- const size_t reduction_batch_id = it.get_group (0 ) % n_reduction_groups;
566613 const size_t reduction_lid = it.get_local_id (0 );
567614 const size_t wg = it.get_local_range (0 ); // 0 <= reduction_lid < wg
568615
616+ const size_t iter_gid = it.get_group (0 ) % iter_gws_;
617+ const size_t reduction_batch_id = it.get_group (0 ) / iter_gws_;
618+ const size_t n_reduction_groups = it.get_group_range (0 ) / iter_gws_;
619+
569620 // work-items sums over input with indices
570621 // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg
571622 // + reduction_lid
@@ -1079,15 +1130,34 @@ struct SumOverAxisTempsStridedFactory
10791130};
10801131
10811132template <typename fnT, typename srcTy, typename dstTy>
1082- struct SumOverAxisAtomicContigFactory
1133+ struct SumOverAxis1AtomicContigFactory
1134+ {
1135+ fnT get () const
1136+ {
1137+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1138+ srcTy, dstTy>::is_defined)
1139+ {
1140+ return dpctl::tensor::kernels::
1141+ sum_reduction_axis1_over_group_with_atomics_contig_impl<srcTy,
1142+ dstTy>;
1143+ }
1144+ else {
1145+ return nullptr ;
1146+ }
1147+ }
1148+ };
1149+
1150+ template <typename fnT, typename srcTy, typename dstTy>
1151+ struct SumOverAxis0AtomicContigFactory
10831152{
10841153 fnT get () const
10851154 {
10861155 if constexpr (TypePairSupportDataForSumReductionAtomic<
10871156 srcTy, dstTy>::is_defined)
10881157 {
10891158 return dpctl::tensor::kernels::
1090- sum_reduction_over_group_with_atomics_contig_impl<srcTy, dstTy>;
1159+ sum_reduction_axis0_over_group_with_atomics_contig_impl<srcTy,
1160+ dstTy>;
10911161 }
10921162 else {
10931163 return nullptr ;
0 commit comments