@@ -159,7 +159,16 @@ def _check_device(xp, device):
159159 if device not in ["cpu" , None ]:
160160 raise ValueError (f"Unsupported device for NumPy: { device !r} " )
161161
162- # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
162+ # Placeholder object to represent the dask device
163+ # when the array backend is not the CPU.
164+ # (since it is not easy to tell which device a dask array is on)
165+ class _dask_device :
166+ def __repr__ (self ):
167+ return "DASK_DEVICE"
168+
169+ DASK_DEVICE = _dask_device ()
170+
171+ # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
163172# or cupy.ndarray. They are not included in array objects of this library
164173# because this library just reuses the respective ndarray classes without
165174# wrapping or subclassing them. These helper functions can be used instead of
@@ -179,11 +188,19 @@ def device(x: Array, /) -> Device:
179188 out: device
180189 a ``device`` object (see the "Device Support" section of the array API specification).
181190 """
182- if is_numpy_array (x ) or is_dask_array (x ):
183- # TODO: dask technically can support GPU arrays
184- # Detecting the array backend isn't easy for dask, though, so just return CPU for now
191+ if is_numpy_array (x ):
185192 return "cpu"
186- if is_jax_array (x ):
193+ elif is_dask_array (x ):
194+ # Peek at the metadata of the jax array to determine type
195+ try :
196+ import numpy as np
197+ if isinstance (x ._meta , np .ndarray ):
198+ # Must be on CPU since backed by numpy
199+ return "cpu"
200+ except ImportError :
201+ pass
202+ return DASK_DEVICE
203+ elif is_jax_array (x ):
187204 # JAX has .device() as a method, but it is being deprecated so that it
188205 # can become a property, in accordance with the standard. In order for
189206 # this function to not break when JAX makes the flip, we check for
0 commit comments