@@ -112,25 +112,32 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype]) -> Dtype:
112112 raise TypeError ("At least one array or dtype must be provided" )
113113 if len (arrays_and_dtypes ) == 1 :
114114 x = arrays_and_dtypes [0 ]
115- if isinstance (x , paddle .dtype ):
116- return x
117- return x .dtype
115+ return x if isinstance (x , paddle .dtype ) else x .dtype
118116 if len (arrays_and_dtypes ) > 2 :
119117 return result_type (arrays_and_dtypes [0 ], result_type (* arrays_and_dtypes [1 :]))
120118
121119 x , y = arrays_and_dtypes
122- xdt = x . dtype if not isinstance (x , paddle .dtype ) else x
123- ydt = y . dtype if not isinstance (y , paddle .dtype ) else y
120+ xdt = x if isinstance (x , paddle .dtype ) else x . dtype
121+ ydt = y if isinstance (y , paddle .dtype ) else y . dtype
124122
125123 if (xdt , ydt ) in _promotion_table :
126- return _promotion_table [xdt , ydt ]
127-
128- # This doesn't result_type(dtype, dtype) for non-array API dtypes
129- # because paddle.result_type only accepts tensors. This does however, allow
130- # cross-kind promotion.
131- x = paddle .to_tensor ([], dtype = x ) if isinstance (x , paddle .dtype ) else x
132- y = paddle .to_tensor ([], dtype = y ) if isinstance (y , paddle .dtype ) else y
133- return paddle .result_type (x , y )
124+ return _promotion_table [(xdt , ydt )]
125+
126+ type_order = {
127+ paddle .bool : 0 ,
128+ paddle .int8 : 1 ,
129+ paddle .uint8 : 2 ,
130+ paddle .int16 : 3 ,
131+ paddle .int32 : 4 ,
132+ paddle .int64 : 5 ,
133+ paddle .float16 : 6 ,
134+ paddle .float32 : 7 ,
135+ paddle .float64 : 8 ,
136+ paddle .complex64 : 9 ,
137+ paddle .complex128 : 10
138+ }
139+
140+ return xdt if type_order .get (xdt , 0 ) > type_order .get (ydt , 0 ) else ydt
134141
135142
136143def can_cast (from_ : Union [Dtype , array ], to : Dtype , / ) -> bool :
@@ -922,7 +929,15 @@ def astype(
922929
923930
924931def broadcast_arrays (* arrays : array ) -> List [array ]:
925- return paddle .broadcast_tensors (arrays )
932+ original_dtypes = [arr .dtype for arr in arrays ]
933+ if len (set (original_dtypes )) == 1 :
934+ return paddle .broadcast_tensors (arrays )
935+ target_dtype = result_type (* arrays )
936+ casted_arrays = [arr .astype (target_dtype ) if arr .dtype != target_dtype else arr
937+ for arr in arrays ]
938+ broadcasted = paddle .broadcast_tensors (casted_arrays )
939+ result = [arr .astype (original_dtype ) for arr , original_dtype in zip (broadcasted , original_dtypes )]
940+ return result
926941
927942
928943# Note that these named tuples aren't actually part of the standard namespace,
0 commit comments