Skip to content

Commit be8df50

Browse files
committed
no fast math for sincos
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent ebddbcb commit be8df50

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

transformer_engine/common/fused_rope/fused_rope.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
4242
float v_cos[nvec], v_sin[nvec];
4343
#pragma unroll
4444
for (int i = 0; i < nvec; i++) {
45-
__sincosf(freqs_row[d_id + i], &v_sin[i], &v_cos[i]);
45+
sincosf(freqs_row[d_id + i], &v_sin[i], &v_cos[i]);
4646
}
4747

4848
// Process all heads with the same cos/sin values
@@ -68,7 +68,7 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
6868
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
6969
// Compute cos/sin once per d_id, store in registers
7070
float v_sin, v_cos;
71-
__sincosf(freqs_row[d_id], &v_sin, &v_cos);
71+
sincosf(freqs_row[d_id], &v_sin, &v_cos);
7272

7373
// Precompute rotation parameters once per d_id
7474
int rot_offset_d;
@@ -138,7 +138,7 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
138138
const float *freqs_row = freqs + s_id * d2;
139139
int tid = threadIdx.x * blockDim.y + threadIdx.y;
140140
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
141-
__sincosf(freqs_row[i], &shared_sin[i], &shared_cos[i]);
141+
sincosf(freqs_row[i], &shared_sin[i], &shared_cos[i]);
142142
}
143143
__syncthreads();
144144

0 commit comments

Comments
 (0)