2727
2828
2929def take (x , indices , / , * , axis = None , mode = "clip" ):
30+ """take(x, indices, axis=None, mode="clip")
31+
32+ Takes elements from array along a given axis.
33+
34+ Args:
35+ x: usm_ndarray
36+ The array that elements will be taken from.
37+ indices: usm_ndarray
38+ One-dimensional array of indices.
39+ axis:
40+ The axis over which the values will be selected.
41+ If x is one-dimensional, this argument is optional.
42+ mode:
43+ How out-of-bounds indices will be handled.
44+ "Clip" - clamps indices to (-n <= i < n), then wraps
45+ negative indices.
46+ "Wrap" - wraps both negative and positive indices.
47+
48+ Returns:
49+ out: usm_ndarray
50+ Array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:]
51+ filled with elements .
52+ """
3053 if not isinstance (x , dpt .usm_ndarray ):
3154 raise TypeError (
3255 "Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
3356 )
3457
35- if not isinstance (indices , list ) and not isinstance (indices , tuple ):
36- indices = (indices ,)
37-
38- queues_ = [
39- x .sycl_queue ,
40- ]
41- usm_types_ = [
42- x .usm_type ,
43- ]
44-
45- for i in indices :
46- if not isinstance (i , dpt .usm_ndarray ):
47- raise TypeError (
48- "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
49- type (i )
50- )
58+ if not isinstance (indices , dpt .usm_ndarray ):
59+ raise TypeError (
60+ "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
61+ type (indices )
5162 )
52- if not np .issubdtype (i .dtype , np .integer ):
53- raise IndexError (
54- "`indices` expected integer data type, got `{}`" .format (i .dtype )
63+ )
64+ if not np .issubdtype (indices .dtype , np .integer ):
65+ raise IndexError (
66+ "`indices` expected integer data type, got `{}`" .format (
67+ indices .dtype
5568 )
56- queues_ .append (i .sycl_queue )
57- usm_types_ .append (i .usm_type )
58- exec_q = dpctl .utils .get_execution_queue (queues_ )
59- if exec_q is None :
60- raise dpctl .utils .ExecutionPlacementError (
61- "Can not automatically determine where to allocate the "
62- "result or performance execution. "
63- "Use `usm_ndarray.to_device` method to migrate data to "
64- "be associated with the same queue."
6569 )
66- res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
70+ if indices .ndim != 1 :
71+ raise ValueError (
72+ "`indices` expected a 1D array, got `{}`" .format (indices .ndim )
73+ )
74+ exec_q = dpctl .utils .get_execution_queue ([x .sycl_queue , indices .sycl_queue ])
75+ if exec_q is None :
76+ raise dpctl .utils .ExecutionPlacementError
77+ res_usm_type = dpctl .utils .get_coerced_usm_type (
78+ [x .usm_type , indices .usm_type ]
79+ )
6780
6881 modes = {"clip" : 0 , "wrap" : 1 }
6982 try :
@@ -81,27 +94,47 @@ def take(x, indices, /, *, axis=None, mode="clip"):
8194 )
8295 axis = 0
8396
84- if len (indices ) > 1 :
85- indices = dpt .broadcast_arrays (* indices )
8697 if x_ndim > 0 :
8798 axis = normalize_axis_index (operator .index (axis ), x_ndim )
88- res_shape = (
89- x .shape [:axis ] + indices [0 ].shape + x .shape [axis + len (indices ) :]
90- )
99+ res_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
91100 else :
92- res_shape = indices [0 ].shape
101+ if axis != 0 :
102+ raise ValueError ("`axis` must be 0 for an array of dimension 0." )
103+ res_shape = indices .shape
93104
94105 res = dpt .empty (
95106 res_shape , dtype = x .dtype , usm_type = res_usm_type , sycl_queue = exec_q
96107 )
97108
98- hev , _ = ti ._take (x , indices , res , axis , mode , sycl_queue = exec_q )
109+ hev , _ = ti ._take (x , ( indices ,) , res , axis , mode , sycl_queue = exec_q )
99110 hev .wait ()
100111
101112 return res
102113
103114
104115def put (x , indices , vals , / , * , axis = None , mode = "clip" ):
116+ """put(x, indices, vals, axis=None, mode="clip")
117+
118+ Puts values of an array into another array
119+ along a given axis.
120+
121+ Args:
122+ x: usm_ndarray
123+ The array the values will be put into.
124+ indices: usm_ndarray
125+ One-dimensional array of indices.
126+ vals:
127+ Array of values to be put into `x`.
128+ Must be broadcastable to the shape of `indices`.
129+ axis:
130+ The axis over which the values will be placed.
131+ If x is one-dimensional, this argument is optional.
132+ mode:
133+ How out-of-bounds indices will be handled.
134+ "Clip" - clamps indices to (-axis_size <= i < axis_size),
135+ then wraps negative indices.
136+ "Wrap" - wraps both negative and positive indices.
137+ """
105138 if not isinstance (x , dpt .usm_ndarray ):
106139 raise TypeError (
107140 "Expected instance of `dpt.usm_ndarray`, got `{}`." .format (type (x ))
@@ -116,66 +149,61 @@ def put(x, indices, vals, /, *, axis=None, mode="clip"):
116149 usm_types_ = [
117150 x .usm_type ,
118151 ]
119-
120- if not isinstance (indices , list ) and not isinstance (indices , tuple ):
121- indices = (indices ,)
122-
123- for i in indices :
124- if not isinstance (i , dpt .usm_ndarray ):
125- raise TypeError (
126- "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
127- type (i )
128- )
152+ if not isinstance (indices , dpt .usm_ndarray ):
153+ raise TypeError (
154+ "`indices` expected `dpt.usm_ndarray`, got `{}`." .format (
155+ type (indices )
129156 )
130- if not np .issubdtype (i .dtype , np .integer ):
131- raise IndexError (
132- "`indices` expected integer data type, got `{}`" .format (i .dtype )
157+ )
158+ if indices .ndim != 1 :
159+ raise ValueError (
160+ "`indices` expected a 1D array, got `{}`" .format (indices .ndim )
161+ )
162+ if not np .issubdtype (indices .dtype , np .integer ):
163+ raise IndexError (
164+ "`indices` expected integer data type, got `{}`" .format (
165+ indices .dtype
133166 )
134- queues_ .append (i .sycl_queue )
135- usm_types_ .append (i .usm_type )
167+ )
168+ queues_ .append (indices .sycl_queue )
169+ usm_types_ .append (indices .usm_type )
136170 exec_q = dpctl .utils .get_execution_queue (queues_ )
137171 if exec_q is None :
138- raise dpctl .utils .ExecutionPlacementError (
139- "Can not automatically determine where to allocate the "
140- "result or performance execution. "
141- "Use `usm_ndarray.to_device` method to migrate data to "
142- "be associated with the same queue."
143- )
144- val_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
145-
172+ raise dpctl .utils .ExecutionPlacementError
173+ vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
146174 modes = {"clip" : 0 , "wrap" : 1 }
147175 try :
148176 mode = modes [mode ]
149177 except KeyError :
150- raise ValueError ("`mode` must be `wrap`, or `clip `." )
178+ raise ValueError ("`mode` must be `clip` or `wrap `." )
151179
152- # when axis is none, array is treated as 1D
153- if axis is None :
154- try :
155- x = dpt .reshape (x , (x .size ,), copy = False )
156- axis = 0
157- except ValueError :
158- raise ValueError ("Cannot create 1D view of input array" )
159- if len (indices ) > 1 :
160- indices = dpt .broadcast_arrays (* indices )
161180 x_ndim = x .ndim
181+ if axis is None :
182+ if x_ndim > 1 :
183+ raise ValueError (
184+ "`axis` cannot be `None` for array of dimension `{}`" .format (
185+ x_ndim
186+ )
187+ )
188+ axis = 0
189+
162190 if x_ndim > 0 :
163191 axis = normalize_axis_index (operator .index (axis ), x_ndim )
164192
165- val_shape = (
166- x .shape [:axis ] + indices [0 ].shape + x .shape [axis + len (indices ) :]
167- )
193+ val_shape = x .shape [:axis ] + indices .shape + x .shape [axis + 1 :]
168194 else :
169- val_shape = indices [0 ].shape
195+ if axis != 0 :
196+ raise ValueError ("`axis` must be 0 for an array of dimension 0." )
197+ val_shape = indices .shape
170198
171199 if not isinstance (vals , dpt .usm_ndarray ):
172200 vals = dpt .asarray (
173- vals , dtype = x .dtype , usm_type = val_usm_type , sycl_queue = exec_q
201+ vals , dtype = x .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
174202 )
175203
176204 vals = dpt .broadcast_to (vals , val_shape )
177205
178- hev , _ = ti ._put (x , indices , vals , axis , mode , sycl_queue = exec_q )
206+ hev , _ = ti ._put (x , ( indices ,) , vals , axis , mode , sycl_queue = exec_q )
179207 hev .wait ()
180208
181209
0 commit comments