2727#include < algorithm>
2828#include < complex>
2929#include < cstdint>
30- #include < iostream>
3130#include < pybind11/complex.h>
3231#include < pybind11/pybind11.h>
3332#include < pybind11/stl.h>
@@ -280,15 +279,50 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
280279 }
281280}
282281
282+ /* Utility to parse python object py_ind into vector of `usm_ndarray`s */
283+ std::vector<dpctl::tensor::usm_ndarray> parse_py_ind (const sycl::queue &q,
284+ py::object py_ind)
285+ {
286+ size_t ind_count = py::len (py_ind);
287+ std::vector<dpctl::tensor::usm_ndarray> res;
288+ res.reserve (ind_count);
289+
290+ bool acquired = false ;
291+ int nd = -1 ;
292+ for (size_t i = 0 ; i < ind_count; ++i) {
293+ auto el_i = py_ind[py::cast (i)];
294+ auto arr_i = py::cast<dpctl::tensor::usm_ndarray>(el_i);
295+ if (!dpctl::utils::queues_are_compatible (q, {arr_i})) {
296+ throw py::value_error (" Index allocation queue is not compatible "
297+ " with execution queue" );
298+ }
299+ if (acquired) {
300+ if (nd != arr_i.get_ndim ()) {
301+ throw py::value_error (
302+ " Indices must have the same number of dimensions." );
303+ }
304+ }
305+ else {
306+ acquired = true ;
307+ nd = arr_i.get_ndim ();
308+ }
309+ res.push_back (arr_i);
310+ }
311+
312+ return res;
313+ }
314+
283315std::pair<sycl::event, sycl::event>
284316usm_ndarray_take (dpctl::tensor::usm_ndarray src,
285- std::vector<dpctl::tensor::usm_ndarray> ind ,
317+ py::object py_ind ,
286318 dpctl::tensor::usm_ndarray dst,
287319 int axis_start,
288320 uint8_t mode,
289321 sycl::queue exec_q,
290322 const std::vector<sycl::event> &depends)
291323{
324+ std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind (exec_q, py_ind);
325+
292326 int k = ind.size ();
293327
294328 if (k == 0 ) {
@@ -636,15 +670,12 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
636670 std::to_string (ind_type_id));
637671 }
638672
639- std::cout << " Submitting take" << std::endl;
640673 sycl::event take_generic_ev =
641674 fn (exec_q, orthog_nelems, ind_nelems, orthog_nd, ind_nd, k,
642675 packed_shapes_strides, packed_axes_shapes_strides,
643676 packed_ind_shapes_strides, src_data, dst_data, packed_ind_ptrs,
644677 src_offset, dst_offset, packed_ind_offsets, all_deps);
645678
646- std::cout << " Submitting take clean-up host task" << std::endl;
647-
648679 // free packed temporaries
649680 auto ctx = exec_q.get_context ();
650681 exec_q.submit ([&](sycl::handler &cgh) {
@@ -661,19 +692,20 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
661692 });
662693
663694 return std::make_pair (
664- keep_args_alive (exec_q, {src, dst}, {take_generic_ev}),
695+ keep_args_alive (exec_q, {src, py_ind, dst}, {take_generic_ev}),
665696 take_generic_ev);
666697}
667698
668699std::pair<sycl::event, sycl::event>
669700usm_ndarray_put (dpctl::tensor::usm_ndarray dst,
670- std::vector<dpctl::tensor::usm_ndarray> ind ,
701+ py::object py_ind ,
671702 dpctl::tensor::usm_ndarray val,
672703 int axis_start,
673704 uint8_t mode,
674705 sycl::queue exec_q,
675706 const std::vector<sycl::event> &depends)
676707{
708+ std::vector<dpctl::tensor::usm_ndarray> ind = parse_py_ind (exec_q, py_ind);
677709 int k = ind.size ();
678710
679711 if (k == 0 ) {
@@ -1046,8 +1078,9 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10461078 });
10471079 });
10481080
1049- return std::make_pair (keep_args_alive (exec_q, {dst, val}, {put_generic_ev}),
1050- put_generic_ev);
1081+ return std::make_pair (
1082+ keep_args_alive (exec_q, {dst, py_ind, val}, {put_generic_ev}),
1083+ put_generic_ev);
10511084}
10521085
10531086void init_advanced_indexing_dispatch_tables (void )
0 commit comments