2626import dpctl .utils
2727from dpctl .tensor ._data_types import _get_dtype
2828from dpctl .tensor ._device import normalize_queue_device
29+ from dpctl .tensor ._type_utils import _dtype_supported_by_device_impl
2930
3031__doc__ = (
3132 "Implementation module for copy- and cast- operations on "
@@ -121,7 +122,7 @@ def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None):
121122 output array is created. Device can be specified by a
122123 a filter selector string, an instance of
123124 :class:`dpctl.SyclDevice`, an instance of
124- :class:`dpctl.SyclQueue`, an instance of
125+ :class:`dpctl.SyclQueue`, or an instance of
125126 :class:`dpctl.tensor.Device`. If the value is `None`,
126127 returned array is created on the default-selected device.
127128 Default: `None`.
@@ -564,9 +565,11 @@ def copy(usm_ary, order="K"):
564565 return R
565566
566567
567- def astype (usm_ary , newdtype , order = "K" , casting = "unsafe" , copy = True ):
568+ def astype (
569+ usm_ary , newdtype , / , order = "K" , casting = "unsafe" , * , copy = True , device = None
570+ ):
568571 """ astype(array, new_dtype, order="K", casting="unsafe", \
569- copy=True)
572+ copy=True, device=None )
570573
571574 Returns a copy of the :class:`dpctl.tensor.usm_ndarray`, cast to a
572575 specified type.
@@ -576,7 +579,8 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
576579 An input array.
577580 new_dtype (dtype):
578581 The data type of the resulting array. If `None`, gives default
579- floating point type supported by device where `array` is allocated.
582+ floating point type supported by device where the resulting array
583+ will be located.
580584 order ({"C", "F", "A", "K"}, optional):
581585 Controls memory layout of the resulting array if a copy
582586 is returned.
@@ -587,6 +591,14 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
587591 By default, `astype` always returns a newly allocated array.
588592 If this keyword is set to `False`, a view of the input array
589593 may be returned when possible.
594+ device (object): array API specification of device where the
595+ output array is created. Device can be specified by a
596+ a filter selector string, an instance of
597+ :class:`dpctl.SyclDevice`, an instance of
598+ :class:`dpctl.SyclQueue`, or an instance of
599+ :class:`dpctl.tensor.Device`. If the value is `None`,
600+ returned array is created on the same device as `array`.
601+ Default: `None`.
590602
591603 Returns:
592604 usm_ndarray:
@@ -604,7 +616,25 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
604616 )
605617 order = order [0 ].upper ()
606618 ary_dtype = usm_ary .dtype
607- target_dtype = _get_dtype (newdtype , usm_ary .sycl_queue )
619+ if device is not None :
620+ if not isinstance (device , dpctl .SyclQueue ):
621+ if isinstance (device , dpt .Device ):
622+ device = device .sycl_queue
623+ else :
624+ device = dpt .Device .create_device (device ).sycl_queue
625+ d = device .sycl_device
626+ target_dtype = _get_dtype (newdtype , device )
627+ if not _dtype_supported_by_device_impl (
628+ target_dtype , d .has_aspect_fp16 , d .has_aspect_fp64
629+ ):
630+ raise ValueError (
631+ f"Requested dtype `{ target_dtype } ` is not supported by the "
632+ "target device"
633+ )
634+ usm_ary = usm_ary .to_device (device )
635+ else :
636+ target_dtype = _get_dtype (newdtype , usm_ary .sycl_queue )
637+
608638 if not dpt .can_cast (ary_dtype , target_dtype , casting = casting ):
609639 raise TypeError (
610640 f"Can not cast from { ary_dtype } to { newdtype } "
0 commit comments