@@ -938,6 +938,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
938938
939939
940940def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
941+ if axis < - x .ndim - 1 or axis > x .ndim :
942+ raise IndexError (f"Axis { axis } is out of bounds for array of dimension { x .ndim } " )
941943 return paddle .unsqueeze (x , axis )
942944
943945
@@ -1087,15 +1089,31 @@ def is_complex(dtype):
10871089
10881090
10891091def take (x : array , indices : array , / , * , axis : Optional [int ] = None , ** kwargs ) -> array :
1090- if axis is None :
1092+ _axis = axis
1093+ if _axis is None :
10911094 if x .ndim != 1 :
1092- raise ValueError ("axis must be specified when ndim > 1" )
1093- axis = 0
1094- if not paddle .is_tensor (indices ):
1095- indices = paddle .to_tensor (indices )
1096- if not paddle .is_tensor (axis ):
1097- axis = paddle .to_tensor (axis )
1098- return paddle .index_select (x , axis , indices , ** kwargs )
1095+ raise ValueError ("axis must be specified when x.ndim > 1" )
1096+ _axis = 0
1097+ elif not isinstance (_axis , int ):
1098+ raise TypeError (f"axis must be an integer, but received { type (_axis )} " )
1099+
1100+ if not (- x .ndim <= _axis < x .ndim ):
1101+ raise IndexError (f"axis { _axis } is out of bounds for tensor of dimension { x .ndim } " )
1102+
1103+ if isinstance (indices , paddle .Tensor ):
1104+ indices_tensor = indices
1105+ elif isinstance (indices , int ):
1106+ indices_tensor = paddle .to_tensor ([indices ], dtype = 'int64' )
1107+ else :
1108+ # Otherwise (e.g., list, tuple), convert directly
1109+ indices_tensor = paddle .to_tensor (indices , dtype = 'int64' )
1110+ # Ensure indices is a 1D tensor
1111+ if indices_tensor .ndim == 0 :
1112+ indices_tensor = indices_tensor .reshape ([1 ])
1113+ elif indices_tensor .ndim > 1 :
1114+ raise ValueError (f"indices must be a 1D tensor, but received a { indices_tensor .ndim } D tensor" )
1115+
1116+ return paddle .index_select (x , index = indices_tensor , axis = _axis )
10991117
11001118
11011119def sign (x : array , / ) -> array :
@@ -1261,7 +1279,6 @@ def cumulative_sum(
12611279 "axis must be specified in cumulative_sum for more than one dimension"
12621280 )
12631281 axis = 0
1264-
12651282 res = paddle .cumsum (x , axis = axis , dtype = dtype )
12661283
12671284 # np.cumsum does not support include_initial
0 commit comments