2525#pragma once
2626#include < cstdint>
2727#include < limits>
28- #include < pybind11/pybind11.h>
2928#include < sycl/sycl.hpp>
3029#include < utility>
3130#include < vector>
3231
32+ #include " dpctl_tensor_types.hpp"
3333#include " utils/offset_utils.hpp"
34- #include " utils/type_dispatch .hpp"
34+ #include " utils/type_dispatch_building .hpp"
3535
3636namespace dpctl
3737{
@@ -42,8 +42,6 @@ namespace kernels
4242namespace indexing
4343{
4444
45- namespace py = pybind11;
46-
4745using namespace dpctl ::tensor::offset_utils;
4846
4947template <typename OrthogIndexerT,
@@ -90,7 +88,7 @@ struct MaskedExtractStridedFunctor
9088 // + 1 : 1)
9189 if (mask_set) {
9290 auto orthog_offsets =
93- orthog_src_dst_indexer (static_cast <py:: ssize_t >(orthog_i));
91+ orthog_src_dst_indexer (static_cast <ssize_t >(orthog_i));
9492
9593 size_t total_src_offset = masked_src_indexer (masked_i) +
9694 orthog_offsets.get_first_offset ();
@@ -161,7 +159,7 @@ struct MaskedPlaceStridedFunctor
161159 // + 1 : 1)
162160 if (mask_set) {
163161 auto orthog_offsets =
164- orthog_dst_rhs_indexer (static_cast <py:: ssize_t >(orthog_i));
162+ orthog_dst_rhs_indexer (static_cast <ssize_t >(orthog_i));
165163
166164 size_t total_dst_offset = masked_dst_indexer (masked_i) +
167165 orthog_offsets.get_first_offset ();
@@ -199,28 +197,28 @@ class masked_extract_all_slices_strided_impl_krn;
199197
200198typedef sycl::event (*masked_extract_all_slices_strided_impl_fn_ptr_t )(
201199 sycl::queue &,
202- py:: ssize_t ,
200+ ssize_t ,
203201 const char *,
204202 const char *,
205203 char *,
206204 int ,
207- py:: ssize_t const *,
208- py:: ssize_t ,
209- py:: ssize_t ,
205+ ssize_t const *,
206+ ssize_t ,
207+ ssize_t ,
210208 const std::vector<sycl::event> &);
211209
212210template <typename dataT, typename indT>
213211sycl::event masked_extract_all_slices_strided_impl (
214212 sycl::queue &exec_q,
215- py:: ssize_t iteration_size,
213+ ssize_t iteration_size,
216214 const char *src_p,
217215 const char *cumsum_p,
218216 char *dst_p,
219217 int nd,
220- const py:: ssize_t
218+ const ssize_t
221219 *packed_src_shape_strides, // [src_shape, src_strides], length 2*nd
222- py:: ssize_t dst_size, // dst is 1D
223- py:: ssize_t dst_stride,
220+ ssize_t dst_size, // dst is 1D
221+ ssize_t dst_stride,
224222 const std::vector<sycl::event> &depends = {})
225223{
226224 // using MaskedExtractStridedFunctor;
@@ -230,7 +228,7 @@ sycl::event masked_extract_all_slices_strided_impl(
230228
231229 TwoZeroOffsets_Indexer orthog_src_dst_indexer{};
232230
233- /* StridedIndexer(int _nd, py:: ssize_t _offset, py:: ssize_t const
231+ /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
234232 * *_packed_shape_strides) */
235233 StridedIndexer masked_src_indexer (nd, 0 , packed_src_shape_strides);
236234 Strided1DIndexer masked_dst_indexer (0 , dst_size, dst_stride);
@@ -254,19 +252,19 @@ sycl::event masked_extract_all_slices_strided_impl(
254252
255253typedef sycl::event (*masked_extract_some_slices_strided_impl_fn_ptr_t )(
256254 sycl::queue &,
257- py:: ssize_t ,
258- py:: ssize_t ,
255+ ssize_t ,
256+ ssize_t ,
259257 const char *,
260258 const char *,
261259 char *,
262260 int ,
263- py:: ssize_t const *,
264- py:: ssize_t ,
265- py:: ssize_t ,
261+ ssize_t const *,
262+ ssize_t ,
263+ ssize_t ,
266264 int ,
267- py:: ssize_t const *,
268- py:: ssize_t ,
269- py:: ssize_t ,
265+ ssize_t const *,
266+ ssize_t ,
267+ ssize_t ,
270268 const std::vector<sycl::event> &);
271269
272270template <typename OrthoIndexerT,
@@ -279,24 +277,24 @@ class masked_extract_some_slices_strided_impl_krn;
279277template <typename dataT, typename indT>
280278sycl::event masked_extract_some_slices_strided_impl (
281279 sycl::queue &exec_q,
282- py:: ssize_t orthog_nelems,
283- py:: ssize_t masked_nelems,
280+ ssize_t orthog_nelems,
281+ ssize_t masked_nelems,
284282 const char *src_p,
285283 const char *cumsum_p,
286284 char *dst_p,
287285 int orthog_nd,
288- const py:: ssize_t
286+ const ssize_t
289287 *packed_ortho_src_dst_shape_strides, // [ortho_shape, ortho_src_strides,
290288 // ortho_dst_strides], length
291289 // 3*ortho_nd
292- py:: ssize_t ortho_src_offset,
293- py:: ssize_t ortho_dst_offset,
290+ ssize_t ortho_src_offset,
291+ ssize_t ortho_dst_offset,
294292 int masked_nd,
295- const py:: ssize_t *packed_masked_src_shape_strides, // [masked_src_shape,
296- // masked_src_strides],
297- // length 2*masked_nd
298- py:: ssize_t masked_dst_size, // mask_dst is 1D
299- py:: ssize_t masked_dst_stride,
293+ const ssize_t *packed_masked_src_shape_strides, // [masked_src_shape,
294+ // masked_src_strides],
295+ // length 2*masked_nd
296+ ssize_t masked_dst_size, // mask_dst is 1D
297+ ssize_t masked_dst_stride,
300298 const std::vector<sycl::event> &depends = {})
301299{
302300 // using MaskedExtractStridedFunctor;
@@ -381,33 +379,33 @@ class masked_place_all_slices_strided_impl_krn;
381379
382380typedef sycl::event (*masked_place_all_slices_strided_impl_fn_ptr_t )(
383381 sycl::queue &,
384- py:: ssize_t ,
382+ ssize_t ,
385383 char *,
386384 const char *,
387385 const char *,
388386 int ,
389- py:: ssize_t const *,
390- py:: ssize_t ,
391- py:: ssize_t ,
387+ ssize_t const *,
388+ ssize_t ,
389+ ssize_t ,
392390 const std::vector<sycl::event> &);
393391
394392template <typename dataT, typename indT>
395393sycl::event masked_place_all_slices_strided_impl (
396394 sycl::queue &exec_q,
397- py:: ssize_t iteration_size,
395+ ssize_t iteration_size,
398396 char *dst_p,
399397 const char *cumsum_p,
400398 const char *rhs_p,
401399 int nd,
402- const py:: ssize_t
400+ const ssize_t
403401 *packed_dst_shape_strides, // [dst_shape, dst_strides], length 2*nd
404- py:: ssize_t rhs_size, // rhs is 1D
405- py:: ssize_t rhs_stride,
402+ ssize_t rhs_size, // rhs is 1D
403+ ssize_t rhs_stride,
406404 const std::vector<sycl::event> &depends = {})
407405{
408406 TwoZeroOffsets_Indexer orthog_dst_rhs_indexer{};
409407
410- /* StridedIndexer(int _nd, py:: ssize_t _offset, py:: ssize_t const
408+ /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
411409 * *_packed_shape_strides) */
412410 StridedIndexer masked_dst_indexer (nd, 0 , packed_dst_shape_strides);
413411 Strided1DCyclicIndexer masked_rhs_indexer (0 , rhs_size, rhs_stride);
@@ -431,19 +429,19 @@ sycl::event masked_place_all_slices_strided_impl(
431429
432430typedef sycl::event (*masked_place_some_slices_strided_impl_fn_ptr_t )(
433431 sycl::queue &,
434- py:: ssize_t ,
435- py:: ssize_t ,
432+ ssize_t ,
433+ ssize_t ,
436434 char *,
437435 const char *,
438436 const char *,
439437 int ,
440- py:: ssize_t const *,
441- py:: ssize_t ,
442- py:: ssize_t ,
438+ ssize_t const *,
439+ ssize_t ,
440+ ssize_t ,
443441 int ,
444- py:: ssize_t const *,
445- py:: ssize_t ,
446- py:: ssize_t ,
442+ ssize_t const *,
443+ ssize_t ,
444+ ssize_t ,
447445 const std::vector<sycl::event> &);
448446
449447template <typename OrthoIndexerT,
@@ -456,31 +454,31 @@ class masked_place_some_slices_strided_impl_krn;
456454template <typename dataT, typename indT>
457455sycl::event masked_place_some_slices_strided_impl (
458456 sycl::queue &exec_q,
459- py:: ssize_t orthog_nelems,
460- py:: ssize_t masked_nelems,
457+ ssize_t orthog_nelems,
458+ ssize_t masked_nelems,
461459 char *dst_p,
462460 const char *cumsum_p,
463461 const char *rhs_p,
464462 int orthog_nd,
465- const py:: ssize_t
463+ const ssize_t
466464 *packed_ortho_dst_rhs_shape_strides, // [ortho_shape, ortho_dst_strides,
467465 // ortho_rhs_strides], length
468466 // 3*ortho_nd
469- py:: ssize_t ortho_dst_offset,
470- py:: ssize_t ortho_rhs_offset,
467+ ssize_t ortho_dst_offset,
468+ ssize_t ortho_rhs_offset,
471469 int masked_nd,
472- const py:: ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape,
473- // masked_dst_strides],
474- // length 2*masked_nd
475- py:: ssize_t masked_rhs_size, // mask_dst is 1D
476- py:: ssize_t masked_rhs_stride,
470+ const ssize_t *packed_masked_dst_shape_strides, // [masked_dst_shape,
471+ // masked_dst_strides],
472+ // length 2*masked_nd
473+ ssize_t masked_rhs_size, // mask_dst is 1D
474+ ssize_t masked_rhs_stride,
477475 const std::vector<sycl::event> &depends = {})
478476{
479477 TwoOffsets_StridedIndexer orthog_dst_rhs_indexer{
480478 orthog_nd, ortho_dst_offset, ortho_rhs_offset,
481479 packed_ortho_dst_rhs_shape_strides};
482480
483- /* StridedIndexer(int _nd, py:: ssize_t _offset, py:: ssize_t const
481+ /* StridedIndexer(int _nd, ssize_t _offset, ssize_t const
484482 * *_packed_shape_strides) */
485483 StridedIndexer masked_dst_indexer{masked_nd, 0 ,
486484 packed_masked_dst_shape_strides};
@@ -550,22 +548,22 @@ template <typename T1, typename T2> class non_zero_indexes_krn;
550548
551549typedef sycl::event (*non_zero_indexes_fn_ptr_t )(
552550 sycl::queue &,
553- py:: ssize_t ,
554- py:: ssize_t ,
551+ ssize_t ,
552+ ssize_t ,
555553 int ,
556554 const char *,
557555 char *,
558- const py:: ssize_t *,
556+ const ssize_t *,
559557 std::vector<sycl::event> const &);
560558
561559template <typename indT1, typename indT2>
562560sycl::event non_zero_indexes_impl (sycl::queue &exec_q,
563- py:: ssize_t iter_size,
564- py:: ssize_t nz_elems,
561+ ssize_t iter_size,
562+ ssize_t nz_elems,
565563 int nd,
566564 const char *cumsum_cp,
567565 char *indexes_cp,
568- const py:: ssize_t *mask_shape,
566+ const ssize_t *mask_shape,
569567 std::vector<sycl::event> const &depends)
570568{
571569 const indT1 *cumsum_data = reinterpret_cast <const indT1 *>(cumsum_cp);
@@ -582,11 +580,11 @@ sycl::event non_zero_indexes_impl(sycl::queue &exec_q,
582580 auto cs_prev_val = (i > 0 ) ? cumsum_data[i - 1 ] : indT1 (0 );
583581 bool cond = (cs_curr_val == cs_prev_val);
584582
585- py:: ssize_t i_ = static_cast <py:: ssize_t >(i);
583+ ssize_t i_ = static_cast <ssize_t >(i);
586584 for (int dim = nd; --dim > 0 ;) {
587585 auto sd = mask_shape[dim];
588- py:: ssize_t q = i_ / sd;
589- py:: ssize_t r = (i_ - q * sd);
586+ ssize_t q = i_ / sd;
587+ ssize_t r = (i_ - q * sd);
590588 if (cond) {
591589 indexes_data[cs_curr_val + dim * nz_elems] =
592590 static_cast <indT2>(r);
0 commit comments