Skip to content

Commit f728d7a

Browse files
committed
Adding error checking to all clib function calls
1 parent 63f08cd commit f728d7a

File tree

4 files changed

+54
-46
lines changed

4 files changed

+54
-46
lines changed

arrayfire/algorithm.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
def parallel_dim(a, dim, c_func):
66
out = array()
7-
c_func(pointer(out.arr), a.arr, c_int(dim))
7+
safe_call(c_func(pointer(out.arr), a.arr, c_int(dim)))
88
return out
99

1010
def reduce_all(a, c_func):
1111
real = c_double(0)
1212
imag = c_double(0)
13-
c_func(pointer(real), pointer(imag), a.arr)
13+
safe_call(c_func(pointer(real), pointer(imag), a.arr))
1414
real = real.value
1515
imag = imag.value
1616
return real if imag == 0 else real + imag * 1j
@@ -61,13 +61,13 @@ def imin(a, dim=None):
6161
if dim is not None:
6262
out = array()
6363
idx = array()
64-
clib.af_imin(pointer(out.arr), pointer(idx.arr), a.arr, c_int(dim))
64+
safe_call(clib.af_imin(pointer(out.arr), pointer(idx.arr), a.arr, c_int(dim)))
6565
return out,idx
6666
else:
6767
real = c_double(0)
6868
imag = c_double(0)
6969
idx = c_uint(0)
70-
clib.af_imin_all(pointer(real), pointer(imag), pointer(idx), a.arr)
70+
safe_call(clib.af_imin_all(pointer(real), pointer(imag), pointer(idx), a.arr))
7171
real = real.value
7272
imag = imag.value
7373
val = real if imag == 0 else real + imag * 1j
@@ -77,13 +77,13 @@ def imax(a, dim=None):
7777
if dim is not None:
7878
out = array()
7979
idx = array()
80-
clib.af_imax(pointer(out.arr), pointer(idx.arr), a.arr, c_int(dim))
80+
safe_call(clib.af_imax(pointer(out.arr), pointer(idx.arr), a.arr, c_int(dim)))
8181
return out,idx
8282
else:
8383
real = c_double(0)
8484
imag = c_double(0)
8585
idx = c_uint(0)
86-
clib.af_imax_all(pointer(real), pointer(imag), pointer(idx), a.arr)
86+
safe_call(clib.af_imax_all(pointer(real), pointer(imag), pointer(idx), a.arr))
8787
real = real.value
8888
imag = imag.value
8989
val = real if imag == 0 else real + imag * 1j
@@ -95,7 +95,7 @@ def accum(a, dim=0):
9595

9696
def where(a):
9797
out = array()
98-
clib.af_where(pointer(out.arr), a.arr)
98+
safe_call(clib.af_where(pointer(out.arr), a.arr))
9999
return out
100100

101101
def diff1(a, dim=0):
@@ -106,33 +106,34 @@ def diff2(a, dim=0):
106106

107107
def sort(a, dim=0, is_ascending=True):
108108
out = array()
109-
clib.af_sort(pointer(out.arr), a.arr, c_uint(dim), c_bool(is_ascending))
109+
safe_call(clib.af_sort(pointer(out.arr), a.arr, c_uint(dim), c_bool(is_ascending)))
110110
return out
111111

112112
def sort_index(a, dim=0, is_ascending=True):
113113
out = array()
114114
idx = array()
115-
clib.af_sort_index(pointer(out.arr), pointer(idx.arr), a.arr, c_uint(dim), c_bool(is_ascending))
115+
safe_call(clib.af_sort_index(pointer(out.arr), pointer(idx.arr), a.arr, \
116+
c_uint(dim), c_bool(is_ascending)))
116117
return out,idx
117118

118119
def sort_by_key(iv, ik, dim=0, is_ascending=True):
119120
ov = array()
120121
ok = array()
121-
clib.af_sort_by_key(pointer(ov.arr), pointer(ok.arr), \
122-
iv.arr, ik.arr, c_uint(dim), c_bool(is_ascending))
122+
safe_call(clib.af_sort_by_key(pointer(ov.arr), pointer(ok.arr), \
123+
iv.arr, ik.arr, c_uint(dim), c_bool(is_ascending)))
123124
return ov,ok
124125

125126
def set_unique(a, is_sorted=False):
126127
out = array()
127-
clib.af_set_unique(pointer(out.arr), a.arr, c_bool(is_sorted))
128+
safe_call(clib.af_set_unique(pointer(out.arr), a.arr, c_bool(is_sorted)))
128129
return out
129130

130131
def set_union(a, b, is_unique=False):
131132
out = array()
132-
clib.af_set_union(pointer(out.arr), a.arr, b.arr, c_bool(is_unique))
133+
safe_call(clib.af_set_union(pointer(out.arr), a.arr, b.arr, c_bool(is_unique)))
133134
return out
134135

135136
def set_intersect(a, b, is_unique=False):
136137
out = array()
137-
clib.af_set_intersect(pointer(out.arr), a.arr, b.arr, c_bool(is_unique))
138+
safe_call(clib.af_set_intersect(pointer(out.arr), a.arr, b.arr, c_bool(is_unique)))
138139
return out

arrayfire/array.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def create_array(buf, numdims, idims, dtype):
77
out_arr = c_longlong(0)
88
c_dims = dim4(idims[0], idims[1], idims[2], idims[3])
9-
clib.af_create_array(pointer(out_arr), c_longlong(buf), numdims, pointer(c_dims), dtype)
9+
safe_call(clib.af_create_array(pointer(out_arr), c_longlong(buf), numdims, pointer(c_dims), dtype))
1010
return out_arr
1111

1212
def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
@@ -23,16 +23,16 @@ def constant_array(val, d0, d1=None, d2=None, d3=None, dtype=f32):
2323
if (dtype != c32 and dtype != c64):
2424
dtype = c32
2525

26-
clib.af_constant_complex(pointer(out), c_real, c_imag, 4, pointer(dims), dtype)
26+
safe_call(clib.af_constant_complex(pointer(out), c_real, c_imag, 4, pointer(dims), dtype))
2727
elif dtype == s64:
2828
c_val = c_longlong(val.real)
29-
clib.af_constant_long(pointer(out), c_val, 4, pointer(dims))
29+
safe_call(clib.af_constant_long(pointer(out), c_val, 4, pointer(dims)))
3030
elif dtype == u64:
3131
c_val = c_ulonglong(val.real)
32-
clib.af_constant_ulong(pointer(out), c_val, 4, pointer(dims))
32+
safe_call(clib.af_constant_ulong(pointer(out), c_val, 4, pointer(dims)))
3333
else:
3434
c_val = c_double(val)
35-
clib.af_constant(pointer(out), c_val, 4, pointer(dims), dtype)
35+
safe_call(clib.af_constant(pointer(out), c_val, 4, pointer(dims), dtype))
3636

3737
return out
3838

@@ -49,7 +49,7 @@ def binary_func(lhs, rhs, c_func):
4949
elif not isinstance(rhs, array):
5050
TypeError("Invalid parameter to binary function")
5151

52-
c_func(pointer(out.arr), lhs.arr, other.arr, False)
52+
safe_call(c_func(pointer(out.arr), lhs.arr, other.arr, False))
5353

5454
return out
5555

@@ -129,21 +129,21 @@ def __del__(self):
129129

130130
def numdims(self):
131131
nd = c_uint(0)
132-
clib.af_get_numdims(pointer(nd), self.arr)
132+
safe_call(clib.af_get_numdims(pointer(nd), self.arr))
133133
return nd.value
134134

135135
def dims(self):
136136
d0 = c_longlong(0)
137137
d1 = c_longlong(0)
138138
d2 = c_longlong(0)
139139
d3 = c_longlong(0)
140-
clib.af_get_dims(pointer(d0), pointer(d1), pointer(d2), pointer(d3), self.arr)
140+
safe_call(clib.af_get_dims(pointer(d0), pointer(d1), pointer(d2), pointer(d3), self.arr))
141141
dims = (d0.value,d1.value,d2.value,d3.value)
142142
return dims[:self.numdims()]
143143

144144
def type(self):
145145
dty = f32
146-
clib.af_get_type(pointer(dty), self.arr)
146+
safe_call(clib.af_get_type(pointer(dty), self.arr))
147147
return dty
148148

149149
def __add__(self, other):
@@ -286,3 +286,6 @@ def __nonzero__(self):
286286
# TODO:
287287
# def __abs__(self):
288288
# return self
289+
290+
def print_array(a):
291+
safe_call(clib.af_print_array(a.arr))

arrayfire/data.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def range(d0, d1=None, d2=None, d3=None, dim=-1, dtype=f32):
1818
out = array()
1919
dims = dim4(d0, d1, d2, d3)
2020

21-
clib.af_range(pointer(out.arr), 4, pointer(dims), dim, dtype)
21+
safe_call(clib.af_range(pointer(out.arr), 4, pointer(dims), dim, dtype))
2222
return out
2323

2424

@@ -36,7 +36,7 @@ def iota(d0, d1=None, d2=None, d3=None, dim=-1, tile_dims=None, dtype=f32):
3636

3737
tdims = dim4(td[0], td[1], td[2], td[3])
3838

39-
clib.af_iota(pointer(out.arr), 4, pointer(dims), 4, pointer(tdims), dtype)
39+
safe_call(clib.af_iota(pointer(out.arr), 4, pointer(dims), 4, pointer(tdims), dtype))
4040
return out
4141

4242
def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
@@ -47,7 +47,7 @@ def randu(d0, d1=None, d2=None, d3=None, dtype=f32):
4747
out = array()
4848
dims = dim4(d0, d1, d2, d3)
4949

50-
clib.af_randu(pointer(out.arr), 4, pointer(dims), dtype)
50+
safe_call(clib.af_randu(pointer(out.arr), 4, pointer(dims), dtype))
5151
return out
5252

5353
def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
@@ -58,15 +58,15 @@ def randn(d0, d1=None, d2=None, d3=None, dtype=f32):
5858
out = array()
5959
dims = dim4(d0, d1, d2, d3)
6060

61-
clib.af_randn(pointer(out.arr), 4, pointer(dims), dtype)
61+
safe_call(clib.af_randn(pointer(out.arr), 4, pointer(dims), dtype))
6262
return out
6363

6464
def set_seed(seed=0):
65-
clib.af_set_seed(c_ulonglong(seed))
65+
safe_call(clib.af_set_seed(c_ulonglong(seed)))
6666

6767
def get_seed():
6868
seed = c_ulonglong(0)
69-
clib.af_get_seed(pointer(seed))
69+
safe_call(clib.af_get_seed(pointer(seed)))
7070
return seed.value
7171

7272
def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
@@ -77,21 +77,21 @@ def identity(d0, d1=None, d2=None, d3=None, dtype=f32):
7777
out = array()
7878
dims = dim4(d0, d1, d2, d3)
7979

80-
clib.af_identity(pointer(out.arr), 4, pointer(dims), dtype)
80+
safe_call(clib.af_identity(pointer(out.arr), 4, pointer(dims), dtype))
8181
return out
8282

8383
def diag(a, num=0, extract=True):
8484
out = array()
8585
if extract:
86-
clib.af_diag_extract(pointer(out.arr), a.arr, c_int(num))
86+
safe_call(clib.af_diag_extract(pointer(out.arr), a.arr, c_int(num)))
8787
else:
88-
clib.af_diag_create(pointer(out.arr), a.arr, c_int(num))
88+
safe_call(clib.af_diag_create(pointer(out.arr), a.arr, c_int(num)))
8989
return out
9090

9191
def join(dim, first, second, third=None, fourth=None):
9292
out = array()
9393
if (third is None and fourth is None):
94-
clib.af_join(pointer(out.arr), dim, first.arr, second.arr)
94+
safe_call(clib.af_join(pointer(out.arr), dim, first.arr, second.arr))
9595
else:
9696
c_array_vec = dim4(first, second, 0, 0)
9797
num = 2
@@ -102,47 +102,47 @@ def join(dim, first, second, third=None, fourth=None):
102102
c_array_vec[num] = fourth.arr
103103
num+=1
104104

105-
clib.af_join_many(pointer(out.arr), dim, num, pointer(c_array_vec))
105+
safe_call(clib.af_join_many(pointer(out.arr), dim, num, pointer(c_array_vec)))
106106

107107

108108
def tile(a, d0, d1=1, d2=1, d3=1):
109109
out = array()
110-
clib.af_tile(pointer(out.arr), a.arr, d0, d1, d2, d3)
110+
safe_call(clib.af_tile(pointer(out.arr), a.arr, d0, d1, d2, d3))
111111
return out
112112

113113

114114
def reorder(a, d0=1, d1=0, d2=2, d3=3):
115115
out = array()
116-
clib.af_reorder(pointer(out.arr), a.arr, d0, d1, d2, d3)
116+
safe_call(clib.af_reorder(pointer(out.arr), a.arr, d0, d1, d2, d3))
117117
return out
118118

119119
def shift(a, d0, d1=0, d2=0, d3=0):
120120
out = array()
121-
clib.af_shift(pointer(out.arr), a.arr, d0, d1, d2, d3)
121+
safe_call(clib.af_shift(pointer(out.arr), a.arr, d0, d1, d2, d3))
122122
return out
123123

124124
def moddims(a, d0, d1=1, d2=1, d3=1):
125125
out = array()
126126
dims = dim4(d0, d1, d2, d3)
127-
clib.af_moddims(pointer(out.arr), a.arr, 4, pointer(dims))
127+
safe_call(clib.af_moddims(pointer(out.arr), a.arr, 4, pointer(dims)))
128128
return out
129129

130130
def flat(a):
131131
out = array()
132-
clib.af_flat(pointer(out.arr), a.arr)
132+
safe_call(clib.af_flat(pointer(out.arr), a.arr))
133133
return out
134134

135135
def flip(a, dim=0):
136136
out = array()
137-
clib.af_flip(pointer(out.arr), a.arr, c_int(dim))
137+
safe_call(clib.af_flip(pointer(out.arr), a.arr, c_int(dim)))
138138
return out
139139

140140
def lower(a, is_unit_diag=False):
141141
out = array()
142-
clib.af_lower(pointer(out.arr), a.arr, is_unit_diag)
142+
safe_call(clib.af_lower(pointer(out.arr), a.arr, is_unit_diag))
143143
return out
144144

145145
def upper(a, is_unit_diag=False):
146146
out = array()
147-
clib.af_upper(pointer(out.arr), a.arr, is_unit_diag)
147+
safe_call(clib.af_upper(pointer(out.arr), a.arr, is_unit_diag))
148148
return out

arrayfire/util.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import inspect
12
from .library import *
2-
from .array import *
33

44
def dim4(d0=1, d1=1, d2=1, d3=1):
55
c_dim4 = c_longlong * 4
@@ -19,8 +19,12 @@ def dim4_tuple(dims):
1919

2020
return tuple(out)
2121

22-
def print_array(a):
23-
clib.af_print_array(a.arr)
24-
2522
def is_valid_scalar(a):
2623
return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)
24+
25+
def safe_call(af_error):
26+
if (af_error != AF_SUCCESS.value):
27+
c_err_str = c_char_p(0)
28+
c_err_len = c_longlong(0)
29+
clib.af_get_last_error(pointer(c_err_str), pointer(c_err_len))
30+
raise RuntimeError(c_err_str.value, af_error)

0 commit comments

Comments
 (0)