@@ -99,8 +99,6 @@ std::vector<sycl::event> _populate_packed_shapes_strides_for_indexing(
9999 std::shared_ptr<shT> packed_host_axes_shapes_strides_shp =
100100 std::make_shared<shT>(2 * k + along_sh_elems, allocator);
101101
102- // can be made more efficient by checking if inp_nd > 1, then performing
103- // same treatment of orthog_sh_elems as for 0D (orthog will not exist)
104102 if (inp_nd > 0 ) {
105103 std::copy (inp_shape, inp_shape + axis_start,
106104 packed_host_shapes_strides_shp->begin ());
@@ -403,6 +401,17 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
403401 }
404402 }
405403
404+ // destination must be ample enough to accommodate all elements
405+ {
406+ size_t range =
407+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
408+ if ((range + 1 ) < (orthog_nelems * ind_nelems)) {
409+ throw py::value_error (
410+ " Destination array can not accommodate all the "
411+ " elements of source array." );
412+ }
413+ }
414+
406415 auto ind_sh_elems = (ind_nd > 0 ) ? ind_nd : 1 ;
407416
408417 std::vector<char *> ind_ptrs;
@@ -580,17 +589,6 @@ usm_ndarray_take(dpctl::tensor::usm_ndarray src,
580589 const py::ssize_t *src_strides = src.get_strides_raw ();
581590 const py::ssize_t *dst_strides = dst.get_strides_raw ();
582591
583- // destination must be ample enough to accommodate all elements
584- {
585- size_t range =
586- static_cast <size_t >(dst_offsets.second - dst_offsets.first );
587- if ((range + 1 ) < (orthog_nelems * ind_nelems)) {
588- throw py::value_error (
589- " Destination array can not accommodate all the "
590- " elements of source array." );
591- }
592- }
593-
594592 // packed_shapes_strides = [src_shape[:axis] + src_shape[axis+k:],
595593 // src_strides[:axis] + src_strides[axis+k:],
596594 // dst_strides[:axis] + dst_strides[axis+k:]]
@@ -765,6 +763,17 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
765763 throw py::value_error (" Arrays index overlapping segments of memory" );
766764 }
767765
766+ // destination must be ample enough to accommodate all possible elements
767+ {
768+ size_t range =
769+ static_cast <size_t >(dst_offsets.second - dst_offsets.first );
770+ if ((range + 1 ) < dst_nelems) {
771+ throw py::value_error (
772+ " Destination array can not accommodate all the "
773+ " elements of source array." );
774+ }
775+ }
776+
768777 int dst_typenum = dst.get_typenum ();
769778 int val_typenum = val.get_typenum ();
770779
@@ -965,17 +974,6 @@ usm_ndarray_put(dpctl::tensor::usm_ndarray dst,
965974 const py::ssize_t *dst_strides = dst.get_strides_raw ();
966975 const py::ssize_t *val_strides = val.get_strides_raw ();
967976
968- // destination must be ample enough to accommodate all possible elements
969- {
970- size_t range =
971- static_cast <size_t >(dst_offsets.second - dst_offsets.first );
972- if ((range + 1 ) < dst_nelems) {
973- throw py::value_error (
974- " Destination array can not accommodate all the "
975- " elements of source array." );
976- }
977- }
978-
979977 // packed_shapes_strides = [dst_shape[:axis] + dst_shape[axis+k:],
980978 // dst_strides[:axis] + dst_strides[axis+k:],
981979 // val_strides[:axis] + val_strides[axis+k:]]
0 commit comments