@@ -61,9 +61,9 @@ typedef sycl::event (*full_contig_fn_ptr_t)(sycl::queue &,
6161 *
6262 * @param exec_q Sycl queue to which kernel is submitted for execution.
6363 * @param nelems Length of the sequence
64- * @param py_value Python object representing the value to fill the array with.
64+ * @param py_value Python object representing the value to fill the array with.
6565 * Must be convertible to `dstTy`.
66- * @param dst_p Kernel accessible USM pointer to the start of array to be
66+ * @param dst_p Kernel accessible USM pointer to the start of array to be
6767 * populated.
6868 * @param depends List of events to wait for before starting computations, if
6969 * any.
@@ -152,7 +152,62 @@ template <typename fnT, typename Ty> struct FullContigFactory
152152 }
153153};
154154
155+ typedef sycl::event (*full_strided_fn_ptr_t )(sycl::queue &,
156+ int ,
157+ size_t ,
158+ py::ssize_t *,
159+ const py::object &,
160+ char *,
161+ const std::vector<sycl::event> &);
162+
163+ /* !
164+ * @brief Function to submit kernel to fill given strided memory allocation
165+ * with specified value.
166+ *
167+ * @param exec_q Sycl queue to which kernel is submitted for execution.
168+ * @param nd Array dimensionality
169+ * @param nelems Length of the sequence
170+ * @param shape_strides Kernel accessible USM pointer to packed shape and
171+ * strides of array.
172+ * @param py_value Python object representing the value to fill the array with.
173+ * Must be convertible to `dstTy`.
174+ * @param dst_p Kernel accessible USM pointer to the start of array to be
175+ * populated.
176+ * @param depends List of events to wait for before starting computations, if
177+ * any.
178+ *
179+ * @return Event to wait on to ensure that computation completes.
180+ * @defgroup CtorKernels
181+ */
182+ template <typename dstTy>
183+ sycl::event full_strided_impl (sycl::queue &exec_q,
184+ int nd,
185+ size_t nelems,
186+ py::ssize_t *shape_strides,
187+ const py::object &py_value,
188+ char *dst_p,
189+ const std::vector<sycl::event> &depends)
190+ {
191+ dstTy fill_v = py::cast<dstTy>(py_value);
192+
193+ using dpctl::tensor::kernels::constructors::full_strided_impl;
194+ sycl::event fill_ev = full_strided_impl<dstTy>(
195+ exec_q, nd, nelems, shape_strides, fill_v, dst_p, depends);
196+
197+ return fill_ev;
198+ }
199+
200+ template <typename fnT, typename Ty> struct FullStridedFactory
201+ {
202+ fnT get ()
203+ {
204+ fnT f = full_strided_impl<Ty>;
205+ return f;
206+ }
207+ };
208+
155209static full_contig_fn_ptr_t full_contig_dispatch_vector[td_ns::num_types];
210+ static full_strided_fn_ptr_t full_strided_dispatch_vector[td_ns::num_types];
156211
157212std::pair<sycl::event, sycl::event>
158213usm_ndarray_full (const py::object &py_value,
@@ -194,8 +249,42 @@ usm_ndarray_full(const py::object &py_value,
194249 full_contig_event);
195250 }
196251 else {
197- throw std::runtime_error (
198- " Only population of contiguous usm_ndarray objects is supported." );
252+ int nd = dst.get_ndim ();
253+ auto const &dst_shape = dst.get_shape_vector ();
254+ auto const &dst_strides = dst.get_strides_vector ();
255+
256+ auto fn = full_strided_dispatch_vector[dst_typeid];
257+
258+ std::vector<sycl::event> host_task_events;
259+ host_task_events.reserve (2 );
260+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
261+ const auto &ptr_size_event_tuple =
262+ device_allocate_and_pack<py::ssize_t >(exec_q, host_task_events,
263+ dst_shape, dst_strides);
264+ py::ssize_t *shape_strides = std::get<0 >(ptr_size_event_tuple);
265+ if (shape_strides == nullptr ) {
266+ throw std::runtime_error (" Unable to allocate device memory" );
267+ }
268+ const sycl::event ©_shape_ev = std::get<2 >(ptr_size_event_tuple);
269+
270+ const sycl::event &full_strided_ev =
271+ fn (exec_q, nd, dst_nelems, shape_strides, py_value, dst_data,
272+ {copy_shape_ev});
273+
274+ // free shape_strides
275+ const auto &ctx = exec_q.get_context ();
276+ const auto &temporaries_cleanup_ev =
277+ exec_q.submit ([&](sycl::handler &cgh) {
278+ cgh.depends_on (full_strided_ev);
279+ using dpctl::tensor::alloc_utils::sycl_free_noexcept;
280+ cgh.host_task ([ctx, shape_strides]() {
281+ sycl_free_noexcept (shape_strides, ctx);
282+ });
283+ });
284+ host_task_events.push_back (temporaries_cleanup_ev);
285+
286+ return std::make_pair (keep_args_alive (exec_q, {dst}, host_task_events),
287+ full_strided_ev);
199288 }
200289}
201290
@@ -204,10 +293,12 @@ void init_full_ctor_dispatch_vectors(void)
204293 using namespace td_ns ;
205294
206295 DispatchVectorBuilder<full_contig_fn_ptr_t , FullContigFactory, num_types>
207- dvb ;
208- dvb .populate_dispatch_vector (full_contig_dispatch_vector);
296+ dvb1 ;
297+ dvb1 .populate_dispatch_vector (full_contig_dispatch_vector);
209298
210- return ;
299+ DispatchVectorBuilder<full_strided_fn_ptr_t , FullStridedFactory, num_types>
300+ dvb2;
301+ dvb2.populate_dispatch_vector (full_strided_dispatch_vector);
211302}
212303
213304} // namespace py_internal
0 commit comments