1- // === boolean_advance_indexing.hpp - ---*-C++-*--/===//
1+ // === boolean_advance_indexing.hpp - --- ---*-C++-*--/===//
22//
33// Data Parallel Control (dpctl)
44//
1616// See the License for the specific language governing permissions and
1717// limitations under the License.
1818//
19- // ===---------------------------------------------------------------------- ===//
19+ // ===---------------------------------------------------------------------===//
2020// /
2121// / \file
2222// / This file defines kernels for advanced tensor index operations.
23- // ===---------------------------------------------------------------------- ===//
23+ // ===---------------------------------------------------------------------===//
2424
2525#pragma once
2626#include < CL/sycl.hpp>
@@ -114,6 +114,26 @@ struct Strided1DIndexer
114114 py::ssize_t step = 1 ;
115115};
116116
117+ struct Strided1DCyclicIndexer
118+ {
119+ Strided1DCyclicIndexer (py::ssize_t _offset,
120+ py::ssize_t _size,
121+ py::ssize_t _step)
122+ : offset(_offset), size(static_cast <size_t >(_size)), step(_step)
123+ {
124+ }
125+
126+ size_t operator ()(size_t gid) const
127+ {
128+ return static_cast <size_t >(offset + (gid % size) * step);
129+ }
130+
131+ private:
132+ py::ssize_t offset = 0 ;
133+ size_t size = 1 ;
134+ py::ssize_t step = 1 ;
135+ };
136+
117137template <typename _IndexerFn> struct ZeroChecker
118138{
119139
@@ -762,27 +782,22 @@ sycl::event masked_place_all_slices_strided_impl(
762782 py::ssize_t rhs_stride,
763783 const std::vector<sycl::event> &depends = {})
764784{
765- // using MaskedPlaceStridedFunctor;
766- // using Strided1DIndexer;
767- // using StridedIndexer;
768- // using TwoZeroOffsets_Indexer;
769-
770785 TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
771786
772787 /* StridedIndexer(int _nd, py::ssize_t _offset, py::ssize_t const
773788 * *_packed_shape_strides) */
774789 StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
775- Strided1DIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
790+ Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
776791
777792 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
778793 cgh.depends_on (depends);
779794
780795 cgh.parallel_for <class masked_place_all_slices_strided_impl_krn <
781- TwoZeroOffsets_Indexer, StridedIndexer, Strided1DIndexer, dataT ,
782- indT>>(
796+ TwoZeroOffsets_Indexer, StridedIndexer, Strided1DCyclicIndexer ,
797+ dataT, indT>>(
783798 sycl::range<1 >(static_cast <size_t >(iteration_size)),
784799 MaskedPlaceStridedFunctor<TwoZeroOffsets_Indexer, StridedIndexer,
785- Strided1DIndexer , dataT, indT>(
800+ Strided1DCyclicIndexer , dataT, indT>(
786801 dst_p, cumsum_p, rhs_p, 1 , iteration_size,
787802 orthog_dst_rhs_indexer, masked_dst_indexer,
788803 masked_rhs_indexer));
@@ -838,11 +853,6 @@ sycl::event masked_place_some_slices_strided_impl(
838853 py::ssize_t masked_rhs_stride,
839854 const std::vector<sycl::event> &depends = {})
840855{
841- // using MaskedPlaceStridedFunctor;
842- // using Strided1DIndexer;
843- // using StridedIndexer;
844- // using TwoOffsets_StridedIndexer;
845-
846856 TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
847857 orthog_nd, ortho_dst_offset, ortho_rhs_offset,
848858 packed_ortho_dst_rhs_shape_strides};
@@ -851,17 +861,18 @@ sycl::event masked_place_some_slices_strided_impl(
851861 * *_packed_shape_strides) */
852862 StridedIndexer masked_dst_indexer{masked_nd, 0 ,
853863 packed_masked_dst_shape_strides};
854- Strided1DIndexer masked_rhs_indexer{0 , masked_rhs_size, masked_rhs_stride};
864+ Strided1DCyclicIndexer masked_rhs_indexer{0 , masked_rhs_size,
865+ masked_rhs_stride};
855866
856867 sycl::event comp_ev = exec_q.submit ([&](sycl::handler &cgh) {
857868 cgh.depends_on (depends);
858869
859870 cgh.parallel_for <class masked_place_some_slices_strided_impl_krn <
860- TwoOffsets_StridedIndexer, StridedIndexer, Strided1DIndexer, dataT ,
861- indT>>(
871+ TwoOffsets_StridedIndexer, StridedIndexer, Strided1DCyclicIndexer ,
872+ dataT, indT>>(
862873 sycl::range<1 >(static_cast <size_t >(orthog_nelems * masked_nelems)),
863874 MaskedPlaceStridedFunctor<TwoOffsets_StridedIndexer, StridedIndexer,
864- Strided1DIndexer , dataT, indT>(
875+ Strided1DCyclicIndexer , dataT, indT>(
865876 dst_p, cumsum_p, rhs_p, orthog_nelems, masked_nelems,
866877 orthog_dst_rhs_indexer, masked_dst_indexer,
867878 masked_rhs_indexer));
0 commit comments