|
18 | 18 | import numpy as np |
19 | 19 |
|
20 | 20 | import dpctl.tensor as dpt |
21 | | -from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray |
22 | 21 | from dpctl.tensor._tensor_impl import ( |
23 | 22 | _copy_usm_ndarray_for_reshape, |
24 | 23 | _ravel_multi_index, |
@@ -155,32 +154,37 @@ def reshape(X, /, shape, *, order="C", copy=None): |
155 | 154 | "Reshaping the array requires a copy, but no copying was " |
156 | 155 | "requested by using copy=False" |
157 | 156 | ) |
| 157 | + copy_q = X.sycl_queue |
158 | 158 | if copy_required or (copy is True): |
159 | 159 | # must perform a copy |
160 | 160 | flat_res = dpt.usm_ndarray( |
161 | 161 | (X.size,), |
162 | 162 | dtype=X.dtype, |
163 | 163 | buffer=X.usm_type, |
164 | | - buffer_ctor_kwargs={"queue": X.sycl_queue}, |
| 164 | + buffer_ctor_kwargs={"queue": copy_q}, |
165 | 165 | ) |
166 | 166 | if order == "C": |
167 | 167 | hev, _ = _copy_usm_ndarray_for_reshape( |
168 | | - src=X, dst=flat_res, sycl_queue=X.sycl_queue |
| 168 | + src=X, dst=flat_res, sycl_queue=copy_q |
169 | 169 | ) |
170 | | - hev.wait() |
171 | 170 | else: |
172 | | - for i in range(X.size): |
173 | | - _copy_from_usm_ndarray_to_usm_ndarray( |
174 | | - flat_res[i], X[np.unravel_index(i, X.shape, order=order)] |
175 | | - ) |
| 171 | + X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1)) |
| 172 | + hev, _ = _copy_usm_ndarray_for_reshape( |
| 173 | + src=X_t, dst=flat_res, sycl_queue=copy_q |
| 174 | + ) |
| 175 | + hev.wait() |
176 | 176 | return dpt.usm_ndarray( |
177 | 177 | tuple(shape), dtype=X.dtype, buffer=flat_res, order=order |
178 | 178 | ) |
179 | 179 | # can form a view |
| 180 | + if (len(shape) == X.ndim) and all( |
| 181 | + s1 == s2 for s1, s2 in zip(shape, X.shape) |
| 182 | + ): |
| 183 | + return X |
180 | 184 | return dpt.usm_ndarray( |
181 | 185 | shape, |
182 | 186 | dtype=X.dtype, |
183 | 187 | buffer=X, |
184 | 188 | strides=tuple(newsts), |
185 | | - offset=X.__sycl_usm_array_interface__.get("offset", 0), |
| 189 | + offset=X._element_offset, |
186 | 190 | ) |
0 commit comments