@@ -70,6 +70,7 @@ using dpctl::utils::keep_args_alive;
7070
7171std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing (
7272 sycl::queue exec_q,
73+ std::vector<sycl::event> &host_task_events,
7374 py::ssize_t *device_orthog_shapes_strides,
7475 py::ssize_t *device_axes_shapes_strides,
7576 const py::ssize_t *inp_shape,
@@ -210,20 +211,21 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
210211 exec_q.copy <py::ssize_t >(packed_host_shapes_strides_shp->data (),
211212 device_orthog_shapes_strides,
212213 packed_host_shapes_strides_shp->size ());
213- exec_q.submit ([&](sycl::handler &cgh) {
214- cgh.depends_on (device_orthog_shapes_strides_copy_ev);
215- cgh.host_task ([packed_host_shapes_strides_shp] {});
216- });
217214
218215 sycl::event device_axes_shapes_strides_copy_ev =
219216 exec_q.copy <py::ssize_t >(
220217 packed_host_axes_shapes_strides_shp->data (),
221218 device_axes_shapes_strides,
222219 packed_host_axes_shapes_strides_shp->size ());
223- exec_q.submit ([&](sycl::handler &cgh) {
224- cgh.depends_on (device_axes_shapes_strides_copy_ev);
225- cgh.host_task ([packed_host_axes_shapes_strides_shp]() {});
226- });
220+
221+ sycl::event clean_up_host_task_ev =
222+ exec_q.submit ([&](sycl::handler &cgh) {
223+ cgh.depends_on (device_axes_shapes_strides_copy_ev);
224+ cgh.depends_on (device_orthog_shapes_strides_copy_ev);
225+ cgh.host_task ([packed_host_axes_shapes_strides_shp,
226+ packed_host_shapes_strides_shp]() {});
227+ });
228+ host_task_events.push_back (clean_up_host_task_ev);
227229
228230 std::vector<sycl::event> v = {device_orthog_shapes_strides_copy_ev,
229231 device_axes_shapes_strides_copy_ev};
@@ -268,10 +270,13 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
268270 packed_host_axes_shapes_strides_shp->data (),
269271 device_axes_shapes_strides,
270272 packed_host_axes_shapes_strides_shp->size ());
271- exec_q.submit ([&](sycl::handler &cgh) {
272- cgh.depends_on (device_axes_shapes_strides_copy_ev);
273- cgh.host_task ([packed_host_axes_shapes_strides_shp]() {});
274- });
273+
274+ sycl::event clean_up_host_task_ev =
275+ exec_q.submit ([&](sycl::handler &cgh) {
276+ cgh.depends_on (device_axes_shapes_strides_copy_ev);
277+ cgh.host_task ([packed_host_axes_shapes_strides_shp]() {});
278+ });
279+ host_task_events.push_back (clean_up_host_task_ev);
275280
276281 std::vector<sycl::event> v = {device_orthog_shapes_strides_fill_ev,
277282 device_axes_shapes_strides_copy_ev};
@@ -590,28 +595,33 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
590595 std::copy (ind_offsets.begin (), ind_offsets.end (),
591596 host_ind_offsets_shp->begin ());
592597
598+ std::vector<sycl::event> host_task_events (5 );
599+
593600 sycl::event packed_ind_ptrs_copy_ev = exec_q.copy <char *>(
594601 host_ind_ptrs_shp->data (), packed_ind_ptrs, host_ind_ptrs_shp->size ());
595- exec_q.submit ([&](sycl::handler &cgh) {
602+ sycl::event ind_ptrs_host_task = exec_q.submit ([&](sycl::handler &cgh) {
596603 cgh.depends_on (packed_ind_ptrs_copy_ev);
597604 cgh.host_task ([host_ind_ptrs_shp]() {});
598605 });
606+ host_task_events.push_back (ind_ptrs_host_task);
599607
600608 sycl::event packed_ind_shapes_strides_copy_ev = exec_q.copy <py::ssize_t >(
601609 host_ind_shapes_strides_shp->data (), packed_ind_shapes_strides,
602610 host_ind_shapes_strides_shp->size ());
603- exec_q.submit ([&](sycl::handler &cgh) {
611+ sycl::event ind_sh_st_host_task = exec_q.submit ([&](sycl::handler &cgh) {
604612 cgh.depends_on (packed_ind_shapes_strides_copy_ev);
605613 cgh.host_task ([host_ind_shapes_strides_shp]() {});
606614 });
615+ host_task_events.push_back (ind_sh_st_host_task);
607616
608617 sycl::event packed_ind_offsets_copy_ev = exec_q.copy <py::ssize_t >(
609618 host_ind_offsets_shp->data (), packed_ind_offsets,
610619 host_ind_offsets_shp->size ());
611- exec_q.submit ([&](sycl::handler &cgh) {
620+ sycl::event ind_offsets_host_task = exec_q.submit ([&](sycl::handler &cgh) {
612621 cgh.depends_on (packed_ind_offsets_copy_ev);
613622 cgh.host_task ([host_ind_offsets_shp]() {});
614623 });
624+ host_task_events.push_back (ind_offsets_host_task);
615625
616626 std::vector<sycl::event> ind_pack_depends{packed_ind_ptrs_copy_ev,
617627 packed_ind_shapes_strides_copy_ev,
@@ -650,10 +660,10 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
650660
651661 std::vector<sycl::event> src_dst_pack_deps =
652662 _populate_packed_shapes_strides_for_indexing (
653- exec_q, packed_shapes_strides, packed_axes_shapes_strides ,
654- src_shape, src_strides, is_src_c_contig, is_src_f_contig, dst_shape ,
655- dst_strides, is_dst_c_contig, is_dst_f_contig, axis_start, k ,
656- ind_nd, src_nd, dst_nd);
663+ exec_q, host_task_events, packed_shapes_strides ,
664+ packed_axes_shapes_strides, src_shape, src_strides, is_src_c_contig,
665+ is_src_f_contig, dst_shape, dst_strides, is_dst_c_contig ,
666+ is_dst_f_contig, axis_start, k, ind_nd, src_nd, dst_nd);
657667
658668 std::vector<sycl::event> all_deps (depends.size () + ind_pack_depends.size () +
659669 src_dst_pack_deps.size ());
@@ -690,9 +700,10 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
690700 sycl::free (packed_ind_offsets, ctx);
691701 });
692702 });
703+ host_task_events.push_back (take_generic_ev);
693704
694705 sycl::event host_task_ev =
695- keep_args_alive (exec_q, {src, py_ind, dst}, {take_generic_ev} );
706+ keep_args_alive (exec_q, {src, py_ind, dst}, host_task_events );
696707
697708 return std::make_pair (host_task_ev, take_generic_ev);
698709}
@@ -977,28 +988,33 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
977988 std::copy (ind_offsets.begin (), ind_offsets.end (),
978989 host_ind_offsets_shp->begin ());
979990
991+ std::vector<sycl::event> host_task_events (5 );
992+
980993 sycl::event device_ind_ptrs_copy_ev = exec_q.copy <char *>(
981994 host_ind_ptrs_shp->data (), packed_ind_ptrs, host_ind_ptrs_shp->size ());
982- exec_q.submit ([&](sycl::handler &cgh) {
995+ sycl::event ind_ptrs_host_task = exec_q.submit ([&](sycl::handler &cgh) {
983996 cgh.depends_on (device_ind_ptrs_copy_ev);
984997 cgh.host_task ([host_ind_ptrs_shp]() {});
985998 });
999+ host_task_events.push_back (ind_ptrs_host_task);
9861000
9871001 sycl::event device_ind_shapes_strides_copy_ev = exec_q.copy <py::ssize_t >(
9881002 host_ind_shapes_strides_shp->data (), packed_ind_shapes_strides,
9891003 host_ind_shapes_strides_shp->size ());
990- exec_q.submit ([&](sycl::handler &cgh) {
1004+ sycl::event ind_sh_st_host_task = exec_q.submit ([&](sycl::handler &cgh) {
9911005 cgh.depends_on (device_ind_shapes_strides_copy_ev);
9921006 cgh.host_task ([host_ind_shapes_strides_shp]() {});
9931007 });
1008+ host_task_events.push_back (ind_sh_st_host_task);
9941009
9951010 sycl::event device_ind_offsets_copy_ev = exec_q.copy <py::ssize_t >(
9961011 host_ind_offsets_shp->data (), packed_ind_offsets,
9971012 host_ind_offsets_shp->size ());
998- exec_q.submit ([&](sycl::handler &cgh) {
1013+ sycl::event ind_offsets_host_task = exec_q.submit ([&](sycl::handler &cgh) {
9991014 cgh.depends_on (device_ind_offsets_copy_ev);
10001015 cgh.host_task ([host_ind_offsets_shp]() {});
10011016 });
1017+ host_task_events.push_back (ind_offsets_host_task);
10021018
10031019 std::vector<sycl::event> ind_pack_depends{device_ind_ptrs_copy_ev,
10041020 device_ind_shapes_strides_copy_ev,
@@ -1037,10 +1053,10 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10371053
10381054 std::vector<sycl::event> copy_shapes_strides_deps =
10391055 _populate_packed_shapes_strides_for_indexing (
1040- exec_q, packed_shapes_strides, packed_axes_shapes_strides ,
1041- dst_shape, dst_strides, is_dst_c_contig, is_dst_f_contig, val_shape ,
1042- val_strides, is_val_c_contig, is_val_f_contig, axis_start, k ,
1043- ind_nd, dst_nd, val_nd);
1056+ exec_q, host_task_events, packed_shapes_strides ,
1057+ packed_axes_shapes_strides, dst_shape, dst_strides, is_dst_c_contig,
1058+ is_dst_f_contig, val_shape, val_strides, is_val_c_contig ,
1059+ is_val_f_contig, axis_start, k, ind_nd, dst_nd, val_nd);
10441060
10451061 std::vector<sycl::event> all_deps (depends.size () +
10461062 copy_shapes_strides_deps.size () +
@@ -1078,9 +1094,10 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10781094 sycl::free (packed_ind_offsets, ctx);
10791095 });
10801096 });
1097+ host_task_events.push_back (put_generic_ev);
10811098
10821099 return std::make_pair (
1083- keep_args_alive (exec_q, {dst, py_ind, val}, {put_generic_ev} ),
1100+ keep_args_alive (exec_q, {dst, py_ind, val}, host_task_events ),
10841101 put_generic_ev);
10851102}
10861103
0 commit comments