7878def _two_arg (f ):
7979 @_wraps (f )
8080 def _f (x1 , x2 , / , ** kwargs ):
81- x1 , x2 = _fix_promotion (x1 , x2 )
81+ # x1, x2 = _fix_promotion(x1, x2)
8282 return f (x1 , x2 , ** kwargs )
8383
8484 if _f .__doc__ is None :
@@ -312,6 +312,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
312312 }
313313 return can_cast_dict [from_ ][to ]
314314
315+ def test_bitwise_or (x : array , y : array ):
316+ if not paddle .is_tensor (x ):
317+ x = paddle .to_tensor (x )
318+ if not paddle .is_tensor (y ):
319+ y = paddle .to_tensor (y )
320+ return paddle .bitwise_or (x , y )
315321
316322# Basic renames
317323bitwise_invert = paddle .bitwise_not
@@ -326,7 +332,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
326332atan2 = _two_arg (paddle .atan2 )
327333bitwise_and = _two_arg (paddle .bitwise_and )
328334bitwise_left_shift = _two_arg (paddle .bitwise_left_shift )
329- bitwise_or = _two_arg (paddle . bitwise_or )
335+ bitwise_or = _two_arg (test_bitwise_or )
330336bitwise_right_shift = _two_arg (paddle .bitwise_right_shift )
331337bitwise_xor = _two_arg (paddle .bitwise_xor )
332338copysign = _two_arg (paddle .copysign )
@@ -455,6 +461,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
455461 out = paddle .unsqueeze (out , a )
456462 return out
457463
464+ _NP_2_PADDLE_DTYPE = {
465+ "BOOL" : 'bool' ,
466+ "UINT8" : 'uint8' ,
467+ "INT8" : 'int8' ,
468+ "INT16" : 'int16' ,
469+ "INT32" : 'int32' ,
470+ "INT64" : 'int64' ,
471+ "FLOAT16" : 'float16' ,
472+ "BFLOAT16" : 'bfloat16' ,
473+ "FLOAT32" : 'float32' ,
474+ "FLOAT64" : 'float64' ,
475+ "COMPLEX128" : 'complex128' ,
476+ "COMPLEX64" : 'complex64' ,
477+ }
458478
459479def prod (
460480 x : array ,
@@ -469,6 +489,10 @@ def prod(
469489 x = paddle .to_tensor (x )
470490 ndim = x .ndim
471491
492+ if dtype is not None :
493+ # import pdb
494+ # pdb.set_trace()
495+ dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
472496 # below because it still needs to upcast.
473497 if axis == ():
474498 if dtype is None :
@@ -825,7 +849,7 @@ def eye(
825849 if n_cols is None :
826850 n_cols = n_rows
827851 z = paddle .zeros ([n_rows , n_cols ], dtype = dtype , ** kwargs ).to (device )
828- if abs (k ) <= n_rows + n_cols :
852+ if n_rows > 0 and n_cols > 0 and abs (k ) <= n_rows + n_cols :
829853 z .diagonal (k ).fill_ (1 )
830854 return z
831855
@@ -1052,6 +1076,10 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
10521076 if x .ndim != 1 :
10531077 raise ValueError ("axis must be specified when ndim > 1" )
10541078 axis = 0
1079+ if not paddle .is_tensor (indices ):
1080+ indices = paddle .to_tensor (indices )
1081+ if not paddle .is_tensor (axis ):
1082+ axis = paddle .to_tensor (axis )
10551083 return paddle .index_select (x , axis , indices , ** kwargs )
10561084
10571085
@@ -1144,7 +1172,6 @@ def floor(x: array, /) -> array:
11441172def ceil (x : array , / ) -> array :
11451173 return paddle .ceil (x ).to (x .dtype )
11461174
1147-
11481175def clip (
11491176 x : array ,
11501177 / ,
@@ -1250,7 +1277,6 @@ def searchsorted(
12501277 right = (side == "right" ),
12511278 )
12521279
1253-
12541280__all__ = [
12551281 "__array_namespace_info__" ,
12561282 "result_type" ,
0 commit comments