@@ -41,25 +41,29 @@ def take(x, indices, /, *, axis=None, mode="wrap"):
4141 Takes elements from an array along a given axis at given indices.
4242
4343 Args:
44- x (usm_ndarray):
45- The array that elements will be taken from.
46- indices (usm_ndarray):
47- One-dimensional array of indices.
48- axis:
49- The axis along which the values will be selected.
50- If ``x`` is one-dimensional, this argument is optional.
51- Default: ``None``.
52- mode:
53- How out-of-bounds indices will be handled.
54- ``"wrap"`` - clamps indices to (-n <= i < n), then wraps
55- negative indices.
56- ``"clip"`` - clips indices to (0 <= i < n)
57- Default: ``"wrap"``.
44+ x (usm_ndarray):
45+ The array that elements will be taken from.
46+ indices (usm_ndarray):
47+ One-dimensional array of indices.
48+ axis (int, optional):
49+ The axis along which the values will be selected.
50+ If ``x`` is one-dimensional, this argument is optional.
51+ Default: ``None``.
52+ mode (str, optional):
53+ How out-of-bounds indices will be handled. Possible values
54+ are:
55+
56+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
57+ negative indices.
58+ - ``"clip"``: clips indices to (``0 <= i < n``).
59+
60+ Default: ``"wrap"``.
5861
5962 Returns:
6063 usm_ndarray:
61- Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
62- filled with elements from x.
64+ Array with shape
65+ ``x.shape[:axis] + indices.shape + x.shape[axis + 1:]``
66+ filled with elements from ``x``.
6367 """
6468 if not isinstance (x , dpt .usm_ndarray ):
6569 raise TypeError (
@@ -128,30 +132,71 @@ def put(x, indices, vals, /, *, axis=None, mode="wrap"):
128132 Puts values into an array along a given axis at given indices.
129133
130134 Args:
131- x (usm_ndarray):
132- The array the values will be put into.
133- indices (usm_ndarray)
134- One-dimensional array of indices.
135-
136- Note that if indices are not unique, a race
137- condition will result, and the value written to
138- ``x`` will not be deterministic.
139- :py:func:`dpctl.tensor.unique` can be used to
140- guarantee unique elements in ``indices``.
141- vals:
142- Array of values to be put into ``x``.
143- Must be broadcastable to the result shape
144- ``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
145- axis:
146- The axis along which the values will be placed.
147- If ``x`` is one-dimensional, this argument is optional.
148- Default: ``None``.
149- mode:
150- How out-of-bounds indices will be handled.
151- ``"wrap"`` - clamps indices to (-n <= i < n), then wraps
152- negative indices.
153- ``"clip"`` - clips indices to (0 <= i < n)
154- Default: ``"wrap"``.
135+ x (usm_ndarray):
136+ The array the values will be put into.
137+ indices (usm_ndarray):
138+ One-dimensional array of indices.
139+ vals (usm_ndarray):
140+ Array of values to be put into ``x``.
141+ Must be broadcastable to the result shape
142+ ``x.shape[:axis] + indices.shape + x.shape[axis+1:]``.
143+ axis (int, optional):
144+ The axis along which the values will be placed.
145+ If ``x`` is one-dimensional, this argument is optional.
146+ Default: ``None``.
147+ mode (str, optional):
148+ How out-of-bounds indices will be handled. Possible values
149+ are:
150+
151+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
152+ negative indices.
153+ - ``"clip"``: clips indices to (``0 <= i < n``).
154+
155+ Default: ``"wrap"``.
156+
157+ .. note::
158+
159+ If input array ``indices`` contains duplicates, a race condition
160+ occurs, and the value written into corresponding positions in ``x``
161+ may vary from run to run. Preserving sequential semantics in handing
162+ the duplicates to achieve deterministic behavior requires additional
163+ work, e.g.
164+
165+ :Example:
166+
167+ .. code-block:: python
168+
169+ from dpctl import tensor as dpt
170+
171+ def put_vec_duplicates(vec, ind, vals):
172+ "Put values into vec, handling possible duplicates in ind"
173+ assert vec.ndim, ind.ndim, vals.ndim == 1, 1, 1
174+
175+ # find positions of last occurences of each
176+ # unique index
177+ ind_flipped = dpt.flip(ind)
178+ ind_uniq = dpt.unique_all(ind_flipped).indices
179+ has_dups = len(ind) != len(ind_uniq)
180+
181+ if has_dups:
182+ ind_uniq = dpt.subtract(vec.size - 1, ind_uniq)
183+ ind = dpt.take(ind, ind_uniq)
184+ vals = dpt.take(vals, ind_uniq)
185+
186+ dpt.put(vec, ind, vals)
187+
188+ n = 512
189+ ind = dpt.concat((dpt.arange(n), dpt.arange(n, -1, step=-1)))
190+ x = dpt.zeros(ind.size, dtype="int32")
191+ vals = dpt.arange(ind.size, dtype=x.dtype)
192+
193+ # Values corresponding to last positions of
194+ # duplicate indices are written into the vector x
195+ put_vec_duplicates(x, ind, vals)
196+
197+ parts = (vals[-1:-n-2:-1], dpt.zeros(n, dtype=x.dtype))
198+ expected = dpt.concat(parts)
199+ assert dpt.all(x == expected)
155200 """
156201 if not isinstance (x , dpt .usm_ndarray ):
157202 raise TypeError (
@@ -237,22 +282,24 @@ def extract(condition, arr):
237282
238283 Returns the elements of an array that satisfies the condition.
239284
240- If `condition` is boolean ``dpctl.tensor.extract`` is
285+ If `` condition` ` is boolean ``dpctl.tensor.extract`` is
241286 equivalent to ``arr[condition]``.
242287
243288 Note that ``dpctl.tensor.place`` does the opposite of
244289 ``dpctl.tensor.extract``.
245290
246291 Args:
247292 conditions (usm_ndarray):
248- An array whose non-zero or True entries indicate the element
249- of `arr` to extract.
293+ An array whose non-zero or ``True`` entries indicate the element
294+ of ``arr`` to extract.
295+
250296 arr (usm_ndarray):
251- Input array of the same size as `condition`.
297+ Input array of the same size as `` condition` `.
252298
253299 Returns:
254300 usm_ndarray:
255- Rank 1 array of values from `arr` where `condition` is True.
301+ Rank 1 array of values from ``arr`` where ``condition`` is
302+ ``True``.
256303 """
257304 if not isinstance (condition , dpt .usm_ndarray ):
258305 raise TypeError (
@@ -280,20 +327,20 @@ def place(arr, mask, vals):
280327
281328 Change elements of an array based on conditional and input values.
282329
283- If `mask` is boolean ``dpctl.tensor.place`` is
330+ If `` mask` ` is boolean ``dpctl.tensor.place`` is
284331 equivalent to ``arr[condition] = vals``.
285332
286333 Args:
287334 arr (usm_ndarray):
288335 Array to put data into.
289336 mask (usm_ndarray):
290- Boolean mask array. Must have the same size as `arr`.
337+ Boolean mask array. Must have the same size as `` arr` `.
291338 vals (usm_ndarray, sequence):
292- Values to put into `arr`. Only the first N elements are
293- used, where N is the number of True values in `mask`. If
294- `vals` is smaller than N, it will be repeated, and if
295- elements of `arr` are to be masked, this sequence must be
296- non-empty. Array `vals` must be one dimensional.
339+ Values to put into `` arr` `. Only the first N elements are
340+ used, where N is the number of True values in `` mask` `. If
341+ `` vals` ` is smaller than N, it will be repeated, and if
342+ elements of `` arr` ` are to be masked, this sequence must be
343+ non-empty. Array `` vals` ` must be one dimensional.
297344 """
298345 if not isinstance (arr , dpt .usm_ndarray ):
299346 raise TypeError (
@@ -345,13 +392,14 @@ def nonzero(arr):
345392 Return the indices of non-zero elements.
346393
347394 Returns a tuple of usm_ndarrays, one for each dimension
348- of `arr`, containing the indices of the non-zero elements
349- in that dimension. The values of `arr` are always tested in
395+ of `` arr` `, containing the indices of the non-zero elements
396+ in that dimension. The values of `` arr` ` are always tested in
350397 row-major, C-style order.
351398
352399 Args:
353400 arr (usm_ndarray):
354401 Input array, which has non-zero array rank.
402+
355403 Returns:
356404 Tuple[usm_ndarray, ...]:
357405 Indices of non-zero array elements.
0 commit comments