1818import numpy as np
1919
2020import dpctl .tensor as dpt
21- from dpctl .tensor ._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
2221from dpctl .tensor ._tensor_impl import (
2322 _copy_usm_ndarray_for_reshape ,
2423 _ravel_multi_index ,
@@ -155,24 +154,25 @@ def reshape(X, /, shape, *, order="C", copy=None):
155154 "Reshaping the array requires a copy, but no copying was "
156155 "requested by using copy=False"
157156 )
157+ copy_q = X .sycl_queue
158158 if copy_required or (copy is True ):
159159 # must perform a copy
160160 flat_res = dpt .usm_ndarray (
161161 (X .size ,),
162162 dtype = X .dtype ,
163163 buffer = X .usm_type ,
164- buffer_ctor_kwargs = {"queue" : X . sycl_queue },
164+ buffer_ctor_kwargs = {"queue" : copy_q },
165165 )
166166 if order == "C" :
167167 hev , _ = _copy_usm_ndarray_for_reshape (
168- src = X , dst = flat_res , sycl_queue = X . sycl_queue
168+ src = X , dst = flat_res , sycl_queue = copy_q
169169 )
170- hev .wait ()
171170 else :
172- for i in range (X .size ):
173- _copy_from_usm_ndarray_to_usm_ndarray (
174- flat_res [i ], X [np .unravel_index (i , X .shape , order = order )]
175- )
171+ X_t = dpt .permute_dims (X , range (X .ndim - 1 , - 1 , - 1 ))
172+ hev , _ = _copy_usm_ndarray_for_reshape (
173+ src = X_t , dst = flat_res , sycl_queue = copy_q
174+ )
175+ hev .wait ()
176176 return dpt .usm_ndarray (
177177 tuple (shape ), dtype = X .dtype , buffer = flat_res , order = order
178178 )
@@ -182,5 +182,5 @@ def reshape(X, /, shape, *, order="C", copy=None):
182182 dtype = X .dtype ,
183183 buffer = X ,
184184 strides = tuple (newsts ),
185- offset = X .__sycl_usm_array_interface__ . get ( "offset" , 0 ) ,
185+ offset = X ._element_offset ,
186186 )
0 commit comments