|
21 | 21 |
|
22 | 22 | import dpctl |
23 | 23 | import dpctl.tensor as dpt |
24 | | -from dpctl.tensor._tensor_impl import _put, _take |
| 24 | +import dpctl.tensor._tensor_impl as ti |
25 | 25 |
|
26 | | -from ._copy_utils import _extract_impl, _nonzero_impl, _place_impl |
| 26 | +from ._copy_utils import _extract_impl, _nonzero_impl |
27 | 27 |
|
28 | 28 |
|
29 | 29 | def take(x, indices, /, *, axis=None, mode="clip"): |
@@ -95,7 +95,7 @@ def take(x, indices, /, *, axis=None, mode="clip"): |
95 | 95 | res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q |
96 | 96 | ) |
97 | 97 |
|
98 | | - hev, _ = _take(x, indices, res, axis, mode, sycl_queue=exec_q) |
| 98 | + hev, _ = ti._take(x, indices, res, axis, mode, sycl_queue=exec_q) |
99 | 99 | hev.wait() |
100 | 100 |
|
101 | 101 | return res |
@@ -175,7 +175,7 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"): |
175 | 175 |
|
176 | 176 | vals = dpt.broadcast_to(vals, val_shape) |
177 | 177 |
|
178 | | - hev, _ = _put(x, indices, vals, axis, mode, sycl_queue=exec_q) |
| 178 | + hev, _ = ti._put(x, indices, vals, axis, mode, sycl_queue=exec_q) |
179 | 179 | hev.wait() |
180 | 180 |
|
181 | 181 |
|
@@ -265,8 +265,23 @@ def place(arr, mask, vals): |
265 | 265 | raise dpctl.utils.ExecutionPlacementError |
266 | 266 | if arr.shape != mask.shape or vals.ndim != 1: |
267 | 267 | raise ValueError("Array sizes are not as required") |
268 | | - # FIXME |
269 | | - _place_impl(arr, mask, vals, axis=0) |
| 268 | + cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q) |
| 269 | + nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q) |
| 270 | + if nz_count == 0: |
| 271 | + return |
| 272 | + if vals.dtype == arr.dtype: |
| 273 | + rhs = vals |
| 274 | + else: |
| 275 | + rhs = dpt.astype(vals, arr.dtype) |
| 276 | + hev, _ = ti._place( |
| 277 | + dst=arr, |
| 278 | + cumsum=cumsum, |
| 279 | + axis_start=0, |
| 280 | + axis_end=mask.ndim, |
| 281 | + rhs=rhs, |
| 282 | + sycl_queue=exec_q, |
| 283 | + ) |
| 284 | + hev.wait() |
270 | 285 |
|
271 | 286 |
|
272 | 287 | def nonzero(arr): |
|
0 commit comments