1111from .library import *
1212from .util import *
1313from .broadcast import *
14+ from .base import *
15+ from .index import *
1416
1517def create_array (buf , numdims , idims , dtype ):
1618 out_arr = ct .c_longlong (0 )
@@ -92,86 +94,6 @@ def transpose(a, conj=False):
9294def transpose_inplace (a , conj = False ):
9395 safe_call (clib .af_transpose_inplace (a .arr , conj ))
9496
95- class seq (ct .Structure ):
96- _fields_ = [("begin" , ct .c_double ),
97- ("end" , ct .c_double ),
98- ("step" , ct .c_double )]
99-
100- def __init__ (self , S ):
101- num = __import__ ("numbers" )
102-
103- self .begin = ct .c_double ( 0 )
104- self .end = ct .c_double (- 1 )
105- self .step = ct .c_double ( 1 )
106-
107- if is_number (S ):
108- self .begin = ct .c_double (S )
109- self .end = ct .c_double (S )
110- elif isinstance (S , slice ):
111- if (S .start is not None ):
112- self .begin = ct .c_double (S .start )
113- if (S .stop is not None ):
114- self .end = ct .c_double (S .stop - 1 )
115- if (S .step is not None ):
116- self .step = ct .c_double (S .step )
117- else :
118- raise IndexError ("Invalid type while indexing arrayfire.array" )
119-
120- class uidx (ct .Union ):
121- _fields_ = [("arr" , ct .c_longlong ),
122- ("seq" , seq )]
123-
124- class index (ct .Structure ):
125- _fields_ = [("idx" , uidx ),
126- ("isSeq" , ct .c_bool ),
127- ("isBatch" , ct .c_bool )]
128-
129- def __init__ (self , idx ):
130-
131- self .idx = uidx ()
132- self .isBatch = False
133- self .isSeq = True
134-
135- if isinstance (idx , array ):
136- self .idx .arr = idx .arr
137- self .isSeq = False
138- else :
139- self .idx .seq = seq (idx )
140-
141- def get_indices (key , n_dims ):
142- index_vec = index * n_dims
143- inds = index_vec ()
144-
145- for n in range (n_dims ):
146- inds [n ] = index (slice (None ))
147-
148- if isinstance (key , tuple ):
149- n_idx = len (key )
150- for n in range (n_idx ):
151- inds [n ] = index (key [n ])
152- else :
153- inds [0 ] = index (key )
154-
155- return inds
156-
157- def slice_to_length (key , dim ):
158- tkey = [key .start , key .stop , key .step ]
159-
160- if tkey [0 ] is None :
161- tkey [0 ] = 0
162- elif tkey [0 ] < 0 :
163- tkey [0 ] = dim - tkey [0 ]
164-
165- if tkey [1 ] is None :
166- tkey [1 ] = dim
167- elif tkey [1 ] < 0 :
168- tkey [1 ] = dim - tkey [1 ]
169-
170- if tkey [2 ] is None :
171- tkey [2 ] = 1
172-
173- return int (((tkey [1 ] - tkey [0 ] - 1 ) / tkey [2 ]) + 1 )
174-
17597def ctype_to_lists (ctype_arr , dim , shape , offset = 0 ):
17698 if (dim == 0 ):
17799 return list (ctype_arr [offset : offset + shape [0 ]])
@@ -183,46 +105,11 @@ def ctype_to_lists(ctype_arr, dim, shape, offset=0):
183105 offset += shape [0 ]
184106 return res
185107
186- def get_assign_dims (key , idims ):
187- dims = [1 ]* 4
188-
189- for n in range (len (idims )):
190- dims [n ] = idims [n ]
191-
192- if is_number (key ):
193- dims [0 ] = 1
194- return dims
195- elif isinstance (key , slice ):
196- dims [0 ] = slice_to_length (key , idims [0 ])
197- return dims
198- elif isinstance (key , array ):
199- dims [0 ] = key .elements ()
200- return dims
201- elif isinstance (key , tuple ):
202- n_inds = len (key )
203-
204- if (n_inds > len (idims )):
205- raise IndexError ("Number of indices greater than array dimensions" )
206-
207- for n in range (n_inds ):
208- if (is_number (key [n ])):
209- dims [n ] = 1
210- elif (isinstance (key [n ], array )):
211- dims [n ] = key [n ].elements ()
212- elif (isinstance (key [n ], slice )):
213- dims [n ] = slice_to_length (key [n ], idims [n ])
214- else :
215- raise IndexError ("Invalid type while assigning to arrayfire.array" )
216-
217- return dims
218- else :
219- raise IndexError ("Invalid type while assigning to arrayfire.array" )
220-
221- class array (object ):
108+ class array (base_array ):
222109
223110 def __init__ (self , src = None , dims = (0 ,)):
224111
225- self . arr = ct . c_longlong ( 0 )
112+ super ( array , self ). __init__ ( )
226113
227114 buf = None
228115 buf_len = 0
0 commit comments