3333#include < type_traits>
3434
3535#include " dpctl4pybind11.hpp"
36- #include " kernels/constructors.hpp"
37- #include " kernels/copy_and_cast.hpp"
38- #include " utils/strided_iters.hpp"
39- #include " utils/type_dispatch.hpp"
40- #include " utils/type_utils.hpp"
4136
4237#include " copy_and_cast_usm_to_usm.hpp"
4338#include " copy_for_reshape.hpp"
4439#include " copy_numpy_ndarray_into_usm_ndarray.hpp"
4540#include " eye_ctor.hpp"
4641#include " full_ctor.hpp"
4742#include " linear_sequences.hpp"
48- #include " simplify_iteration_space.hpp"
43+ #include " triul_ctor.hpp"
44+ #include " utils/strided_iters.hpp"
4945
5046namespace py = pybind11;
51- namespace _ns = dpctl::tensor::detail;
5247
5348namespace
5449{
5550
5651using dpctl::tensor::c_contiguous_strides;
5752using dpctl::tensor::f_contiguous_strides;
5853
59- using dpctl::utils::keep_args_alive;
60-
6154using dpctl::tensor::py_internal::copy_usm_ndarray_into_usm_ndarray;
62- using dpctl::tensor::py_internal::simplify_iteration_space;
6355
6456/* =========================== Copy for reshape ============================= */
6557
@@ -84,253 +76,28 @@ using dpctl::tensor::py_internal::usm_ndarray_eye;
8476
8577/* =========================== Tril and triu ============================== */
8678
87- using dpctl::tensor::kernels::constructors::tri_fn_ptr_t ;
88-
89- static tri_fn_ptr_t tril_generic_dispatch_vector[_ns::num_types];
90- static tri_fn_ptr_t triu_generic_dispatch_vector[_ns::num_types];
91-
92- std::pair<sycl::event, sycl::event>
93- tri (sycl::queue &exec_q,
94- dpctl::tensor::usm_ndarray src,
95- dpctl::tensor::usm_ndarray dst,
96- char part,
97- py::ssize_t k = 0 ,
98- const std::vector<sycl::event> &depends = {})
99- {
100- // array dimensions must be the same
101- int src_nd = src.get_ndim ();
102- int dst_nd = dst.get_ndim ();
103- if (src_nd != dst_nd) {
104- throw py::value_error (" Array dimensions are not the same." );
105- }
106-
107- if (src_nd < 2 ) {
108- throw py::value_error (" Array dimensions less than 2." );
109- }
110-
111- // shapes must be the same
112- const py::ssize_t *src_shape = src.get_shape_raw ();
113- const py::ssize_t *dst_shape = dst.get_shape_raw ();
114-
115- bool shapes_equal (true );
116- size_t src_nelems (1 );
117-
118- for (int i = 0 ; shapes_equal && i < src_nd; ++i) {
119- src_nelems *= static_cast <size_t >(src_shape[i]);
120- shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
121- }
122- if (!shapes_equal) {
123- throw py::value_error (" Array shapes are not the same." );
124- }
125-
126- if (src_nelems == 0 ) {
127- // nothing to do
128- return std::make_pair (sycl::event (), sycl::event ());
129- }
130-
131- char *src_data = src.get_data ();
132- char *dst_data = dst.get_data ();
133-
134- // check that arrays do not overlap, and concurrent copying is safe.
135- auto src_offsets = src.get_minmax_offsets ();
136- auto dst_offsets = dst.get_minmax_offsets ();
137- int src_elem_size = src.get_elemsize ();
138- int dst_elem_size = dst.get_elemsize ();
139-
140- bool memory_overlap =
141- ((dst_data - src_data > src_offsets.second * src_elem_size -
142- dst_offsets.first * dst_elem_size) &&
143- (src_data - dst_data > dst_offsets.second * dst_elem_size -
144- src_offsets.first * src_elem_size));
145- if (memory_overlap) {
146- // TODO: could use a temporary, but this is done by the caller
147- throw py::value_error (" Arrays index overlapping segments of memory" );
148- }
149-
150- auto array_types = dpctl::tensor::detail::usm_ndarray_types ();
151-
152- int src_typenum = src.get_typenum ();
153- int dst_typenum = dst.get_typenum ();
154- int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
155- int dst_typeid = array_types.typenum_to_lookup_id (dst_typenum);
156-
157- if (dst_typeid != src_typeid) {
158- throw py::value_error (" Array dtype are not the same." );
159- }
160-
161- // check same contexts
162- sycl::queue src_q = src.get_queue ();
163- sycl::queue dst_q = dst.get_queue ();
164-
165- if (!dpctl::utils::queues_are_compatible (exec_q, {src_q, dst_q})) {
166- throw py::value_error (
167- " Execution queue context is not the same as allocation contexts" );
168- }
169-
170- using shT = std::vector<py::ssize_t >;
171- shT src_strides (src_nd);
172-
173- bool is_src_c_contig = src.is_c_contiguous ();
174- bool is_src_f_contig = src.is_f_contiguous ();
175-
176- const py::ssize_t *src_strides_raw = src.get_strides_raw ();
177- if (src_strides_raw == nullptr ) {
178- if (is_src_c_contig) {
179- src_strides = c_contiguous_strides (src_nd, src_shape);
180- }
181- else if (is_src_f_contig) {
182- src_strides = f_contiguous_strides (src_nd, src_shape);
183- }
184- else {
185- throw std::runtime_error (" Source array has null strides but has "
186- " neither C- nor F- contiguous flag set" );
187- }
188- }
189- else {
190- std::copy (src_strides_raw, src_strides_raw + src_nd,
191- src_strides.begin ());
192- }
193-
194- shT dst_strides (src_nd);
195-
196- bool is_dst_c_contig = dst.is_c_contiguous ();
197- bool is_dst_f_contig = dst.is_f_contiguous ();
198-
199- const py::ssize_t *dst_strides_raw = dst.get_strides_raw ();
200- if (dst_strides_raw == nullptr ) {
201- if (is_dst_c_contig) {
202- dst_strides =
203- dpctl::tensor::c_contiguous_strides (src_nd, src_shape);
204- }
205- else if (is_dst_f_contig) {
206- dst_strides =
207- dpctl::tensor::f_contiguous_strides (src_nd, src_shape);
208- }
209- else {
210- throw std::runtime_error (" Source array has null strides but has "
211- " neither C- nor F- contiguous flag set" );
212- }
213- }
214- else {
215- std::copy (dst_strides_raw, dst_strides_raw + dst_nd,
216- dst_strides.begin ());
217- }
218-
219- shT simplified_shape;
220- shT simplified_src_strides;
221- shT simplified_dst_strides;
222- py::ssize_t src_offset (0 );
223- py::ssize_t dst_offset (0 );
224-
225- constexpr py::ssize_t src_itemsize = 1 ; // item size in elements
226- constexpr py::ssize_t dst_itemsize = 1 ; // item size in elements
227-
228- int nd = src_nd - 2 ;
229- const py::ssize_t *shape = src_shape;
230- const py::ssize_t *p_src_strides = src_strides.data ();
231- const py::ssize_t *p_dst_strides = dst_strides.data ();
232-
233- simplify_iteration_space (nd, shape, p_src_strides, src_itemsize,
234- is_src_c_contig, is_src_f_contig, p_dst_strides,
235- dst_itemsize, is_dst_c_contig, is_dst_f_contig,
236- simplified_shape, simplified_src_strides,
237- simplified_dst_strides, src_offset, dst_offset);
238-
239- if (src_offset != 0 || dst_offset != 0 ) {
240- throw py::value_error (" Reversed slice for dst is not supported" );
241- }
242-
243- nd += 2 ;
244-
245- using usm_host_allocatorT =
246- sycl::usm_allocator<py::ssize_t , sycl::usm::alloc::host>;
247- using usmshT = std::vector<py::ssize_t , usm_host_allocatorT>;
248-
249- usm_host_allocatorT allocator (exec_q);
250- auto shp_host_shape_and_strides =
251- std::make_shared<usmshT>(3 * nd, allocator);
252-
253- std::copy (simplified_shape.begin (), simplified_shape.end (),
254- shp_host_shape_and_strides->begin ());
255- (*shp_host_shape_and_strides)[nd - 2 ] = src_shape[src_nd - 2 ];
256- (*shp_host_shape_and_strides)[nd - 1 ] = src_shape[src_nd - 1 ];
257-
258- std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
259- shp_host_shape_and_strides->begin () + nd);
260- (*shp_host_shape_and_strides)[2 * nd - 2 ] = src_strides[src_nd - 2 ];
261- (*shp_host_shape_and_strides)[2 * nd - 1 ] = src_strides[src_nd - 1 ];
262-
263- std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
264- shp_host_shape_and_strides->begin () + 2 * nd);
265- (*shp_host_shape_and_strides)[3 * nd - 2 ] = dst_strides[src_nd - 2 ];
266- (*shp_host_shape_and_strides)[3 * nd - 1 ] = dst_strides[src_nd - 1 ];
267-
268- py::ssize_t *dev_shape_and_strides =
269- sycl::malloc_device<ssize_t >(3 * nd, exec_q);
270- if (dev_shape_and_strides == nullptr ) {
271- throw std::runtime_error (" Unabled to allocate device memory" );
272- }
273- sycl::event copy_shape_and_strides = exec_q.copy <ssize_t >(
274- shp_host_shape_and_strides->data (), dev_shape_and_strides, 3 * nd);
275-
276- py::ssize_t inner_range = src_shape[src_nd - 1 ] * src_shape[src_nd - 2 ];
277- py::ssize_t outer_range = src_nelems / inner_range;
278-
279- sycl::event tri_ev;
280- if (part == ' l' ) {
281- auto fn = tril_generic_dispatch_vector[src_typeid];
282- tri_ev =
283- fn (exec_q, inner_range, outer_range, src_data, dst_data, nd,
284- dev_shape_and_strides, k, depends, {copy_shape_and_strides});
285- }
286- else {
287- auto fn = triu_generic_dispatch_vector[src_typeid];
288- tri_ev =
289- fn (exec_q, inner_range, outer_range, src_data, dst_data, nd,
290- dev_shape_and_strides, k, depends, {copy_shape_and_strides});
291- }
292-
293- exec_q.submit ([&](sycl::handler &cgh) {
294- cgh.depends_on ({tri_ev});
295- auto ctx = exec_q.get_context ();
296- cgh.host_task (
297- [shp_host_shape_and_strides, dev_shape_and_strides, ctx]() {
298- // capture of shp_host_shape_and_strides ensure the underlying
299- // vector exists for the entire execution of copying kernel
300- sycl::free (dev_shape_and_strides, ctx);
301- });
302- });
303-
304- return std::make_pair (keep_args_alive (exec_q, {src, dst}, {tri_ev}),
305- tri_ev);
306- }
79+ using dpctl::tensor::py_internal::usm_ndarray_triul;
30780
30881// populate dispatch tables
30982void init_dispatch_tables (void )
31083{
311- dpctl::tensor::py_internal::init_copy_and_cast_usm_to_usm_dispatch_tables ();
312- dpctl::tensor::py_internal::
313- init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables ();
84+ using namespace dpctl ::tensor::py_internal;
85+
86+ init_copy_and_cast_usm_to_usm_dispatch_tables ();
87+ init_copy_numpy_ndarray_into_usm_ndarray_dispatch_tables ();
31488 return ;
31589}
31690
31791// populate dispatch vectors
31892void init_dispatch_vectors (void )
31993{
320- dpctl::tensor::py_internal::init_copy_for_reshape_dispatch_vectors ();
321- dpctl::tensor::py_internal::init_linear_sequences_dispatch_vectors ();
322- dpctl::tensor::py_internal::init_full_ctor_dispatch_vectors ();
323- dpctl::tensor::py_internal::init_eye_ctor_dispatch_vectors ();
324-
325- using namespace dpctl ::tensor::detail;
326- using dpctl::tensor::kernels::constructors::TrilGenericFactory;
327- using dpctl::tensor::kernels::constructors::TriuGenericFactory;
328-
329- DispatchVectorBuilder<tri_fn_ptr_t , TrilGenericFactory, num_types> dvb5;
330- dvb5.populate_dispatch_vector (tril_generic_dispatch_vector);
94+ using namespace dpctl ::tensor::py_internal;
33195
332- DispatchVectorBuilder<tri_fn_ptr_t , TriuGenericFactory, num_types> dvb6;
333- dvb6.populate_dispatch_vector (triu_generic_dispatch_vector);
96+ init_copy_for_reshape_dispatch_vectors ();
97+ init_linear_sequences_dispatch_vectors ();
98+ init_full_ctor_dispatch_vectors ();
99+ init_eye_ctor_dispatch_vectors ();
100+ init_triul_ctor_dispatch_vectors ();
334101
335102 return ;
336103}
@@ -478,7 +245,7 @@ PYBIND11_MODULE(_tensor_impl, m)
478245 py::ssize_t k, sycl::queue exec_q,
479246 const std::vector<sycl::event> depends)
480247 -> std::pair<sycl::event, sycl::event> {
481- return tri (exec_q, src, dst, ' l' , k, depends);
248+ return usm_ndarray_triul (exec_q, src, dst, ' l' , k, depends);
482249 },
483250 " Tril helper function." , py::arg (" src" ), py::arg (" dst" ),
484251 py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
@@ -490,7 +257,7 @@ PYBIND11_MODULE(_tensor_impl, m)
490257 py::ssize_t k, sycl::queue exec_q,
491258 const std::vector<sycl::event> depends)
492259 -> std::pair<sycl::event, sycl::event> {
493- return tri (exec_q, src, dst, ' u' , k, depends);
260+ return usm_ndarray_triul (exec_q, src, dst, ' u' , k, depends);
494261 },
495262 " Triu helper function." , py::arg (" src" ), py::arg (" dst" ),
496263 py::arg (" k" ) = 0 , py::arg (" sycl_queue" ),
0 commit comments