3232
3333#include " kernels/reductions.hpp"
3434#include " reduction_over_axis.hpp"
35+ #include " utils/sycl_utils.hpp"
3536#include " utils/type_dispatch_building.hpp"
3637
3738namespace py = pybind11;
@@ -44,6 +45,7 @@ namespace py_internal
4445{
4546
4647namespace td_ns = dpctl::tensor::type_dispatch;
48+ namespace su_ns = dpctl::tensor::sycl_utils;
4749
4850namespace impl
4951{
@@ -68,6 +70,7 @@ struct TypePairSupportDataForLogSumExpReductionTemps
6870 static constexpr bool is_defined = std::disjunction< // disjunction is C++17
6971 // feature, supported
7072 // by DPC++ input bool
73+ #if 1
7174 td_ns::TypePairDefinedEntry<argTy, bool , outTy, sycl::half>,
7275 td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
7376 td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
@@ -105,7 +108,6 @@ struct TypePairSupportDataForLogSumExpReductionTemps
105108 // input uint64_t
106109 td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, float >,
107110 td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
108-
109111 // input half
110112 td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
111113 td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, float >,
@@ -117,6 +119,7 @@ struct TypePairSupportDataForLogSumExpReductionTemps
117119
118120 // input double
119121 td_ns::TypePairDefinedEntry<argTy, double , outTy, double >,
122+ #endif
120123
121124 // fall-through
122125 td_ns::NotDefinedEntry>::is_defined;
0 commit comments