@@ -599,21 +599,22 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
599599 std::shared_ptr<shT> host_ind_offsets_shp =
600600 std::make_shared<shT>(k, ind_allocator);
601601
602+ std::copy (ind_sh_sts.begin (), ind_sh_sts.end (),
603+ host_ind_shapes_strides_shp->begin ());
604+ std::copy (ind_ptrs.begin (), ind_ptrs.end (), host_ind_ptrs_shp->begin ());
605+ std::copy (ind_offsets.begin (), ind_offsets.end (),
606+ host_ind_offsets_shp->begin ());
607+
602608 std::vector<sycl::event> host_task_events;
603609 host_task_events.reserve (5 );
604610
605- std::copy (ind_sh_sts.begin (), ind_sh_sts.end (),
606- host_ind_shapes_strides_shp->begin ());
607611 sycl::event packed_ind_ptrs_copy_ev = exec_q.copy <char *>(
608612 host_ind_ptrs_shp->data (), packed_ind_ptrs, host_ind_ptrs_shp->size ());
609613
610- std::copy (ind_ptrs.begin (), ind_ptrs.end (), host_ind_ptrs_shp->begin ());
611614 sycl::event packed_ind_shapes_strides_copy_ev = exec_q.copy <py::ssize_t >(
612615 host_ind_shapes_strides_shp->data (), packed_ind_shapes_strides,
613616 host_ind_shapes_strides_shp->size ());
614617
615- std::copy (ind_offsets.begin (), ind_offsets.end (),
616- host_ind_offsets_shp->begin ());
617618 sycl::event packed_ind_offsets_copy_ev = exec_q.copy <py::ssize_t >(
618619 host_ind_offsets_shp->data (), packed_ind_offsets,
619620 host_ind_offsets_shp->size ());
@@ -1010,38 +1011,39 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
10101011 std::shared_ptr<shT> host_ind_offsets_shp =
10111012 std::make_shared<shT>(k, ind_allocator);
10121013
1014+ std::copy (ind_sh_sts.begin (), ind_sh_sts.end (),
1015+ host_ind_shapes_strides_shp->begin ());
1016+ std::copy (ind_ptrs.begin (), ind_ptrs.end (), host_ind_ptrs_shp->begin ());
1017+ std::copy (ind_offsets.begin (), ind_offsets.end (),
1018+ host_ind_offsets_shp->begin ());
1019+
10131020 std::vector<sycl::event> host_task_events;
10141021 host_task_events.reserve (5 );
10151022
1016- std::copy (ind_ptrs.begin (), ind_ptrs.end (), host_ind_ptrs_shp->begin ());
1017- sycl::event device_ind_ptrs_copy_ev = exec_q.copy <char *>(
1023+ sycl::event packed_ind_ptrs_copy_ev = exec_q.copy <char *>(
10181024 host_ind_ptrs_shp->data (), packed_ind_ptrs, host_ind_ptrs_shp->size ());
10191025
1020- std::copy (ind_sh_sts.begin (), ind_sh_sts.end (),
1021- host_ind_shapes_strides_shp->begin ());
1022- sycl::event device_ind_shapes_strides_copy_ev = exec_q.copy <py::ssize_t >(
1026+ sycl::event packed_ind_shapes_strides_copy_ev = exec_q.copy <py::ssize_t >(
10231027 host_ind_shapes_strides_shp->data (), packed_ind_shapes_strides,
10241028 host_ind_shapes_strides_shp->size ());
10251029
1026- std::copy (ind_offsets.begin (), ind_offsets.end (),
1027- host_ind_offsets_shp->begin ());
1028- sycl::event device_ind_offsets_copy_ev = exec_q.copy <py::ssize_t >(
1030+ sycl::event packed_ind_offsets_copy_ev = exec_q.copy <py::ssize_t >(
10291031 host_ind_offsets_shp->data (), packed_ind_offsets,
10301032 host_ind_offsets_shp->size ());
10311033
10321034 sycl::event shared_ptr_cleanup_host_task =
10331035 exec_q.submit ([&](sycl::handler &cgh) {
1034- cgh.depends_on (device_ind_ptrs_copy_ev);
1035- cgh. depends_on (device_ind_shapes_strides_copy_ev);
1036- cgh. depends_on (device_ind_offsets_copy_ev );
1037- cgh.host_task ([host_ind_ptrs_shp , host_ind_shapes_strides_shp,
1038- host_ind_offsets_shp ]() {});
1036+ cgh.depends_on ({packed_ind_offsets_copy_ev,
1037+ packed_ind_shapes_strides_copy_ev,
1038+ packed_ind_ptrs_copy_ev} );
1039+ cgh.host_task ([host_ind_offsets_shp , host_ind_shapes_strides_shp,
1040+ host_ind_ptrs_shp ]() {});
10391041 });
10401042 host_task_events.push_back (shared_ptr_cleanup_host_task);
10411043
1042- std::vector<sycl::event> ind_pack_depends{device_ind_ptrs_copy_ev ,
1043- device_ind_shapes_strides_copy_ev ,
1044- device_ind_offsets_copy_ev };
1044+ std::vector<sycl::event> ind_pack_depends{packed_ind_ptrs_copy_ev ,
1045+ packed_ind_shapes_strides_copy_ev ,
1046+ packed_ind_offsets_copy_ev };
10451047
10461048 bool is_dst_c_contig = dst.is_c_contiguous ();
10471049 bool is_dst_f_contig = dst.is_f_contiguous ();
0 commit comments