Skip to content

Commit 7a3bbb8

Browse files
committed
Adding blas functions and simple tests
1 parent 87ba291 commit 7a3bbb8

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .util import *
44
from .algorithm import *
55
from .device import *
6+
from .blas import *

arrayfire/algorithm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .library import *
22
from .array import *
33

4-
54
def parallel_dim(a, dim, c_func):
65
out = array()
76
safe_call(c_func(pointer(out.arr), a.arr, c_int(dim)))

arrayfire/blas.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .library import *
2+
from .array import *
3+
4+
def matmul(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
5+
out = array()
6+
safe_call(clib.af_matmul(pointer(out.arr), lhs.arr, rhs.arr,\
7+
lhs_opts, rhs_opts))
8+
return out
9+
10+
def dot(lhs, rhs, lhs_opts=AF_MAT_NONE, rhs_opts=AF_MAT_NONE):
11+
out = array()
12+
safe_call(clib.af_dot(pointer(out.arr), lhs.arr, rhs.arr,\
13+
lhs_opts, rhs_opts))
14+
return out
15+
16+
def transpose(a, conj=False):
17+
out = array()
18+
safe_call(clib.af_transpose(pointer(out.arr), a.arr, conj))
19+
return out
20+
21+
def transpose_inplace(a, conj=False):
22+
safe_call(clib.af_transpose_inplace(a.arr, conj))

tests/simple_blas.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#!/usr/bin/python
2+
import arrayfire as af
3+
4+
a = af.randu(5,5)
5+
b = af.randu(5,5)
6+
7+
af.print_array(af.matmul(a,b))
8+
af.print_array(af.matmul(a,b,af.AF_MAT_TRANS))
9+
af.print_array(af.matmul(a,b,af.AF_MAT_NONE, af.AF_MAT_TRANS))
10+
11+
b = af.randu(5,1)
12+
af.print_array(af.dot(a,b))
13+
14+
af.print_array(af.transpose(a))
15+
16+
af.transpose_inplace(a)
17+
af.print_array(a)

0 commit comments

Comments
 (0)