@@ -234,7 +234,10 @@ template <typename T1, typename T2, typename T3, typename T4, typename T5>
234234class sum_reduction_seq_contig_krn ;
235235
236236template <typename T1, typename T2, typename T3, typename T4, typename T5>
237- class sum_reduction_over_group_with_atomics_contig_krn ;
237+ class sum_reduction_axis0_over_group_with_atomics_contig_krn ;
238+
239+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
240+ class sum_reduction_axis1_over_group_with_atomics_contig_krn ;
238241
239242using dpctl::tensor::sycl_utils::choose_workgroup_size;
240243
@@ -390,7 +393,7 @@ typedef sycl::event (*sum_reduction_contig_impl_fn_ptr)(
390393
391394/* @brief Reduce rows in a matrix */
392395template <typename argTy, typename resTy>
393- sycl::event sum_reduction_over_group_with_atomics_contig_impl (
396+ sycl::event sum_reduction_axis1_over_group_with_atomics_contig_impl (
394397 sycl::queue exec_q,
395398 size_t iter_nelems, // number of reductions (num. of rows in a matrix
396399 // when reducing over rows)
@@ -458,11 +461,11 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
458461 RowsIndexerT, NoOpIndexerT>;
459462 using ReductionIndexerT = NoOpIndexerT;
460463
461- RowsIndexerT columns_indexer {
464+ RowsIndexerT rows_indexer {
462465 0 , static_cast <py::ssize_t >(iter_nelems),
463466 static_cast <py::ssize_t >(reduction_nelems)};
464467 NoOpIndexerT result_indexer{};
465- InputOutputIterIndexerT in_out_iter_indexer{columns_indexer ,
468+ InputOutputIterIndexerT in_out_iter_indexer{rows_indexer ,
466469 result_indexer};
467470 ReductionIndexerT reduction_indexer{};
468471
@@ -495,7 +498,102 @@ sycl::event sum_reduction_over_group_with_atomics_contig_impl(
495498 auto localRange = sycl::range<1 >{wg};
496499
497500 using KernelName =
498- class sum_reduction_over_group_with_atomics_contig_krn <
501+ class sum_reduction_axis1_over_group_with_atomics_contig_krn <
502+ argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
503+ ReductionIndexerT>;
504+
505+ cgh.parallel_for <KernelName>(
506+ sycl::nd_range<1 >(globalRange, localRange),
507+ ReductionOverGroupWithAtomicFunctor<argTy, resTy, ReductionOpT,
508+ InputOutputIterIndexerT,
509+ ReductionIndexerT>(
510+ arg_tp, res_tp, ReductionOpT (), identity_val,
511+ in_out_iter_indexer, reduction_indexer, reduction_nelems,
512+ iter_nelems, reductions_per_wi));
513+ });
514+
515+ return comp_ev;
516+ }
517+ }
518+
519+ /* @brief Reduce rows in a matrix */
520+ template <typename argTy, typename resTy>
521+ sycl::event sum_reduction_axis0_over_group_with_atomics_contig_impl (
522+ sycl::queue exec_q,
523+ size_t iter_nelems, // number of reductions (num. of cols in a matrix
524+ // when reducing over cols)
525+ size_t reduction_nelems, // size of each reduction (length of cols, i.e.
526+ // number of rows)
527+ const char *arg_cp,
528+ char *res_cp,
529+ py::ssize_t iter_arg_offset,
530+ py::ssize_t iter_res_offset,
531+ py::ssize_t reduction_arg_offset,
532+ const std::vector<sycl::event> &depends)
533+ {
534+ const argTy *arg_tp = reinterpret_cast <const argTy *>(arg_cp) +
535+ iter_arg_offset + reduction_arg_offset;
536+ resTy *res_tp = reinterpret_cast <resTy *>(res_cp) + iter_res_offset;
537+
538+ using ReductionOpT = sycl::plus<resTy>;
539+ constexpr resTy identity_val = resTy{0 };
540+
541+ const sycl::device &d = exec_q.get_device ();
542+ const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
543+ size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
544+
545+ {
546+ sycl::event res_init_ev = exec_q.fill <resTy>(
547+ res_tp, resTy (identity_val), iter_nelems, depends);
548+
549+ sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
550+ cgh.depends_on (res_init_ev);
551+
552+ using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
553+ using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer;
554+ using InputOutputIterIndexerT =
555+ dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
556+ NoOpIndexerT, NoOpIndexerT>;
557+ using ReductionIndexerT = ColsIndexerT;
558+
559+ NoOpIndexerT columns_indexer{};
560+ NoOpIndexerT result_indexer{};
561+ InputOutputIterIndexerT in_out_iter_indexer{columns_indexer,
562+ result_indexer};
563+ ReductionIndexerT reduction_indexer{
564+ 0 , /* size */ static_cast <py::ssize_t >(reduction_nelems),
565+ /* step */ static_cast <py::ssize_t >(iter_nelems)};
566+
567+ constexpr size_t preferrered_reductions_per_wi = 8 ;
568+ size_t reductions_per_wi =
569+ (reduction_nelems < preferrered_reductions_per_wi * wg)
570+ ? std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg)
571+ : preferrered_reductions_per_wi;
572+
573+ size_t reduction_groups =
574+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
575+ (reductions_per_wi * wg);
576+
577+ if (reduction_groups > 1 ) {
578+ const size_t &max_wg =
579+ d.get_info <sycl::info::device::max_work_group_size>();
580+
581+ if (reduction_nelems < preferrered_reductions_per_wi * max_wg) {
582+ wg = max_wg;
583+ reductions_per_wi =
584+ std::max<size_t >(1 , (reduction_nelems + wg - 1 ) / wg);
585+ reduction_groups =
586+ (reduction_nelems + reductions_per_wi * wg - 1 ) /
587+ (reductions_per_wi * wg);
588+ }
589+ }
590+
591+ auto globalRange =
592+ sycl::range<1 >{iter_nelems * reduction_groups * wg};
593+ auto localRange = sycl::range<1 >{wg};
594+
595+ using KernelName =
596+ class sum_reduction_axis0_over_group_with_atomics_contig_krn <
499597 argTy, resTy, ReductionOpT, InputOutputIterIndexerT,
500598 ReductionIndexerT>;
501599
@@ -1075,15 +1173,34 @@ struct SumOverAxisTempsStridedFactory
10751173};
10761174
10771175template <typename fnT, typename srcTy, typename dstTy>
1078- struct SumOverAxisAtomicContigFactory
1176+ struct SumOverAxis1AtomicContigFactory
1177+ {
1178+ fnT get () const
1179+ {
1180+ if constexpr (TypePairSupportDataForSumReductionAtomic<
1181+ srcTy, dstTy>::is_defined)
1182+ {
1183+ return dpctl::tensor::kernels::
1184+ sum_reduction_axis1_over_group_with_atomics_contig_impl<srcTy,
1185+ dstTy>;
1186+ }
1187+ else {
1188+ return nullptr ;
1189+ }
1190+ }
1191+ };
1192+
1193+ template <typename fnT, typename srcTy, typename dstTy>
1194+ struct SumOverAxis0AtomicContigFactory
10791195{
10801196 fnT get () const
10811197 {
10821198 if constexpr (TypePairSupportDataForSumReductionAtomic<
10831199 srcTy, dstTy>::is_defined)
10841200 {
10851201 return dpctl::tensor::kernels::
1086- sum_reduction_over_group_with_atomics_contig_impl<srcTy, dstTy>;
1202+ sum_reduction_axis0_over_group_with_atomics_contig_impl<srcTy,
1203+ dstTy>;
10871204 }
10881205 else {
10891206 return nullptr ;
0 commit comments