@@ -55,13 +55,13 @@ def binary_func(lhs, rhs, c_func):
5555 out = array ()
5656 other = rhs
5757
58- if (is_valid_scalar (rhs )):
58+ if (is_number (rhs )):
5959 ldims = dim4_tuple (lhs .dims ())
6060 lty = lhs .type ()
6161 other = array ()
6262 other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], lty )
6363 elif not isinstance (rhs , array ):
64- TypeError ("Invalid parameter to binary function" )
64+ raise TypeError ("Invalid parameter to binary function" )
6565
6666 safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , False ))
6767
@@ -71,18 +71,133 @@ def binary_funcr(lhs, rhs, c_func):
7171 out = array ()
7272 other = lhs
7373
74- if (is_valid_scalar (lhs )):
74+ if (is_number (lhs )):
7575 rdims = dim4_tuple (rhs .dims ())
7676 rty = rhs .type ()
7777 other = array ()
7878 other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], rty )
7979 elif not isinstance (lhs , array ):
80- TypeError ("Invalid parameter to binary function" )
80+ raise TypeError ("Invalid parameter to binary function" )
8181
8282 c_func (ct .pointer (out .arr ), other .arr , rhs .arr , False )
8383
8484 return out
8585
86+ class seq (ct .Structure ):
87+ _fields_ = [("begin" , ct .c_double ),
88+ ("end" , ct .c_double ),
89+ ("step" , ct .c_double )]
90+
91+ def __init__ (self , S ):
92+ num = __import__ ("numbers" )
93+
94+ self .begin = ct .c_double ( 0 )
95+ self .end = ct .c_double (- 1 )
96+ self .step = ct .c_double ( 1 )
97+
98+ if is_number (S ):
99+ self .begin = ct .c_double (S )
100+ self .end = ct .c_double (S )
101+ elif isinstance (S , slice ):
102+ if (S .start is not None ):
103+ self .begin = ct .c_double (S .start )
104+ if (S .stop is not None ):
105+ self .end = ct .c_double (S .stop - 1 ) if S .stop >= 0 else ct .c_double (S .stop )
106+ if (S .step is not None ):
107+ self .step = ct .c_double (S .step )
108+ else :
109+ raise IndexError ("Invalid type while indexing arrayfire.array" )
110+
111+ class uidx (ct .Union ):
112+ _fields_ = [("arr" , ct .c_longlong ),
113+ ("seq" , seq )]
114+
115+ class index (ct .Structure ):
116+ _fields_ = [("idx" , uidx ),
117+ ("isSeq" , ct .c_bool ),
118+ ("isBatch" , ct .c_bool )]
119+
120+ def __init__ (self , idx ):
121+
122+ self .idx = uidx ()
123+ self .isBatch = False
124+ self .isSeq = True
125+
126+ if isinstance (idx , array ):
127+ self .idx .arr = idx .arr
128+ self .isSeq = False
129+ else :
130+ self .idx .seq = seq (idx )
131+
132+ def get_indices (key , n_dims ):
133+ index_vec = index * n_dims
134+ inds = index_vec ()
135+
136+ for n in range (n_dims ):
137+ inds [n ] = index (slice (0 , - 1 ))
138+
139+ if isinstance (key , tuple ):
140+ num_idx = len (key )
141+ for n in range (n_dims ):
142+ inds [n ] = index (key [n ]) if (n < num_idx ) else index (slice (0 , - 1 ))
143+ else :
144+ inds [0 ] = index (key )
145+
146+ return inds
147+
148+ def slice_to_length (key , dim ):
149+ tkey = [key .start , key .stop , key .step ]
150+
151+ if tkey [0 ] is None :
152+ tkey [0 ] = 0
153+ elif tkey [0 ] < 0 :
154+ tkey [0 ] = dim - tkey [0 ]
155+
156+ if tkey [1 ] is None :
157+ tkey [1 ] = dim
158+ elif tkey [1 ] < 0 :
159+ tkey [1 ] = dim - tkey [1 ]
160+
161+ if tkey [2 ] is None :
162+ tkey [2 ] = 1
163+
164+ return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
165+
166+ def get_assign_dims (key , idims ):
167+ dims = [1 ]* 4
168+
169+ for n in range (len (idims )):
170+ dims [n ] = idims [n ]
171+
172+ if is_number (key ):
173+ dims [0 ] = 1
174+ return dims
175+ elif isinstance (key , slice ):
176+ dims [0 ] = slice_to_length (key , idims [0 ])
177+ return dims
178+ elif isinstance (key , array ):
179+ dims [0 ] = key .elements ()
180+ return dims
181+ elif isinstance (key , tuple ):
182+ n_inds = len (key )
183+
184+ if (n_inds > len (idims )):
185+ raise IndexError ("Number of indices greater than array dimensions" )
186+
187+ for n in range (n_inds ):
188+ if (is_number (key [n ])):
189+ dims [n ] = 1
190+ elif (isinstance (key [n ], array )):
191+ dims [n ] = key [n ].elements ()
192+ elif (isinstance (key [n ], slice )):
193+ dims [n ] = slice_to_length (key [n ], idims [n ])
194+ else :
195+ raise IndexError ("Invalid type while assigning to arrayfire.array" )
196+
197+ return dims
198+ else :
199+ raise IndexError ("Invalid type while assigning to arrayfire.array" )
200+
86201class array (object ):
87202
88203 def __init__ (self , src = None , dims = (0 ,)):
@@ -152,7 +267,8 @@ def dims(self):
152267 d1 = ct .c_longlong (0 )
153268 d2 = ct .c_longlong (0 )
154269 d3 = ct .c_longlong (0 )
155- safe_call (clib .af_get_dims (ct .pointer (d0 ), ct .pointer (d1 ), ct .pointer (d2 ), ct .pointer (d3 ), self .arr ))
270+ safe_call (clib .af_get_dims (ct .pointer (d0 ), ct .pointer (d1 ),\
271+ ct .pointer (d2 ), ct .pointer (d3 ), self .arr ))
156272 dims = (d0 .value ,d1 .value ,d2 .value ,d3 .value )
157273 return dims [:self .numdims ()]
158274
@@ -367,6 +483,41 @@ def __nonzero__(self):
367483 # def __abs__(self):
368484 # return self
369485
486+ def __getitem__ (self , key ):
487+ try :
488+ out = array ()
489+ n_dims = self .numdims ()
490+ inds = get_indices (key , n_dims )
491+
492+ safe_call (clib .af_index_gen (ct .pointer (out .arr ),\
493+ self .arr , ct .c_longlong (n_dims ), ct .pointer (inds )))
494+ return out
495+ except RuntimeError as e :
496+ raise IndexError (str (e ))
497+
498+
499+ def __setitem__ (self , key , val ):
500+ try :
501+ n_dims = self .numdims ()
502+
503+ if (is_number (val )):
504+ tdims = get_assign_dims (key , self .dims ())
505+ other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ])
506+ else :
507+ other_arr = val .arr
508+
509+ out_arr = ct .c_longlong (0 )
510+ inds = get_indices (key , n_dims )
511+
512+ safe_call (clib .af_assign_gen (ct .pointer (out_arr ),\
513+ self .arr , ct .c_longlong (n_dims ), ct .pointer (inds ),\
514+ other_arr ))
515+ safe_call (clib .af_release_array (self .arr ))
516+ self .arr = out_arr
517+
518+ except RuntimeError as e :
519+ raise IndexError (str (e ))
520+
370521def print_array (a ):
371522 expr = inspect .stack ()[1 ][- 2 ]
372523 if (expr is not None ):
0 commit comments