@@ -42,14 +42,20 @@ def get_group_id(self, dim):
4242
4343 def get_group_linear_id (self ):
4444 """Returns a linearized version of the work-group index."""
45- if len (self ._index ) == 1 :
46- return self ._index [0 ]
47- if len (self ._index ) == 2 :
48- return self ._index [0 ] * self ._group_range [1 ] + self ._index [1 ]
45+ if self .dimensions == 1 :
46+ return self .get_group_id (0 )
47+ if self .dimensions == 2 :
48+ return self .get_group_id (0 ) * self .get_group_range (
49+ 1
50+ ) + self .get_group_id (1 )
4951 return (
50- (self ._index [0 ] * self ._group_range [1 ] * self ._group_range [2 ])
51- + (self ._index [1 ] * self ._group_range [2 ])
52- + (self ._index [2 ])
52+ (
53+ self .get_group_id (0 )
54+ * self .get_group_range (1 )
55+ * self .get_group_range (2 )
56+ )
57+ + (self .get_group_id (1 ) * self .get_group_range (2 ))
58+ + (self .get_group_id (2 ))
5359 )
5460
5561 def get_group_range (self , dim ):
@@ -61,8 +67,8 @@ def get_group_range(self, dim):
6167 def get_group_linear_range (self ):
6268 """Return the total number of work-groups in the nd_range."""
6369 num_wg = 1
64- for ext in self ._group_range :
65- num_wg *= ext
70+ for i in range ( self .dimensions ) :
71+ num_wg *= self . get_group_range ( i )
6672
6773 return num_wg
6874
@@ -76,8 +82,8 @@ def get_local_range(self, dim):
7682 def get_local_linear_range (self ):
7783 """Return the total number of work-items in the work-group."""
7884 num_wi = 1
79- for ext in self ._local_range :
80- num_wi *= ext
85+ for i in range ( self .dimensions ) :
86+ num_wi *= self . get_local_range ( i )
8187
8288 return num_wi
8389
@@ -128,14 +134,14 @@ def get_linear_id(self):
128134 Returns:
129135 int: The linear id.
130136 """
131- if len ( self ._extent ) == 1 :
132- return self ._index [ 0 ]
133- if len ( self ._extent ) == 2 :
134- return self ._index [ 0 ] * self ._extent [ 1 ] + self ._index [ 1 ]
137+ if self .dimensions == 1 :
138+ return self .get_id ( 0 )
139+ if self .dimensions == 2 :
140+ return self .get_id ( 0 ) * self .get_range ( 1 ) + self .get_id ( 1 )
135141 return (
136- (self ._index [ 0 ] * self ._extent [ 1 ] * self ._extent [ 2 ] )
137- + (self ._index [ 1 ] * self ._extent [ 2 ] )
138- + (self ._index [ 2 ] )
142+ (self .get_id ( 0 ) * self .get_range ( 1 ) * self .get_range ( 2 ) )
143+ + (self .get_id ( 1 ) * self .get_range ( 2 ) )
144+ + (self .get_id ( 2 ) )
139145 )
140146
141147 def get_id (self , idx ):
@@ -146,6 +152,14 @@ def get_id(self, idx):
146152 """
147153 return self ._index [idx ]
148154
155+ def get_linear_range (self ):
156+ """Return the total number of work-items in the work-group."""
157+ num_wi = 1
158+ for i in range (self .dimensions ):
159+ num_wi *= self .get_range (i )
160+
161+ return num_wi
162+
149163 def get_range (self , idx ):
150164 """Get the range size for a specific dimension.
151165
@@ -193,7 +207,24 @@ def get_global_linear_id(self):
193207 Returns:
194208 int: The global linear id.
195209 """
196- return self ._global_item .get_linear_id ()
210+ # Instead of calling self._global_item.get_linear_id(), the linearization
211+ # logic is duplicated here so that the method can be JIT compiled by
212+ # numba-dpex and works in both Python and Numba nopython modes.
213+ if self .dimensions == 1 :
214+ return self .get_global_id (0 )
215+ if self .dimensions == 2 :
216+ return self .get_global_id (0 ) * self .get_global_range (
217+ 1
218+ ) + self .get_global_id (1 )
219+ return (
220+ (
221+ self .get_global_id (0 )
222+ * self .get_global_range (1 )
223+ * self .get_global_range (2 )
224+ )
225+ + (self .get_global_id (1 ) * self .get_global_range (2 ))
226+ + (self .get_global_id (2 ))
227+ )
197228
198229 def get_local_id (self , idx ):
199230 """Get the local id for a specific dimension.
@@ -210,7 +241,24 @@ def get_local_linear_id(self):
210241 Returns:
211242 int: The local linear id.
212243 """
213- return self ._local_item .get_linear_id ()
244+ # Instead of calling self._local_item.get_linear_id(), the linearization
245+ # logic is duplicated here so that the method can be JIT compiled by
246+ # numba-dpex and works in both Python and Numba nopython modes.
247+ if self .dimensions == 1 :
248+ return self .get_local_id (0 )
249+ if self .dimensions == 2 :
250+ return self .get_local_id (0 ) * self .get_local_range (
251+ 1
252+ ) + self .get_local_id (1 )
253+ return (
254+ (
255+ self .get_local_id (0 )
256+ * self .get_local_range (1 )
257+ * self .get_local_range (2 )
258+ )
259+ + (self .get_local_id (1 ) * self .get_local_range (2 ))
260+ + (self .get_local_id (2 ))
261+ )
214262
215263 def get_global_range (self , idx ):
216264 """Get the global range size for a specific dimension.
@@ -228,6 +276,22 @@ def get_local_range(self, idx):
228276 """
229277 return self ._local_item .get_range (idx )
230278
279+ def get_local_linear_range (self ):
280+ """Return the total number of work-items in the work-group."""
281+ num_wi = 1
282+ for i in range (self .dimensions ):
283+ num_wi *= self .get_local_range (i )
284+
285+ return num_wi
286+
287+ def get_global_linear_range (self ):
288+ """Return the total number of work-items in the work-group."""
289+ num_wi = 1
290+ for i in range (self .dimensions ):
291+ num_wi *= self .get_global_range (i )
292+
293+ return num_wi
294+
231295 def get_group (self ):
232296 """Returns the group.
233297
0 commit comments