2929from ._copy_utils import _empty_like_orderK , _empty_like_pair_orderK
3030from ._type_utils import (
3131 _acceptance_fn_default ,
32+ _all_data_types ,
3233 _find_buf_dtype ,
3334 _find_buf_dtype2 ,
3435 _to_device_supported_dtype ,
@@ -44,6 +45,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4445 self .__name__ = "UnaryElementwiseFunc"
4546 self .name_ = name
4647 self .result_type_resolver_fn_ = result_type_resolver_fn
48+ self .types_ = None
4749 self .unary_fn_ = unary_dp_impl_fn
4850 self .__doc__ = docs
4951
@@ -53,6 +55,18 @@ def __str__(self):
5355 def __repr__ (self ):
5456 return f"<{ self .__name__ } '{ self .name_ } '>"
5557
58+ @property
59+ def types (self ):
60+ types = self .types_
61+ if not types :
62+ types = []
63+ for dt1 in _all_data_types (True , True ):
64+ dt2 = self .result_type_resolver_fn_ (dt1 )
65+ if dt2 :
66+ types .append (f"{ dt1 .char } ->{ dt2 .char } " )
67+ self .types_ = types
68+ return types
69+
5670 def __call__ (self , x , out = None , order = "K" ):
5771 if not isinstance (x , dpt .usm_ndarray ):
5872 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
@@ -363,6 +377,7 @@ def __init__(
363377 self .__name__ = "BinaryElementwiseFunc"
364378 self .name_ = name
365379 self .result_type_resolver_fn_ = result_type_resolver_fn
380+ self .types_ = None
366381 self .binary_fn_ = binary_dp_impl_fn
367382 self .binary_inplace_fn_ = binary_inplace_fn
368383 self .__doc__ = docs
@@ -377,6 +392,20 @@ def __str__(self):
377392 def __repr__ (self ):
378393 return f"<{ self .__name__ } '{ self .name_ } '>"
379394
395+ @property
396+ def types (self ):
397+ types = self .types_
398+ if not types :
399+ types = []
400+ _all_dtypes = _all_data_types (True , True )
401+ for dt1 in _all_dtypes :
402+ for dt2 in _all_dtypes :
403+ dt3 = self .result_type_resolver_fn_ (dt1 , dt2 )
404+ if dt3 :
405+ types .append (f"{ dt1 .char } { dt2 .char } ->{ dt3 .char } " )
406+ self .types_ = types
407+ return types
408+
380409 def __call__ (self , o1 , o2 , out = None , order = "K" ):
381410 if order not in ["K" , "C" , "F" , "A" ]:
382411 order = "K"
0 commit comments