99
1010from .library import *
1111from .array import *
12+ from .broadcast import *
1213
1314def arith_binary_func (lhs , rhs , c_func ):
1415 out = array ()
@@ -20,21 +21,21 @@ def arith_binary_func(lhs, rhs, c_func):
2021 TypeError ("Atleast one input needs to be of type arrayfire.array" )
2122
2223 elif (is_left_array and is_right_array ):
23- safe_call (c_func (ct .pointer (out .arr ), lhs .arr , rhs .arr , False ))
24+ safe_call (c_func (ct .pointer (out .arr ), lhs .arr , rhs .arr , bcast . get () ))
2425
2526 elif (is_number (rhs )):
2627 ldims = dim4_tuple (lhs .dims ())
2728 lty = lhs .type ()
2829 other = array ()
2930 other .arr = constant_array (rhs , ldims [0 ], ldims [1 ], ldims [2 ], ldims [3 ], lty )
30- safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , False ))
31+ safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , bcast . get () ))
3132
3233 else :
3334 rdims = dim4_tuple (rhs .dims ())
3435 rty = rhs .type ()
3536 other = array ()
3637 other .arr = constant_array (lhs , rdims [0 ], rdims [1 ], rdims [2 ], rdims [3 ], rty )
37- safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , False ))
38+ safe_call (c_func (ct .pointer (out .arr ), lhs .arr , other .arr , bcast . get () ))
3839
3940 return out
4041
0 commit comments