@@ -543,6 +543,14 @@ def empty(shape: Union[int, Tuple[int, ...]],
543543 ** kwargs ) -> array :
544544 return torch .empty (shape , dtype = dtype , device = device , ** kwargs )
545545
546+ # tril and triu do not call the keyword argument k
547+
548+ def tril (x : array , / , * , k : int = 0 ) -> array :
549+ return torch .tril (x , k )
550+
551+ def triu (x : array , / , * , k : int = 0 ) -> array :
552+ return torch .triu (x , k )
553+
546554# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
547555def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
548556 return torch .unsqueeze (x , axis )
@@ -610,6 +618,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
610618 'subtract' , 'max' , 'min' , 'sort' , 'prod' , 'sum' , 'any' , 'all' ,
611619 'mean' , 'std' , 'var' , 'concat' , 'squeeze' , 'flip' , 'roll' ,
612620 'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613- 'zeros' , 'empty' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
614- 'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
615- 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
621+ 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
622+ 'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
623+ 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
624+ 'vecdot' , 'tensordot' ]
0 commit comments