Skip to content

Commit 99c6cee

Browse files
committed
Adding all functions from arith.h
1 parent 7a3bbb8 commit 99c6cee

File tree

5 files changed

+253
-1
lines changed

5 files changed

+253
-1
lines changed

arrayfire/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .algorithm import *
55
from .device import *
66
from .blas import *
7+
from .arith import *

arrayfire/arith.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from .library import *
2+
from .array import *
3+
4+
def arith_binary_func(lhs, rhs, c_func):
5+
out = array()
6+
7+
is_left_array = isinstance(lhs, array)
8+
is_right_array = isinstance(rhs, array)
9+
10+
if not (is_left_array or is_right_array):
11+
TypeError("Atleast one input needs to be of type arrayfire.array")
12+
13+
elif (is_left_array and is_right_array):
14+
safe_call(c_func(pointer(out.arr), lhs.arr, rhs.arr, False))
15+
16+
elif (is_valid_scalar(rhs)):
17+
ldims = dim4_tuple(lhs.dims())
18+
lty = lhs.type()
19+
other = array()
20+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
21+
safe_call(c_func(pointer(out.arr), lhs.arr, other.arr, False))
22+
23+
else:
24+
rdims = dim4_tuple(rhs.dims())
25+
rty = rhs.type()
26+
other = array()
27+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
28+
safe_call(c_func(pointer(out.arr), lhs.arr, other.arr, False))
29+
30+
return out
31+
32+
def arith_unary_func(a, c_func):
33+
out = array()
34+
safe_call(c_func(pointer(out.arr), a.arr))
35+
return out
36+
37+
def cast(a, dtype=f32):
38+
out=array()
39+
safe_call(clib.af_cast(pointer(out.arr), a.arr, dtype))
40+
return out
41+
42+
def minof(lhs, rhs):
43+
return arith_binary_func(lhs, rhs, clib.af_minof)
44+
45+
def maxof(lhs, rhs):
46+
return arith_binary_func(lhs, rhs, clib.af_maxof)
47+
48+
def rem(lhs, rhs):
49+
return arith_binary_func(lhs, rhs, clib.af_rem)
50+
51+
def abs(a):
52+
return arith_unary_func(a, clib.af_abs)
53+
54+
def arg(a):
55+
return arith_unary_func(a, clib.af_arg)
56+
57+
def sign(a):
58+
return arith_unary_func(a, clib.af_sign)
59+
60+
def round(a):
61+
return arith_unary_func(a, clib.af_round)
62+
63+
def trunc(a):
64+
return arith_unary_func(a, clib.af_trunc)
65+
66+
def floor(a):
67+
return arith_unary_func(a, clib.af_floor)
68+
69+
def ceil(a):
70+
return arith_unary_func(a, clib.af_ceil)
71+
72+
def hypot(lhs, rhs):
73+
return arith_binary_func(lhs, rhs, clib.af_hypot)
74+
75+
def sin(a):
76+
return arith_unary_func(a, clib.af_sin)
77+
78+
def cos(a):
79+
return arith_unary_func(a, clib.af_cos)
80+
81+
def tan(a):
82+
return arith_unary_func(a, clib.af_tan)
83+
84+
def asin(a):
85+
return arith_unary_func(a, clib.af_asin)
86+
87+
def acos(a):
88+
return arith_unary_func(a, clib.af_acos)
89+
90+
def atan(a):
91+
return arith_unary_func(a, clib.af_atan)
92+
93+
def atan2(lhs, rhs):
94+
return arith_binary_func(lhs, rhs, clib.af_atan2)
95+
96+
def cplx(lhs, rhs=None):
97+
if rhs is None:
98+
return arith_unary_func(lhs, clib.af_cplx)
99+
else:
100+
return arith_binary_func(lhs, rhs, clib.af_cplx2)
101+
102+
def real(lhs):
103+
return arith_unary_func(lhs, clib.af_real)
104+
105+
def imag(lhs):
106+
return arith_unary_func(lhs, clib.af_imag)
107+
108+
def conjg(lhs):
109+
return arith_unary_func(lhs, clib.af_conjg)
110+
111+
def sinh(a):
112+
return arith_unary_func(a, clib.af_sinh)
113+
114+
def cosh(a):
115+
return arith_unary_func(a, clib.af_cosh)
116+
117+
def tanh(a):
118+
return arith_unary_func(a, clib.af_tanh)
119+
120+
def asinh(a):
121+
return arith_unary_func(a, clib.af_asinh)
122+
123+
def acosh(a):
124+
return arith_unary_func(a, clib.af_acosh)
125+
126+
def atanh(a):
127+
return arith_unary_func(a, clib.af_atanh)
128+
129+
def root(lhs, rhs):
130+
return arith_binary_func(lhs, rhs, clib.af_root)
131+
132+
def pow(lhs, rhs):
133+
return arith_binary_func(lhs, rhs, clib.af_pow)
134+
135+
def pow2(a):
136+
return arith_unary_func(a, clib.af_pow2)
137+
138+
def exp(a):
139+
return arith_unary_func(a, clib.af_exp)
140+
141+
def expm1(a):
142+
return arith_unary_func(a, clib.af_expm1)
143+
144+
def erf(a):
145+
return arith_unary_func(a, clib.af_erf)
146+
147+
def erfc(a):
148+
return arith_unary_func(a, clib.af_erfc)
149+
150+
def log(a):
151+
return arith_unary_func(a, clib.af_log)
152+
153+
def log1p(a):
154+
return arith_unary_func(a, clib.af_log1p)
155+
156+
def log10(a):
157+
return arith_unary_func(a, clib.af_log10)
158+
159+
def log2(a):
160+
return arith_unary_func(a, clib.af_log2)
161+
162+
def sqrt(a):
163+
return arith_unary_func(a, clib.af_sqrt)
164+
165+
def cbrt(a):
166+
return arith_unary_func(a, clib.af_cbrt)
167+
168+
def factorial(a):
169+
return arith_unary_func(a, clib.af_factorial)
170+
171+
def tgamma(a):
172+
return arith_unary_func(a, clib.af_tgamma)
173+
174+
def lgamma(a):
175+
return arith_unary_func(a, clib.af_lgamma)
176+
177+
def iszero(a):
178+
return arith_unary_func(a, clib.af_iszero)
179+
180+
def isinf(a):
181+
return arith_unary_func(a, clib.af_isinf)
182+
183+
def isnan(a):
184+
return arith_unary_func(a, clib.af_isnan)

arrayfire/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import array as host
2+
import inspect
23
from .library import *
34
from .util import *
45
from .data import *

arrayfire/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ def safe_call(af_error):
2929
c_err_str = c_char_p(0)
3030
c_err_len = c_longlong(0)
3131
clib.af_get_last_error(pointer(c_err_str), pointer(c_err_len))
32-
raise RuntimeError('test', to_str(c_err_str), af_error)
32+
raise RuntimeError(to_str(c_err_str), af_error)

tests/simple_arith.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,69 @@
113113
af.print_array(+a)
114114
af.print_array(~a)
115115
af.print_array(a)
116+
117+
af.print_array(af.cast(a, af.c32))
118+
af.print_array(af.maxof(a,b))
119+
af.print_array(af.minof(a,b))
120+
af.print_array(af.rem(a,b))
121+
122+
a = af.randu(3,3) - 0.5
123+
b = af.randu(3,3) - 0.5
124+
125+
af.print_array(af.abs(a))
126+
af.print_array(af.arg(a))
127+
af.print_array(af.sign(a))
128+
af.print_array(af.round(a))
129+
af.print_array(af.trunc(a))
130+
af.print_array(af.floor(a))
131+
af.print_array(af.ceil(a))
132+
af.print_array(af.hypot(a, b))
133+
af.print_array(af.sin(a))
134+
af.print_array(af.cos(a))
135+
af.print_array(af.tan(a))
136+
af.print_array(af.asin(a))
137+
af.print_array(af.acos(a))
138+
af.print_array(af.atan(a))
139+
af.print_array(af.atan2(a, b))
140+
141+
c = af.cplx(a)
142+
d = af.cplx(a,b)
143+
af.print_array(c)
144+
af.print_array(d)
145+
af.print_array(af.real(d))
146+
af.print_array(af.imag(d))
147+
af.print_array(af.conjg(d))
148+
149+
af.print_array(af.sinh(a))
150+
af.print_array(af.cosh(a))
151+
af.print_array(af.tanh(a))
152+
af.print_array(af.asinh(a))
153+
af.print_array(af.acosh(a))
154+
af.print_array(af.atanh(a))
155+
156+
a = af.abs(a)
157+
b = af.abs(b)
158+
159+
af.print_array(af.root(a, b))
160+
af.print_array(af.pow(a, b))
161+
af.print_array(af.pow2(a))
162+
af.print_array(af.exp(a))
163+
af.print_array(af.expm1(a))
164+
af.print_array(af.erf(a))
165+
af.print_array(af.erfc(a))
166+
af.print_array(af.log(a))
167+
af.print_array(af.log1p(a))
168+
af.print_array(af.log10(a))
169+
af.print_array(af.log2(a))
170+
af.print_array(af.sqrt(a))
171+
af.print_array(af.cbrt(a))
172+
173+
a = af.round(5 * af.randu(3,3) - 1)
174+
b = af.round(5 * af.randu(3,3) - 1)
175+
176+
af.print_array(af.factorial(a))
177+
af.print_array(af.tgamma(a))
178+
af.print_array(af.lgamma(a))
179+
af.print_array(af.iszero(a))
180+
af.print_array(af.isinf(a/b))
181+
af.print_array(af.isnan(a/a))

0 commit comments

Comments
 (0)