11from __future__ import annotations
22
33from functools import wraps
4- from builtins import all as builtin_all
4+ from builtins import all as builtin_all , any as builtin_any
55
66from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
77 UniqueInverseResult ,
1919
2020 array = torch .Tensor
2121
22- _array_api_dtypes = {
23- torch .bool ,
22+ _int_dtypes = {
2423 torch .uint8 ,
2524 torch .int8 ,
2625 torch .int16 ,
2726 torch .int32 ,
2827 torch .int64 ,
28+ }
29+
30+ _array_api_dtypes = {
31+ torch .bool ,
32+ * _int_dtypes ,
2933 torch .float32 ,
3034 torch .float64 ,
3135}
@@ -611,6 +615,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
611615 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
612616 return torch .tensordot (x1 , x2 , dims = axes , ** kwargs )
613617
618+
619+ def isdtype (
620+ dtype : Dtype , kind : Union [Dtype , str , Tuple [Union [Dtype , str ], ...]],
621+ * , _tuple = True , # Disallow nested tuples
622+ ) -> bool :
623+ """
624+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
625+
626+ Note that outside of this function, this compat library does not yet fully
627+ support complex numbers.
628+
629+ See
630+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
631+ for more details
632+ """
633+ if isinstance (kind , tuple ) and _tuple :
634+ return builtin_any (isdtype (dtype , k , _tuple = False ) for k in kind )
635+ elif isinstance (kind , str ):
636+ if kind == 'bool' :
637+ return dtype == torch .bool
638+ elif kind == 'signed integer' :
639+ return dtype in _int_dtypes and dtype .is_signed
640+ elif kind == 'unsigned integer' :
641+ return dtype in _int_dtypes and not dtype .is_signed
642+ elif kind == 'integral' :
643+ return dtype in _int_dtypes
644+ elif kind == 'real floating' :
645+ return dtype .is_floating_point
646+ elif kind == 'complex floating' :
647+ return dtype .is_complex
648+ elif kind == 'numeric' :
649+ return isdtype (dtype , ('integral' , 'real floating' , 'complex floating' ))
650+ else :
651+ raise ValueError (f"Unrecognized data type kind: { kind !r} " )
652+ else :
653+ return dtype == kind
654+
614655__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
615656 'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
616657 'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -622,4 +663,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
622663 'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
623664 'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
624665 'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
625- 'vecdot' , 'tensordot' ]
666+ 'vecdot' , 'tensordot' , 'isdtype' ]
0 commit comments