@@ -245,11 +245,21 @@ void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
245245}
246246#endif
247247
248+ // vector-scalar multiplication
249+ template <auto _tag = detail::SysTag()>
250+ void VecScaMulFp32 (Context const * ctx, linalg::VectorView<float > x, float mul) {
251+ TransformKernel (ctx, x, [=] XGBOOST_DEVICE (float v) { return v * mul; });
252+ }
253+
248254// vector-scalar multiplication
249255template <auto _tag = detail::SysTag()>
250256void VecScaMul (Context const * ctx, linalg::VectorView<float > x, double mul) {
251257 CHECK_EQ (x.Device ().ordinal , ctx->Device ().ordinal );
252- TransformKernel (ctx, x, [=] XGBOOST_DEVICE (float v) { return v * mul; });
258+ if (ctx->DeviceFP64 () != ctx->Device ()) {
259+ VecScaMulFp32 (ctx, x, mul);
260+ } else {
261+ TransformKernel (ctx, x, [=] XGBOOST_DEVICE (float v) { return v * mul; });
262+ }
253263}
254264
255265// vector-scalar division
@@ -261,7 +271,11 @@ void VecScaDiv(Context const* ctx, linalg::VectorView<float> x, double div) {
261271template <auto _tag = detail::SysTag()>
262272void LogE (Context const * ctx, linalg::VectorView<float > x, float rt_eps = 0 .0f ) {
263273 CHECK_EQ (x.Device ().ordinal , ctx->Device ().ordinal );
274+ #if defined(SYCL_LANGUAGE_VERSION)
275+ TransformKernel (ctx, x, [=] XGBOOST_DEVICE (float v) { return ::sycl::log (v + rt_eps); });
276+ #else
264277 TransformKernel (ctx, x, [=] XGBOOST_DEVICE (float v) { return log (v + rt_eps); });
278+ #endif
265279}
266280
267281template <typename T, std::enable_if_t <std::is_floating_point_v<T>>* = nullptr >
0 commit comments