Skip to content

Commit be27746

Browse files
committed
Adding operator overloading to the array class
1 parent 2be3cbf commit be27746

File tree

3 files changed

+223
-25
lines changed

3 files changed

+223
-25
lines changed

arrayfire/array.py

Lines changed: 210 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,74 @@
11
import array as host
22
from .library import *
3-
from .util import dim4
3+
from .util import *
4+
from .data import *
45

56
def create_array(buf, numdims, idims, dtype):
67
out_arr = c_longlong(0)
78
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
89
clib.af_create_array(pointer(out_arr), c_longlong(buf), numdims, pointer(c_dims), dtype)
910
return out_arr
1011

12+
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
13+
if not isinstance(dtype, c_int):
14+
raise TypeError("Invalid dtype")
15+
16+
out = c_longlong(0)
17+
dims = dim4(d0, d1, d2, d3)
18+
19+
if isinstance(val, complex):
20+
c_real = c_double(val.real)
21+
c_imag = c_double(val.imag)
22+
23+
if (dtype != c32 and dtype != c64):
24+
dtype = c32
25+
26+
clib.af_constant_complex(pointer(out), c_real, c_imag, 4, pointer(dims), dtype)
27+
elif dtype == s64:
28+
c_val = c_longlong(val.real)
29+
clib.af_constant_long(pointer(out), c_val, 4, pointer(dims))
30+
elif dtype == u64:
31+
c_val = c_ulonglong(val.real)
32+
clib.af_constant_ulong(pointer(out), c_val, 4, pointer(dims))
33+
else:
34+
c_val = c_double(val)
35+
clib.af_constant(pointer(out), c_val, 4, pointer(dims), dtype)
36+
37+
return out
38+
39+
40+
def binary_func(lhs, rhs, c_func):
41+
out = array()
42+
other = rhs
43+
44+
if (is_valid_scalar(rhs)):
45+
ldims = dim4_tuple(lhs.dims())
46+
lty = lhs.type()
47+
other = array()
48+
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
49+
elif not isinstance(rhs, array):
50+
TypeError("Invalid parameter to binary function")
51+
52+
c_func(pointer(out.arr), lhs.arr, other.arr, False)
53+
54+
return out
55+
56+
def binary_funcr(lhs, rhs, c_func):
57+
out = array()
58+
other = lhs
59+
60+
if (is_valid_scalar(lhs)):
61+
rdims = dim4_tuple(rhs.dims())
62+
rty = rhs.type()
63+
other = array()
64+
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
65+
elif not isinstance(lhs, array):
66+
TypeError("Invalid parameter to binary function")
67+
68+
c_func(pointer(out.arr), other.arr, rhs.arr, False)
69+
70+
return out
71+
1172
class array(object):
1273

1374
def __init__(self, src=None, dims=(0,)):
@@ -65,3 +126,151 @@ def __init__(self, src=None, dims=(0,)):
65126
def __del__(self):
66127
if (self.arr.value != 0):
67128
clib.af_release_array(self.arr)
129+
130+
def numdims(self):
131+
nd = c_uint(0)
132+
clib.af_get_numdims(pointer(nd), self.arr)
133+
return nd.value
134+
135+
def dims(self):
136+
d0 = c_longlong(0)
137+
d1 = c_longlong(0)
138+
d2 = c_longlong(0)
139+
d3 = c_longlong(0)
140+
clib.af_get_dims(pointer(d0), pointer(d1), pointer(d2), pointer(d3), self.arr)
141+
dims = (d0.value,d1.value,d2.value,d3.value)
142+
return dims[:self.numdims()]
143+
144+
def type(self):
145+
dty = f32
146+
clib.af_get_type(pointer(dty), self.arr)
147+
return dty
148+
149+
def __add__(self, other):
150+
return binary_func(self, other, clib.af_add)
151+
152+
def __iadd__(self, other):
153+
self = binary_func(self, other, clib.af_add)
154+
return self
155+
156+
def __radd__(self, other):
157+
return binary_funcr(other, self, clib.af_add)
158+
159+
def __sub__(self, other):
160+
return binary_func(self, other, clib.af_sub)
161+
162+
def __isub__(self, other):
163+
self = binary_func(self, other, clib.af_sub)
164+
return self
165+
166+
def __rsub__(self, other):
167+
return binary_funcr(other, self, clib.af_sub)
168+
169+
def __mul__(self, other):
170+
return binary_func(self, other, clib.af_mul)
171+
172+
def __imul__(self, other):
173+
self = binary_func(self, other, clib.af_mul)
174+
return self
175+
176+
def __rmul__(self, other):
177+
return binary_funcr(other, self, clib.af_mul)
178+
179+
def __truediv__(self, other):
180+
return binary_func(self, other, clib.af_div)
181+
182+
def __itruediv__(self, other):
183+
self = binary_func(self, other, clib.af_div)
184+
return self
185+
186+
def __rtruediv__(self, other):
187+
return binary_funcr(other, self, clib.af_div)
188+
189+
def __mod__(self, other):
190+
return binary_func(self, other, clib.af_mod)
191+
192+
def __imod__(self, other):
193+
self = binary_func(self, other, clib.af_mod)
194+
return self
195+
196+
def __rmod__(self, other):
197+
return binary_funcr(other, self, clib.af_mod)
198+
199+
def __pow__(self, other):
200+
return binary_func(self, other, clib.af_pow)
201+
202+
def __ipow__(self, other):
203+
self = binary_func(self, other, clib.af_pow)
204+
return self
205+
206+
def __rpow__(self, other):
207+
return binary_funcr(other, self, clib.af_pow)
208+
209+
def __lt__(self, other):
210+
return binary_func(self, other, clib.af_lt)
211+
212+
def __gt__(self, other):
213+
return binary_func(self, other, clib.af_gt)
214+
215+
def __le__(self, other):
216+
return binary_func(self, other, clib.af_le)
217+
218+
def __ge__(self, other):
219+
return binary_func(self, other, clib.af_ge)
220+
221+
def __eq__(self, other):
222+
return binary_func(self, other, clib.af_eq)
223+
224+
def __ne__(self, other):
225+
return binary_func(self, other, clib.af_neq)
226+
227+
def __and__(self, other):
228+
return binary_func(self, other, clib.af_bitand)
229+
230+
def __iand__(self, other):
231+
self = binary_func(self, other, clib.af_bitand)
232+
return self
233+
234+
def __or__(self, other):
235+
return binary_func(self, other, clib.af_bitor)
236+
237+
def __ior__(self, other):
238+
self = binary_func(self, other, clib.af_bitor)
239+
return self
240+
241+
def __xor__(self, other):
242+
return binary_func(self, other, clib.af_bitxor)
243+
244+
def __ixor__(self, other):
245+
self = binary_func(self, other, clib.af_bitxor)
246+
return self
247+
248+
def __lshift__(self, other):
249+
return binary_func(self, other, clib.af_bitshiftl)
250+
251+
def __ilshift__(self, other):
252+
self = binary_func(self, other, clib.af_bitshiftl)
253+
return self
254+
255+
def __rshift__(self, other):
256+
return binary_func(self, other, clib.af_bitshiftr)
257+
258+
def __irshift__(self, other):
259+
self = binary_func(self, other, clib.af_bitshiftr)
260+
return self
261+
262+
def __neg__(self):
263+
return 0 - self
264+
265+
def __pos__(self):
266+
return self
267+
268+
def __invert__(self):
269+
return self == 0
270+
271+
def __nonzero__(self):
272+
return self != 0
273+
274+
# TODO:
275+
# def __abs__(self):
276+
# return self

arrayfire/data.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,9 @@
33
from .array import *
44
from .util import *
55

6-
76
def constant(val, d0, d1=None, d2=None, d3=None, dtype=f32):
8-
9-
if not isinstance(dtype, c_int):
10-
raise TypeError("Invalid dtype")
11-
127
out = array()
13-
dims = dim4(d0, d1, d2, d3)
14-
15-
if isinstance(val, complex):
16-
c_real = c_double(val.real)
17-
c_imag = c_double(val.imag)
18-
19-
if (dtype != c32 and dtype != c64):
20-
dtype = c32
21-
22-
clib.af_constant_complex(pointer(out.arr), c_real, c_imag, 4, pointer(dims), dtype)
23-
elif dtype == s64:
24-
c_val = c_longlong(val.real)
25-
clib.af_constant_long(pointer(out.arr), c_val, 4, pointer(dims))
26-
elif dtype == u64:
27-
c_val = c_ulonglong(val.real)
28-
clib.af_constant_ulong(pointer(out.arr), c_val, 4, pointer(dims))
29-
else:
30-
c_val = c_double(val)
31-
clib.af_constant(pointer(out.arr), c_val, 4, pointer(dims), dtype)
8+
out.arr = constant_array(val, d0, d1, d2, d3, dtype)
329
return out
3310

3411
# Store builtin range function to be used later

arrayfire/util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,17 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
1010

1111
return out
1212

13+
def dim4_tuple(dims):
14+
assert(isinstance(dims, tuple))
15+
out = [1]*4
16+
17+
for i, dim in enumerate(dims):
18+
out[i] = dim
19+
20+
return tuple(out)
21+
1322
def print_array(a):
1423
clib.af_print_array(a.arr)
24+
25+
def is_valid_scalar(a):
26+
return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)

0 commit comments

Comments
 (0)