@@ -93,16 +93,13 @@ def _var_impl(x, axis, correction, keepdims):
9393 )
9494 # divide in-place to get mean
9595 mean_ary_shape = mean_ary .shape
96- nelems_ary = dpt .asarray (
97- nelems , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
98- )
99- if nelems_ary .shape != mean_ary_shape :
100- nelems_ary = dpt .broadcast_to (nelems_ary , mean_ary_shape )
96+
10197 dep_evs = _manager .submitted_events
102- ht_e2 , d_e1 = tei ._divide_inplace (
103- lhs = mean_ary , rhs = nelems_ary , sycl_queue = q , depends = dep_evs
98+ ht_e2 , d_e1 = tei ._divide_by_scalar (
99+ src = mean_ary , scalar = nelems , dst = mean_ary , sycl_queue = q , depends = dep_evs
104100 )
105101 _manager .add_event_pair (ht_e2 , d_e1 )
102+
106103 # subtract mean from original array to get deviations
107104 dev_ary = dpt .empty_like (buf )
108105 if mean_ary_shape != buf .shape :
@@ -144,17 +141,18 @@ def _var_impl(x, axis, correction, keepdims):
144141 res_shape = res .shape
145142 # when nelems - correction <= 0, yield nans
146143 div = max (nelems - correction , 0 )
147- if not div :
148- div = dpt .nan
149- div_ary = dpt .asarray (
150- div , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
151- )
152- # divide in-place again
153- if div_ary .shape != res_shape :
154- div_ary = dpt .broadcast_to (div_ary , res .shape )
144+ if div :
145+ dep_evs = _manager .submitted_events
146+ ht_e7 , d_e2 = tei ._divide_by_scalar (
147+ src = res , scalar = div , dst = res , sycl_queue = q , depends = dep_evs
148+ )
149+ _manager .add_event_pair (ht_e7 , d_e2 )
150+ return res , [d_e2 ]
151+
152+ div = dpt .nan
155153 dep_evs = _manager .submitted_events
156- ht_e7 , d_e2 = tei ._divide_inplace (
157- lhs = res , rhs = div_ary , sycl_queue = q , depends = dep_evs
154+ ht_e7 , d_e2 = tei ._divide_by_scalar (
155+ src = res , scalar = div , dst = res , sycl_queue = q , depends = dep_evs
158156 )
159157 _manager .add_event_pair (ht_e7 , d_e2 )
160158 return res , [d_e2 ]
@@ -259,17 +257,9 @@ def mean(x, axis=None, keepdims=False):
259257 inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
260258 res = dpt .permute_dims (dpt .reshape (res , res_shape ), inv_perm )
261259
262- res_shape = res .shape
263- # in-place divide
264- den_dt = dpt .finfo (res_dt ).dtype if res_dt .kind == "c" else res_dt
265- nelems_arr = dpt .asarray (
266- nelems , dtype = den_dt , usm_type = res_usm_type , sycl_queue = q
267- )
268- if nelems_arr .shape != res_shape :
269- nelems_arr = dpt .broadcast_to (nelems_arr , res_shape )
270260 dep_evs = _manager .submitted_events
271- ht_e2 , div_e = tei ._divide_inplace (
272- lhs = res , rhs = nelems_arr , sycl_queue = q , depends = dep_evs
261+ ht_e2 , div_e = tei ._divide_by_scalar (
262+ src = res , scalar = nelems , dst = res , sycl_queue = q , depends = dep_evs
273263 )
274264 _manager .add_event_pair (ht_e2 , div_e )
275265 return res
0 commit comments