11from __future__ import annotations
2+
23from functools import partial
4+ from typing import TYPE_CHECKING
35
4- from ...common import _aliases
5- from ...common ._helpers import _check_device
6+ import numpy as np
67
78from ..._internal import get_xp
9+ from ...common import _aliases , _linalg
10+ from ...common ._helpers import _check_device
811
9- import numpy as np
10-
11- from typing import TYPE_CHECKING
1212if TYPE_CHECKING :
13- from typing import Optional , Union
14- from ...common ._typing import ndarray , Device , Dtype
13+ from typing import Optional , Tuple , Union
14+
15+ from ...common ._typing import Device , Dtype , ndarray
1516
1617import dask .array as da
1718
2425# not pass stop/step as keyword arguments, which will cause
2526# an error with dask
2627
28+
2729# TODO: delete the xp stuff, it shouldn't be necessary
2830def dask_arange (
2931 start : Union [int , float ],
@@ -34,7 +36,7 @@ def dask_arange(
3436 xp ,
3537 dtype : Optional [Dtype ] = None ,
3638 device : Optional [Device ] = None ,
37- ** kwargs
39+ ** kwargs ,
3840) -> ndarray :
3941 _check_device (xp , device )
4042 args = [start ]
@@ -47,10 +49,11 @@ def dask_arange(
4749 args .append (step )
4850 return xp .arange (* args , dtype = dtype , ** kwargs )
4951
52+
5053arange = get_xp (da )(dask_arange )
5154eye = get_xp (da )(_aliases .eye )
5255
53- asarray = partial (_aliases ._asarray , namespace = ' dask.array' )
56+ asarray = partial (_aliases ._asarray , namespace = " dask.array" )
5457asarray .__doc__ = _aliases ._asarray .__doc__
5558
5659linspace = get_xp (da )(_aliases .linspace )
@@ -86,3 +89,22 @@ def dask_arange(
8689matmul = get_xp (np )(_aliases .matmul )
8790tensordot = get_xp (np )(_aliases .tensordot )
8891
92+
93+ EighResult = _linalg .EighResult
94+ QRResult = _linalg .QRResult
95+ SlogdetResult = _linalg .SlogdetResult
96+ SVDResult = _linalg .SVDResult
97+ qr = get_xp (da )(_linalg .qr )
98+ cholesky = get_xp (da )(_linalg .cholesky )
99+ matrix_rank = get_xp (da )(_linalg .matrix_rank )
100+ matrix_norm = get_xp (da )(_linalg .matrix_norm )
101+
102+
103+ def svdvals (x : ndarray ) -> Union [ndarray , Tuple [ndarray , ...]]:
104+ # TODO: can't avoid computing U or V for dask
105+ _ , s , _ = da .linalg .svd (x )
106+ return s
107+
108+
109+ vector_norm = get_xp (da )(_linalg .vector_norm )
110+ diagonal = get_xp (da )(_linalg .diagonal )
0 commit comments