Skip to content

Commit 6f32d63

Browse files
author
hongyuHe
committed
update
1 parent fd6eea0 commit 6f32d63

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
def _two_arg(f):
7979
@_wraps(f)
8080
def _f(x1, x2, /, **kwargs):
81-
x1, x2 = _fix_promotion(x1, x2)
81+
# x1, x2 = _fix_promotion(x1, x2)
8282
return f(x1, x2, **kwargs)
8383

8484
if _f.__doc__ is None:
@@ -312,6 +312,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
312312
}
313313
return can_cast_dict[from_][to]
314314

315+
def test_bitwise_or(x: array, y: array):
316+
if not paddle.is_tensor(x):
317+
x = paddle.to_tensor(x)
318+
if not paddle.is_tensor(y):
319+
y = paddle.to_tensor(y)
320+
return paddle.bitwise_or(x, y)
315321

316322
# Basic renames
317323
bitwise_invert = paddle.bitwise_not
@@ -326,7 +332,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
326332
atan2 = _two_arg(paddle.atan2)
327333
bitwise_and = _two_arg(paddle.bitwise_and)
328334
bitwise_left_shift = _two_arg(paddle.bitwise_left_shift)
329-
bitwise_or = _two_arg(paddle.bitwise_or)
335+
bitwise_or = _two_arg(test_bitwise_or)
330336
bitwise_right_shift = _two_arg(paddle.bitwise_right_shift)
331337
bitwise_xor = _two_arg(paddle.bitwise_xor)
332338
copysign = _two_arg(paddle.copysign)
@@ -455,6 +461,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
455461
out = paddle.unsqueeze(out, a)
456462
return out
457463

464+
_NP_2_PADDLE_DTYPE = {
465+
"BOOL": 'bool',
466+
"UINT8": 'uint8',
467+
"INT8": 'int8',
468+
"INT16": 'int16',
469+
"INT32": 'int32',
470+
"INT64": 'int64',
471+
"FLOAT16": 'float16',
472+
"BFLOAT16": 'bfloat16',
473+
"FLOAT32": 'float32',
474+
"FLOAT64": 'float64',
475+
"COMPLEX128": 'complex128',
476+
"COMPLEX64": 'complex64',
477+
}
458478

459479
def prod(
460480
x: array,
@@ -469,6 +489,10 @@ def prod(
469489
x = paddle.to_tensor(x)
470490
ndim = x.ndim
471491

492+
if dtype is not None:
493+
# import pdb
494+
# pdb.set_trace()
495+
dtype = _NP_2_PADDLE_DTYPE[dtype.name]
472496
# below because it still needs to upcast.
473497
if axis == ():
474498
if dtype is None:
@@ -825,7 +849,7 @@ def eye(
825849
if n_cols is None:
826850
n_cols = n_rows
827851
z = paddle.zeros([n_rows, n_cols], dtype=dtype, **kwargs).to(device)
828-
if abs(k) <= n_rows + n_cols:
852+
if n_rows > 0 and n_cols > 0 and abs(k) <= n_rows + n_cols:
829853
z.diagonal(k).fill_(1)
830854
return z
831855

@@ -1052,6 +1076,10 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
10521076
if x.ndim != 1:
10531077
raise ValueError("axis must be specified when ndim > 1")
10541078
axis = 0
1079+
if not paddle.is_tensor(indices):
1080+
indices = paddle.to_tensor(indices)
1081+
if not paddle.is_tensor(axis):
1082+
axis = paddle.to_tensor(axis)
10551083
return paddle.index_select(x, axis, indices, **kwargs)
10561084

10571085

@@ -1144,7 +1172,6 @@ def floor(x: array, /) -> array:
11441172
def ceil(x: array, /) -> array:
11451173
return paddle.ceil(x).to(x.dtype)
11461174

1147-
11481175
def clip(
11491176
x: array,
11501177
/,
@@ -1250,7 +1277,6 @@ def searchsorted(
12501277
right=(side == "right"),
12511278
)
12521279

1253-
12541280
__all__ = [
12551281
"__array_namespace_info__",
12561282
"result_type",

array_api_compat/paddle/linalg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
8484
# Use our wrapped sum to make sure it does upcasting correctly
8585
return sum(paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1), axis=-1, dtype=dtype)
8686

87+
def diagonal(x: ndarray, / , *, offset: int = 0, **kwargs) -> ndarray:
88+
return paddle.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
8789

8890
def vector_norm(
8991
x: array,
@@ -141,6 +143,24 @@ def slogdet(x: array):
141143
slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
142144
return slotdet(sign, log_det)
143145

146+
def tuple_to_namedtuple(data, fields):
147+
nt_class = namedtuple('DynamicNameTuple', fields)
148+
return nt_class(*data)
149+
150+
def eigh(x: array):
151+
return tuple_to_namedtuple(paddle.linalg.eigh(x), ['eigenvalues', 'eigenvectors'])
152+
153+
def qr(x: array, mode: Optional[str] = None) -> array:
154+
if mode is None:
155+
return tuple_to_namedtuple(paddle.linalg.qr(x), ['Q', 'R'])
156+
157+
return tuple_to_namedtuple(paddle.linalg.qr(x, mode), ['Q', 'R'])
158+
159+
160+
def svd(x: array, full_matrices: Optional[bool]= None) -> array:
161+
if full_matrices is None :
162+
return tuple_to_namedtuple(paddle.linalg.svd(x), ['U', 'S', 'Vh'])
163+
return tuple_to_namedtuple(paddle.linalg.svd(x, full_matrices), ['U', 'S', 'Vh'])
144164

145165
__all__ = linalg_all + [
146166
"outer",
@@ -154,6 +174,8 @@ def slogdet(x: array):
154174
"trace",
155175
"vector_norm",
156176
"slogdet",
177+
"eigh",
178+
"diagonal",
157179
]
158180

159181
_all_ignore = ["paddle_linalg", "sum"]

vendor_test/vendored/_compat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
../../array_api_compat/
1+
../../array_api_compat

0 commit comments

Comments
 (0)