Skip to content

Commit 035ea49

Browse files
authored
fix for fp32 (#84)
* fix for fp32 * precommit --------- Co-authored-by: Dmitry Razdoburdin <>
1 parent a5710df commit 035ea49

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/common/linalg_op.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,21 @@ void TransformKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
245245
}
246246
#endif
247247

248+
// vector-scalar multiplication
249+
template <auto _tag = detail::SysTag()>
250+
void VecScaMulFp32(Context const* ctx, linalg::VectorView<float> x, float mul) {
251+
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; });
252+
}
253+
248254
// vector-scalar multiplication
249255
template <auto _tag = detail::SysTag()>
250256
void VecScaMul(Context const* ctx, linalg::VectorView<float> x, double mul) {
251257
CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal);
252-
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; });
258+
if (ctx->DeviceFP64() != ctx->Device()) {
259+
VecScaMulFp32(ctx, x, mul);
260+
} else {
261+
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return v * mul; });
262+
}
253263
}
254264

255265
// vector-scalar division
@@ -261,7 +271,11 @@ void VecScaDiv(Context const* ctx, linalg::VectorView<float> x, double div) {
261271
template <auto _tag = detail::SysTag()>
262272
void LogE(Context const* ctx, linalg::VectorView<float> x, float rt_eps = 0.0f) {
263273
CHECK_EQ(x.Device().ordinal, ctx->Device().ordinal);
274+
#if defined(SYCL_LANGUAGE_VERSION)
275+
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return ::sycl::log(v + rt_eps); });
276+
#else
264277
TransformKernel(ctx, x, [=] XGBOOST_DEVICE(float v) { return log(v + rt_eps); });
278+
#endif
265279
}
266280

267281
template <typename T, std::enable_if_t<std::is_floating_point_v<T>>* = nullptr>

0 commit comments

Comments
 (0)