@@ -193,6 +193,28 @@ def _asarray_from_usm_ndarray(
193193 return res
194194
195195
196+ def _map_to_device_dtype (dt , q ):
197+ if dt .char == "?" or np .issubdtype (dt , np .integer ):
198+ return dt
199+ d = q .sycl_device
200+ dtc = dt .char
201+ if np .issubdtype (dt , np .floating ):
202+ if dtc == "f" :
203+ return dt
204+ else :
205+ if dtc == "d" and d .has_aspect_fp64 :
206+ return dt
207+ if dtc == "h" and d .has_aspect_fp16 :
208+ return dt
209+ return dpt .dtype ("f4" )
210+ elif np .issubdtype (dt , np .complexfloating ):
211+ if dtc == "F" :
212+ return dt
213+ if dtc == "D" and d .has_aspect_fp64 :
214+ return dt
215+ return dpt .dtype ("c8" )
216+
217+
196218def _asarray_from_numpy_ndarray (
197219 ary , dtype = None , usm_type = None , sycl_queue = None , order = "K"
198220):
@@ -207,10 +229,8 @@ def _asarray_from_numpy_ndarray(
207229 "Please convert the input to an array with numeric data type."
208230 )
209231 if dtype is None :
210- ary_dtype = ary .dtype
211- dtype = _get_dtype (dtype , copy_q , ref_type = ary_dtype )
212- if dtype .itemsize > ary_dtype .itemsize or ary_dtype == np .uint64 :
213- dtype = ary_dtype
232+ # deduce device-representable output data type
233+ dtype = _map_to_device_dtype (ary .dtype , copy_q )
214234 f_contig = ary .flags ["F" ]
215235 c_contig = ary .flags ["C" ]
216236 fc_contig = f_contig or c_contig
@@ -246,7 +266,7 @@ def _asarray_from_numpy_ndarray(
246266 order = order ,
247267 buffer_ctor_kwargs = {"queue" : copy_q },
248268 )
249- ti . _copy_numpy_ndarray_into_usm_ndarray ( src = ary , dst = res , sycl_queue = copy_q )
269+ res [...] = ary
250270 return res
251271
252272
0 commit comments