Skip to content

Commit 07d0243

Browse files
committed
In order to resolve gh-2156 move definition of the mask_positions
and _cumsum_1d functions to _tensor_accumulations_impl Changed Python scripts accordingly, as well as CMake scripts to add implementation cpp file to the list of source files for the _tensor_accumulations_impl MODULE library. Also moved find_package(Python) to find Module.Development component before pybind11 is being activated to resolve CMake warning. Incidentally, this change also results in reduced binary size and improved compilation tiles, since accumulation kernels are not being generated in duplicates (once for _tensor_ctor module, and once for _tensor_accumulation_impl module).
1 parent 878cc19 commit 07d0243

File tree

9 files changed

+35
-20
lines changed

9 files changed

+35
-20
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ install(DIRECTORY
114114
FILES_MATCHING REGEX "\\.h(pp)?$"
115115
)
116116

117+
# find Python before enabling pybind11
118+
find_package(Python REQUIRED COMPONENTS Development.Module)
119+
117120
# Define CMAKE_INSTALL_xxx: LIBDIR, INCLUDEDIR
118121
include(GNUInstallDirs)
119122

dpctl/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
find_package(Python REQUIRED COMPONENTS Development.Module NumPy)
1+
find_package(Python REQUIRED COMPONENTS NumPy)
22

33
# -t is to only Cythonize sources with timestamps newer than existing CXX files (if present)
44
# -w is to set working directory (and correctly set __pyx_f[] array of filenames)

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ set(_accumulator_sources
171171
)
172172
set(_tensor_accumulation_impl_sources
173173
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
174+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators.cpp
174175
${_accumulator_sources}
175176
)
176177

dpctl/tensor/_copy_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import dpctl.memory as dpm
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._tensor_impl as ti
26+
from dpctl.tensor._tensor_accumulation_impl import mask_positions
2627
import dpctl.utils
2728
from dpctl.tensor._data_types import _get_dtype
2829
from dpctl.tensor._device import normalize_queue_device
@@ -792,7 +793,7 @@ def _extract_impl(ary, ary_mask, axis=0):
792793
exec_q = cumsum.sycl_queue
793794
_manager = dpctl.utils.SequentialOrderManager[exec_q]
794795
dep_evs = _manager.submitted_events
795-
mask_count = ti.mask_positions(
796+
mask_count = mask_positions(
796797
ary_mask, cumsum, sycl_queue=exec_q, depends=dep_evs
797798
)
798799
dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
@@ -828,7 +829,7 @@ def _nonzero_impl(ary):
828829
)
829830
_manager = dpctl.utils.SequentialOrderManager[exec_q]
830831
dep_evs = _manager.submitted_events
831-
mask_count = ti.mask_positions(
832+
mask_count = mask_positions(
832833
ary, cumsum, sycl_queue=exec_q, depends=dep_evs
833834
)
834835
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
@@ -1050,7 +1051,7 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10501051
exec_q = cumsum.sycl_queue
10511052
_manager = dpctl.utils.SequentialOrderManager[exec_q]
10521053
dep_ev = _manager.submitted_events
1053-
mask_count = ti.mask_positions(
1054+
mask_count = mask_positions(
10541055
ary_mask, cumsum, sycl_queue=exec_q, depends=dep_ev
10551056
)
10561057
expected_vals_shape = (

dpctl/tensor/_indexing_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl
2020
import dpctl.tensor as dpt
2121
import dpctl.tensor._tensor_impl as ti
22+
from dpctl.tensor._tensor_accumulation_impl import mask_positions
2223
import dpctl.utils
2324

2425
from ._copy_utils import (
@@ -413,7 +414,7 @@ def place(arr, mask, vals):
413414
cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q)
414415
_manager = dpctl.utils.SequentialOrderManager[exec_q]
415416
deps_ev = _manager.submitted_events
416-
nz_count = ti.mask_positions(
417+
nz_count = mask_positions(
417418
mask, cumsum, sycl_queue=exec_q, depends=deps_ev
418419
)
419420
if nz_count == 0:

dpctl/tensor/_manipulation_functions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import dpctl
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._tensor_impl as ti
26+
from dpctl.tensor._tensor_accumulation_impl import _cumsum_1d
2627
import dpctl.utils as dputils
2728

2829
from ._copy_utils import _broadcast_strides
@@ -908,7 +909,7 @@ def repeat(x, repeats, /, *, axis=None):
908909
sycl_queue=exec_q,
909910
)
910911
# _cumsum_1d synchronizes so `depends` ends here safely
911-
res_axis_size = ti._cumsum_1d(
912+
res_axis_size = _cumsum_1d(
912913
rep_buf, cumsum, sycl_queue=exec_q, depends=[copy_ev]
913914
)
914915
if axis is not None:
@@ -940,7 +941,7 @@ def repeat(x, repeats, /, *, axis=None):
940941
usm_type=usm_type,
941942
sycl_queue=exec_q,
942943
)
943-
res_axis_size = ti._cumsum_1d(
944+
res_axis_size = _cumsum_1d(
944945
repeats, cumsum, sycl_queue=exec_q, depends=dep_evs
945946
)
946947
if axis is not None:

dpctl/tensor/_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
_linspace_step,
3535
_take,
3636
default_device_index_type,
37-
mask_positions,
3837
)
38+
from ._tensor_accumulation_impl import mask_positions
3939
from ._tensor_sorting_impl import (
4040
_argsort_ascending,
4141
_isin,

dpctl/tensor/libtensor/source/tensor_accumulation.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,30 @@
2424
//===----------------------------------------------------------------------===//
2525

2626
#include <pybind11/pybind11.h>
27+
#include <pybind11/stl.h>
2728

2829
#include "accumulators/accumulators_common.hpp"
30+
#include "accumulators.hpp"
2931

3032
namespace py = pybind11;
3133

34+
namespace py_int = dpctl::tensor::py_internal;
35+
36+
using py_int::py_mask_positions;
37+
using py_int::py_cumsum_1d;
38+
3239
PYBIND11_MODULE(_tensor_accumulation_impl, m)
3340
{
41+
py_int::populate_mask_positions_dispatch_vectors();
42+
py_int::populate_cumsum_1d_dispatch_vectors();
43+
3444
dpctl::tensor::py_internal::init_accumulator_functions(m);
45+
46+
m.def("mask_positions", &py_mask_positions, "", py::arg("mask"),
47+
py::arg("cumsum"), py::arg("sycl_queue"),
48+
py::arg("depends") = py::list());
49+
50+
m.def("_cumsum_1d", &py_cumsum_1d, "", py::arg("src"), py::arg("cumsum"),
51+
py::arg("sycl_queue"), py::arg("depends") = py::list());
52+
3553
}

dpctl/tensor/libtensor/source/tensor_ctors.cpp

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ using dpctl::tensor::py_internal::usm_ndarray_put;
105105
using dpctl::tensor::py_internal::usm_ndarray_take;
106106

107107
using dpctl::tensor::py_internal::py_extract;
108-
using dpctl::tensor::py_internal::py_mask_positions;
108+
// using dpctl::tensor::py_internal::py_mask_positions;
109109
using dpctl::tensor::py_internal::py_nonzero;
110110
using dpctl::tensor::py_internal::py_place;
111111

112112
/* ================= Repeat ====================*/
113-
using dpctl::tensor::py_internal::py_cumsum_1d;
113+
// using dpctl::tensor::py_internal::py_cumsum_1d;
114114
using dpctl::tensor::py_internal::py_repeat_by_scalar;
115115
using dpctl::tensor::py_internal::py_repeat_by_sequence;
116116

@@ -158,9 +158,6 @@ void init_dispatch_vectors(void)
158158
populate_masked_extract_dispatch_vectors();
159159
populate_masked_place_dispatch_vectors();
160160

161-
populate_mask_positions_dispatch_vectors();
162-
163-
populate_cumsum_1d_dispatch_vectors();
164161
init_repeat_dispatch_vectors();
165162

166163
init_clip_dispatch_vectors();
@@ -402,13 +399,6 @@ PYBIND11_MODULE(_tensor_impl, m)
402399
py::arg("dst"), py::arg("k") = 0, py::arg("sycl_queue"),
403400
py::arg("depends") = py::list());
404401

405-
m.def("mask_positions", &py_mask_positions, "", py::arg("mask"),
406-
py::arg("cumsum"), py::arg("sycl_queue"),
407-
py::arg("depends") = py::list());
408-
409-
m.def("_cumsum_1d", &py_cumsum_1d, "", py::arg("src"), py::arg("cumsum"),
410-
py::arg("sycl_queue"), py::arg("depends") = py::list());
411-
412402
m.def("_extract", &py_extract, "", py::arg("src"), py::arg("cumsum"),
413403
py::arg("axis_start"), py::arg("axis_end"), py::arg("dst"),
414404
py::arg("sycl_queue"), py::arg("depends") = py::list());

0 commit comments

Comments
 (0)