Skip to content

Commit e35f3c4

Browse files
committed
FEAT/TEST: Adding all functions from lapack.h
1 parent f7cdefc commit e35f3c4

File tree

4 files changed

+170
-0
lines changed

4 files changed

+170
-0
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from .blas import *
1616
from .arith import *
1717
from .statistics import *
18+
from .lapack import *

arrayfire/blas.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,24 @@ def matmul(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
1616
lhs_opts, rhs_opts))
1717
return out
1818

19+
def matmulTN(lhs, rhs):
20+
out = array()
21+
safe_call(clib.af_matmul(pointer(out.arr), lhs.arr, rhs.arr,\
22+
AF_MAT_TRANS, AF_MAT_NONE))
23+
return out
24+
25+
def matmulNT(lhs, rhs):
26+
out = array()
27+
safe_call(clib.af_matmul(pointer(out.arr), lhs.arr, rhs.arr,\
28+
AF_MAT_NONE, AF_MAT_TRANS))
29+
return out
30+
31+
def matmulTT(lhs, rhs):
32+
out = array()
33+
safe_call(clib.af_matmul(pointer(out.arr), lhs.arr, rhs.arr,\
34+
AF_MAT_TRANS, AF_MAT_TRANS))
35+
return out
36+
1937
def dot(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
2038
out = array()
2139
safe_call(clib.af_dot(pointer(out.arr), lhs.arr, rhs.arr,\

arrayfire/lapack.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#######################################################
2+
# Copyright (c) 2014, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
from .library import *
11+
from .array import *
12+
13+
def lu(A):
14+
L = array()
15+
U = array()
16+
P = array()
17+
safe_call(clib.af_lu(pointer(L.arr), pointer(U.arr), pointer(P.arr), A.arr))
18+
return L,U,P
19+
20+
def lu_inplace(A, pivot="lapack"):
21+
P = array()
22+
is_pivot_lapack = False if (pivot == "full") else True
23+
safe_call(clib.af_lu_inplace(pointer(P.arr), A.arr, is_pivot_lapack))
24+
return P
25+
26+
def qr(A):
27+
Q = array()
28+
R = array()
29+
T = array()
30+
safe_call(clib.af_lu(pointer(Q.arr), pointer(R.arr), pointer(T.arr), A.arr))
31+
return Q,R,T
32+
33+
def qr_inplace(A):
34+
T = array()
35+
safe_call(clib.af_qr_inplace(pointer(T.arr), A.arr))
36+
return T
37+
38+
def cholesky(A, is_upper=True):
39+
R = array()
40+
info = c_int(0)
41+
safe_call(clib.af_cholesky(pointer(R.arr), pointer(info), A.arr, is_upper))
42+
return R, info.value
43+
44+
def cholesky_inplace(A, is_upper=True):
45+
info = c_int(0)
46+
safe_call(clib.af_cholesky_inplace(pointer(info), A.arr, is_upper))
47+
return info.value
48+
49+
def solve(A, B, options=AF_MAT_NONE):
50+
X = array()
51+
safe_call(clib.af_solve(pointer(X.arr), A.arr, B.arr, options))
52+
return X
53+
54+
def solve_lu(A, P, B, options=AF_MAT_NONE):
55+
X = array()
56+
safe_call(clib.af_solve_lu(pointer(X.arr), A.arr, P.arr, B.arr, options))
57+
return X
58+
59+
def inverse(A, options=AF_MAT_NONE):
60+
I = array()
61+
safe_call(clib.af_inverse(pointer(I.arr), A.arr, options))
62+
return I
63+
64+
def rank(A, tol=1E-5):
65+
r = c_uint(0)
66+
safe_call(clib.af_rank(pointer(r), A.arr, c_double(tol)))
67+
return r.value
68+
69+
def det(A):
70+
re = c_double(0)
71+
im = c_double(0)
72+
safe_call(clib.af_det(pointer(re), pointer(im), A.arr))
73+
re = re.value
74+
im = im.value
75+
return re if (im == 0) else re + im * 1j
76+
77+
def norm(A, norm_type=AF_NORM_EUCLID, p=1.0, q=1.0):
78+
res = c_double(0)
79+
safe_call(clib.af_norm(pointer(res), A.arr, norm_type, c_double(p), c_double(q)))
80+
return res.value

tests/simple_lapack.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/python
2+
#######################################################
3+
# Copyright (c) 2014, ArrayFire
4+
# All rights reserved.
5+
#
6+
# This file is distributed under 3-clause BSD license.
7+
# The complete license agreement can be obtained at:
8+
# http://arrayfire.com/licenses/BSD-3-Clause
9+
########################################################
10+
import arrayfire as af
11+
12+
a = af.randu(5,5)
13+
14+
l,u,p = af.lu(a)
15+
16+
af.print_array(l)
17+
af.print_array(u)
18+
af.print_array(p)
19+
20+
p = af.lu_inplace(a, "full")
21+
22+
af.print_array(a)
23+
af.print_array(p)
24+
25+
a = af.randu(5,3)
26+
27+
q,r,t = af.qr(a)
28+
29+
af.print_array(q)
30+
af.print_array(r)
31+
af.print_array(t)
32+
33+
af.qr_inplace(a)
34+
35+
af.print_array(a)
36+
37+
a = af.randu(5, 5)
38+
a = af.matmulTN(a, a) + 10 * af.identity(5,5)
39+
40+
R,info = af.cholesky(a)
41+
af.print_array(R)
42+
print(info)
43+
44+
af.cholesky_inplace(a)
45+
af.print_array(a)
46+
47+
a = af.randu(5,5)
48+
ai = af.inverse(a)
49+
50+
af.print_array(a)
51+
af.print_array(ai)
52+
53+
x0 = af.randu(5, 3)
54+
b = af.matmul(a, x0)
55+
x1 = af.solve(a, b)
56+
57+
af.print_array(x0)
58+
af.print_array(x1)
59+
60+
p = af.lu_inplace(a)
61+
62+
x2 = af.solve_lu(a, p, b)
63+
64+
af.print_array(x2)
65+
66+
print(af.rank(a))
67+
print(af.det(a))
68+
print(af.norm(a, af.AF_NORM_EUCLID))
69+
print(af.norm(a, af.AF_NORM_MATRIX_1))
70+
print(af.norm(a, af.AF_NORM_MATRIX_INF))
71+
print(af.norm(a, af.AF_NORM_MATRIX_L_PQ, 1, 1))

0 commit comments

Comments
 (0)