1212if TYPE_CHECKING :
1313 from typing import Optional , Tuple , Union
1414
15- from ...common ._typing import Device , Dtype , ndarray
15+ from ...common ._typing import Device , Dtype , Array
1616
1717import dask .array as da
1818
@@ -37,7 +37,7 @@ def dask_arange(
3737 dtype : Optional [Dtype ] = None ,
3838 device : Optional [Device ] = None ,
3939 ** kwargs ,
40- ) -> ndarray :
40+ ) -> Array :
4141 _check_device (xp , device )
4242 args = [start ]
4343 if stop is not None :
@@ -99,8 +99,18 @@ def dask_arange(
9999matrix_rank = get_xp (da )(_linalg .matrix_rank )
100100matrix_norm = get_xp (da )(_linalg .matrix_norm )
101101
102+ # Wrap the svd functions to not pass full_matrices to dask
103+ # when full_matrices=False (as that is the defualt behavior for dask),
104+ # and dask doesn't have the full_matrices keyword
105+ _svd = get_xp (da )(_linalg .svd )
102106
103- def svdvals (x : ndarray ) -> Union [ndarray , Tuple [ndarray , ...]]:
107+ def svd (x : Array , full_matrices : bool = True , ** kwargs ) -> SVDResult :
108+ if full_matrices :
109+ return _svd (x , full_matrices = full_matrices , ** kwargs )
110+ return _svd (x , ** kwargs )
111+
112+
113+ def svdvals (x : Array ) -> Array :
104114 # TODO: can't avoid computing U or V for dask
105115 _ , s , _ = da .linalg .svd (x )
106116 return s
0 commit comments