@@ -48,18 +48,14 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
4848 return tuple (out_strides )
4949
5050
51- def _broadcast_shapes (* args ):
52- """
53- Broadcast the input shapes into a single shape;
54- returns tuple broadcasted shape.
55- """
56- shapes = [array .shape for array in args ]
51+ def _broadcast_shape_impl (shapes ):
5752 if len (set (shapes )) == 1 :
5853 return shapes [0 ]
5954 mutable_shapes = False
6055 nds = [len (s ) for s in shapes ]
6156 biggest = max (nds )
62- for i in range (len (args )):
57+ sh_len = len (shapes )
58+ for i in range (sh_len ):
6359 diff = biggest - nds [i ]
6460 if diff > 0 :
6561 ty = type (shapes [i ])
@@ -77,7 +73,7 @@ def _broadcast_shapes(*args):
7773 unique .remove (1 )
7874 new_length = unique .pop ()
7975 common_shape .append (new_length )
80- for i in range (len ( args ) ):
76+ for i in range (sh_len ):
8177 if shapes [i ][axis ] == 1 :
8278 if not mutable_shapes :
8379 shapes = [list (s ) for s in shapes ]
@@ -89,6 +85,15 @@ def _broadcast_shapes(*args):
8985 return tuple (common_shape )
9086
9187
88+ def _broadcast_shapes (* args ):
89+ """
90+ Broadcast the input shapes into a single shape;
91+ returns tuple broadcasted shape.
92+ """
93+ array_shapes = [array .shape for array in args ]
94+ return _broadcast_shape_impl (array_shapes )
95+
96+
9297def permute_dims (X , axes ):
9398 """
9499 permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray
0 commit comments