99from .library import *
1010from .util import *
1111from .base import *
12+ from .broadcast import *
1213
1314class seq (ct .Structure ):
1415 _fields_ = [("begin" , ct .c_double ),
@@ -35,6 +36,31 @@ def __init__ (self, S):
3536 else :
3637 raise IndexError ("Invalid type while indexing arrayfire.array" )
3738
39+ class parallel_range (seq ):
40+
41+ def __init__ (self , start , stop = None , step = None ):
42+
43+ if (stop is None ):
44+ stop = start
45+ start = 0
46+
47+ self .S = slice (start , stop , step )
48+ super (parallel_range , self ).__init__ (self .S )
49+
50+ def __iter__ (self ):
51+ return self
52+
53+ def next (self ):
54+ if bcast .get () is True :
55+ bcast .toggle ()
56+ raise StopIteration
57+ else :
58+ bcast .toggle ()
59+ return self
60+
61+ def __next__ (self ):
62+ return self .next ()
63+
3864def slice_to_length (key , dim ):
3965 tkey = [key .start , key .stop , key .step ]
4066
@@ -71,6 +97,9 @@ def __init__ (self, idx):
7197 if isinstance (idx , base_array ):
7298 self .idx .arr = idx .arr
7399 self .isSeq = False
100+ elif isinstance (idx , parallel_range ):
101+ self .idx .seq = idx
102+ self .isBatch = True
74103 else :
75104 self .idx .seq = seq (idx )
76105
@@ -104,6 +133,9 @@ def get_assign_dims(key, idims):
104133 elif isinstance (key , slice ):
105134 dims [0 ] = slice_to_length (key , idims [0 ])
106135 return dims
136+ elif isinstance (key , parallel_range ):
137+ dims [0 ] = slice_to_length (key .S , idims [0 ])
138+ return dims
107139 elif isinstance (key , base_array ):
108140 dims [0 ] = key .elements ()
109141 return dims
@@ -120,6 +152,8 @@ def get_assign_dims(key, idims):
120152 dims [n ] = key [n ].elements ()
121153 elif (isinstance (key [n ], slice )):
122154 dims [n ] = slice_to_length (key [n ], idims [n ])
155+ elif (isinstance (key [n ], parallel_range )):
156+ dims [n ] = slice_to_length (key [n ].S , idims [n ])
123157 else :
124158 raise IndexError ("Invalid type while assigning to arrayfire.array" )
125159
0 commit comments