Skip to content

Commit f4a1738

Browse files
committed
FEAT/TEST: Added indexing and assignment support
- Added simple tests in simple_array for verification
1 parent 27010f4 commit f4a1738

File tree

6 files changed

+216
-12
lines changed

6 files changed

+216
-12
lines changed

arrayfire/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,27 @@
2727
del inspect
2828
del numbers
2929
del os
30+
31+
#do not export internal classes
32+
del uidx
33+
del seq
34+
del index
35+
36+
#do not export internal functions
37+
del binary_func
38+
del binary_funcr
39+
del create_array
40+
del constant_array
41+
del parallel_dim
42+
del reduce_all
43+
del arith_unary_func
44+
del arith_binary_func
45+
del brange
46+
del load_backend
47+
del dim4_tuple
48+
del is_number
49+
del to_str
50+
del safe_call
51+
del get_indices
52+
del get_assign_dims
53+
del slice_to_length

arrayfire/arith.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def arith_binary_func(lhs, rhs, c_func):
2222
elif (is_left_array and is_right_array):
2323
safe_call(c_func(ct.pointer(out.arr), lhs.arr, rhs.arr, False))
2424

25-
elif (is_valid_scalar(rhs)):
25+
elif (is_number(rhs)):
2626
ldims = dim4_tuple(lhs.dims())
2727
lty = lhs.type()
2828
other = array()

arrayfire/array.py

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def binary_func(lhs, rhs, c_func):
5555
out = array()
5656
other = rhs
5757

58-
if (is_valid_scalar(rhs)):
58+
if (is_number(rhs)):
5959
ldims = dim4_tuple(lhs.dims())
6060
lty = lhs.type()
6161
other = array()
6262
other.arr = constant_array(rhs, ldims[0], ldims[1], ldims[2], ldims[3], lty)
6363
elif not isinstance(rhs, array):
64-
TypeError("Invalid parameter to binary function")
64+
raise TypeError("Invalid parameter to binary function")
6565

6666
safe_call(c_func(ct.pointer(out.arr), lhs.arr, other.arr, False))
6767

@@ -71,18 +71,133 @@ def binary_funcr(lhs, rhs, c_func):
7171
out = array()
7272
other = lhs
7373

74-
if (is_valid_scalar(lhs)):
74+
if (is_number(lhs)):
7575
rdims = dim4_tuple(rhs.dims())
7676
rty = rhs.type()
7777
other = array()
7878
other.arr = constant_array(lhs, rdims[0], rdims[1], rdims[2], rdims[3], rty)
7979
elif not isinstance(lhs, array):
80-
TypeError("Invalid parameter to binary function")
80+
raise TypeError("Invalid parameter to binary function")
8181

8282
c_func(ct.pointer(out.arr), other.arr, rhs.arr, False)
8383

8484
return out
8585

86+
class seq(ct.Structure):
87+
_fields_ = [("begin", ct.c_double),
88+
("end" , ct.c_double),
89+
("step" , ct.c_double)]
90+
91+
def __init__ (self, S):
92+
num = __import__("numbers")
93+
94+
self.begin = ct.c_double( 0)
95+
self.end = ct.c_double(-1)
96+
self.step = ct.c_double( 1)
97+
98+
if is_number(S):
99+
self.begin = ct.c_double(S)
100+
self.end = ct.c_double(S)
101+
elif isinstance(S, slice):
102+
if (S.start is not None):
103+
self.begin = ct.c_double(S.start)
104+
if (S.stop is not None):
105+
self.end = ct.c_double(S.stop - 1) if S.stop >= 0 else ct.c_double(S.stop)
106+
if (S.step is not None):
107+
self.step = ct.c_double(S.step)
108+
else:
109+
raise IndexError("Invalid type while indexing arrayfire.array")
110+
111+
class uidx(ct.Union):
112+
_fields_ = [("arr", ct.c_longlong),
113+
("seq", seq)]
114+
115+
class index(ct.Structure):
116+
_fields_ = [("idx", uidx),
117+
("isSeq", ct.c_bool),
118+
("isBatch", ct.c_bool)]
119+
120+
def __init__ (self, idx):
121+
122+
self.idx = uidx()
123+
self.isBatch = False
124+
self.isSeq = True
125+
126+
if isinstance(idx, array):
127+
self.idx.arr = idx.arr
128+
self.isSeq = False
129+
else:
130+
self.idx.seq = seq(idx)
131+
132+
def get_indices(key, n_dims):
133+
index_vec = index * n_dims
134+
inds = index_vec()
135+
136+
for n in range(n_dims):
137+
inds[n] = index(slice(0, -1))
138+
139+
if isinstance(key, tuple):
140+
num_idx = len(key)
141+
for n in range(n_dims):
142+
inds[n] = index(key[n]) if (n < num_idx) else index(slice(0, -1))
143+
else:
144+
inds[0] = index(key)
145+
146+
return inds
147+
148+
def slice_to_length(key, dim):
149+
tkey = [key.start, key.stop, key.step]
150+
151+
if tkey[0] is None:
152+
tkey[0] = 0
153+
elif tkey[0] < 0:
154+
tkey[0] = dim - tkey[0]
155+
156+
if tkey[1] is None:
157+
tkey[1] = dim
158+
elif tkey[1] < 0:
159+
tkey[1] = dim - tkey[1]
160+
161+
if tkey[2] is None:
162+
tkey[2] = 1
163+
164+
return int(((tkey[1] - tkey[0] - 1) / tkey[2]) + 1)
165+
166+
def get_assign_dims(key, idims):
167+
dims = [1]*4
168+
169+
for n in range(len(idims)):
170+
dims[n] = idims[n]
171+
172+
if is_number(key):
173+
dims[0] = 1
174+
return dims
175+
elif isinstance(key, slice):
176+
dims[0] = slice_to_length(key, idims[0])
177+
return dims
178+
elif isinstance(key, array):
179+
dims[0] = key.elements()
180+
return dims
181+
elif isinstance(key, tuple):
182+
n_inds = len(key)
183+
184+
if (n_inds > len(idims)):
185+
raise IndexError("Number of indices greater than array dimensions")
186+
187+
for n in range(n_inds):
188+
if (is_number(key[n])):
189+
dims[n] = 1
190+
elif (isinstance(key[n], array)):
191+
dims[n] = key[n].elements()
192+
elif (isinstance(key[n], slice)):
193+
dims[n] = slice_to_length(key[n], idims[n])
194+
else:
195+
raise IndexError("Invalid type while assigning to arrayfire.array")
196+
197+
return dims
198+
else:
199+
raise IndexError("Invalid type while assigning to arrayfire.array")
200+
86201
class array(object):
87202

88203
def __init__(self, src=None, dims=(0,)):
@@ -152,7 +267,8 @@ def dims(self):
152267
d1 = ct.c_longlong(0)
153268
d2 = ct.c_longlong(0)
154269
d3 = ct.c_longlong(0)
155-
safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1), ct.pointer(d2), ct.pointer(d3), self.arr))
270+
safe_call(clib.af_get_dims(ct.pointer(d0), ct.pointer(d1),\
271+
ct.pointer(d2), ct.pointer(d3), self.arr))
156272
dims = (d0.value,d1.value,d2.value,d3.value)
157273
return dims[:self.numdims()]
158274

@@ -367,6 +483,41 @@ def __nonzero__(self):
367483
# def __abs__(self):
368484
# return self
369485

486+
def __getitem__(self, key):
487+
try:
488+
out = array()
489+
n_dims = self.numdims()
490+
inds = get_indices(key, n_dims)
491+
492+
safe_call(clib.af_index_gen(ct.pointer(out.arr),\
493+
self.arr, ct.c_longlong(n_dims), ct.pointer(inds)))
494+
return out
495+
except RuntimeError as e:
496+
raise IndexError(str(e))
497+
498+
499+
def __setitem__(self, key, val):
500+
try:
501+
n_dims = self.numdims()
502+
503+
if (is_number(val)):
504+
tdims = get_assign_dims(key, self.dims())
505+
other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3])
506+
else:
507+
other_arr = val.arr
508+
509+
out_arr = ct.c_longlong(0)
510+
inds = get_indices(key, n_dims)
511+
512+
safe_call(clib.af_assign_gen(ct.pointer(out_arr),\
513+
self.arr, ct.c_longlong(n_dims), ct.pointer(inds),\
514+
other_arr))
515+
safe_call(clib.af_release_array(self.arr))
516+
self.arr = out_arr
517+
518+
except RuntimeError as e:
519+
raise IndexError(str(e))
520+
370521
def print_array(a):
371522
expr = inspect.stack()[1][-2]
372523
if (expr is not None):

arrayfire/signal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212

1313
def approx1(signal, pos0, method=AF_INTERP_LINEAR, off_grid=0.0):
1414
output = array()
15-
safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr, method, ct.c_double(off_grid)))
15+
safe_call(clib.af_approx1(ct.pointer(output.arr), signal.arr, pos0.arr,\
16+
method, ct.c_double(off_grid)))
1617
return output
1718

1819
def approx2(signal, pos0, pos1, method=AF_INTERP_LINEAR, off_grid=0.0):

arrayfire/util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@ def dim4(d0=1, d1=1, d2=1, d3=1):
1919

2020
return out
2121

22+
def is_number(a):
23+
return isinstance(a, numbers.Number)
24+
2225
def dim4_tuple(dims, default=1):
2326
assert(isinstance(dims, tuple))
2427

2528
if (default is not None):
26-
assert(isinstance(default, numbers.Number))
29+
assert(is_number(default))
2730

2831
out = [default]*4
2932

@@ -32,9 +35,6 @@ def dim4_tuple(dims, default=1):
3235

3336
return tuple(out)
3437

35-
def is_valid_scalar(a):
36-
return isinstance(a, float) or isinstance(a, int) or isinstance(a, complex)
37-
3838
def to_str(c_str):
3939
return str(c_str.value.decode('utf-8'))
4040

tests/simple_array.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
2020

2121

22-
a = af.array(host.array('d', [4, 5, 6]))
22+
a = af.array(host.array('i', [4, 5, 6]))
2323
af.print_array(a)
2424
print(a.elements(), a.type(), a.dims(), a.numdims())
2525
print(a.is_empty(), a.is_scalar(), a.is_column(), a.is_row())
@@ -33,8 +33,36 @@
3333
print(a.is_complex(), a.is_real(), a.is_double(), a.is_single())
3434
print(a.is_real_floating(), a.is_floating(), a.is_integer(), a.is_bool())
3535

36+
a = af.randu(5, 5)
37+
af.print_array(a)
3638
b = af.array(a)
3739
af.print_array(b)
3840

3941
c = a.copy()
4042
af.print_array(c)
43+
af.print_array(a[0,0])
44+
af.print_array(a[0])
45+
af.print_array(a[:])
46+
af.print_array(a[:,:])
47+
af.print_array(a[0:3,])
48+
af.print_array(a[-2:-1,-1])
49+
af.print_array(a[0:5])
50+
af.print_array(a[0:5:2])
51+
52+
idx = af.array(host.array('i', [0, 3, 2]))
53+
af.print_array(idx)
54+
aa = a[idx]
55+
af.print_array(aa)
56+
57+
a[0] = 1
58+
af.print_array(a)
59+
a[0] = af.randu(1, 5)
60+
af.print_array(a)
61+
a[:] = af.randu(5,5)
62+
af.print_array(a)
63+
a[:,-1] = af.randu(5,1)
64+
af.print_array(a)
65+
a[0:5:2] = af.randu(3, 5)
66+
af.print_array(a)
67+
a[idx, idx] = af.randu(3,3)
68+
af.print_array(a)

0 commit comments

Comments
 (0)