3232 * _int_dtypes ,
3333 torch .float32 ,
3434 torch .float64 ,
35+ torch .complex64 ,
36+ torch .complex128 ,
3537}
3638
3739_promotion_table = {
7072 (torch .float32 , torch .float64 ): torch .float64 ,
7173 (torch .float64 , torch .float32 ): torch .float64 ,
7274 (torch .float64 , torch .float64 ): torch .float64 ,
75+ # complexes
76+ (torch .complex64 , torch .complex64 ): torch .complex64 ,
77+ (torch .complex64 , torch .complex128 ): torch .complex128 ,
78+ (torch .complex128 , torch .complex64 ): torch .complex128 ,
79+ (torch .complex128 , torch .complex128 ): torch .complex128 ,
80+ # Mixed float and complex
81+ (torch .float32 , torch .complex64 ): torch .complex64 ,
82+ (torch .float32 , torch .complex128 ): torch .complex128 ,
83+ (torch .float64 , torch .complex64 ): torch .complex128 ,
84+ (torch .float64 , torch .complex128 ): torch .complex128 ,
7385}
7486
7587
@@ -652,6 +664,9 @@ def isdtype(
652664 else :
653665 return dtype == kind
654666
667+ def take (x : array , indices : array , / , * , axis : int , ** kwargs ) -> array :
668+ return torch .index_select (x , axis , indices , ** kwargs )
669+
655670__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
656671 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
657672 'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -663,4 +678,4 @@ def isdtype(
663678 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
664679 'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
665680 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
666- 'vecdot' , 'tensordot' , 'isdtype' ]
681+ 'vecdot' , 'tensordot' , 'isdtype' , 'take' ]
0 commit comments