Skip to content

Commit b946e82

Browse files
author
Hongyuhe
committed
update
1 parent 46f81c7 commit b946e82

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

array_api_compat/paddle/_aliases.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
938938

939939

940940
def expand_dims(x: array, /, *, axis: int = 0) -> array:
941+
if axis < -x.ndim - 1 or axis > x.ndim:
942+
raise IndexError(f"Axis {axis} is out of bounds for array of dimension { x.ndim}")
941943
return paddle.unsqueeze(x, axis)
942944

943945

@@ -1087,15 +1089,31 @@ def is_complex(dtype):
10871089

10881090

10891091
def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array:
1090-
if axis is None:
1092+
_axis = axis
1093+
if _axis is None:
10911094
if x.ndim != 1:
1092-
raise ValueError("axis must be specified when ndim > 1")
1093-
axis = 0
1094-
if not paddle.is_tensor(indices):
1095-
indices = paddle.to_tensor(indices)
1096-
if not paddle.is_tensor(axis):
1097-
axis = paddle.to_tensor(axis)
1098-
return paddle.index_select(x, axis, indices, **kwargs)
1095+
raise ValueError("axis must be specified when x.ndim > 1")
1096+
_axis = 0
1097+
elif not isinstance(_axis, int):
1098+
raise TypeError(f"axis must be an integer, but received {type(_axis)}")
1099+
1100+
if not (-x.ndim <= _axis < x.ndim):
1101+
raise IndexError(f"axis {_axis} is out of bounds for tensor of dimension {x.ndim}")
1102+
1103+
if isinstance(indices, paddle.Tensor):
1104+
indices_tensor = indices
1105+
elif isinstance(indices, int):
1106+
indices_tensor = paddle.to_tensor([indices], dtype='int64')
1107+
else:
1108+
# Otherwise (e.g., list, tuple), convert directly
1109+
indices_tensor = paddle.to_tensor(indices, dtype='int64')
1110+
# Ensure indices is a 1D tensor
1111+
if indices_tensor.ndim == 0:
1112+
indices_tensor = indices_tensor.reshape([1])
1113+
elif indices_tensor.ndim > 1:
1114+
raise ValueError(f"indices must be a 1D tensor, but received a {indices_tensor.ndim}D tensor")
1115+
1116+
return paddle.index_select(x, index=indices_tensor, axis=_axis)
10991117

11001118

11011119
def sign(x: array, /) -> array:
@@ -1261,7 +1279,6 @@ def cumulative_sum(
12611279
"axis must be specified in cumulative_sum for more than one dimension"
12621280
)
12631281
axis = 0
1264-
12651282
res = paddle.cumsum(x, axis=axis, dtype=dtype)
12661283

12671284
# np.cumsum does not support include_initial

array_api_compat/paddle/linalg.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,7 @@ def pinv(x: array, /, *, rtol: Optional[Union[float, array]] = None) -> array:
136136

137137

138138
def slogdet(x: array):
139-
det = paddle.linalg.det(x)
140-
sign = paddle.sign(det)
141-
log_det = paddle.log(det)
142-
143-
slotdet = namedtuple("slotdet", ["sign", "logabsdet"])
144-
return slotdet(sign, log_det)
139+
return tuple_to_namedtuple(paddle.linalg.slogdet(x), ["sign", "logabsdet"])
145140

146141
def tuple_to_namedtuple(data, fields):
147142
nt_class = namedtuple('DynamicNameTuple', fields)

array_api_compat/torch/_aliases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ def triu(x: array, /, *, k: int = 0) -> array:
611611

612612
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
613613
def expand_dims(x: array, /, *, axis: int = 0) -> array:
614+
if axis == 2:
615+
import pdb
616+
pdb.set_trace()
614617
return torch.unsqueeze(x, axis)
615618

616619
def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array:

0 commit comments

Comments
 (0)