@@ -16,6 +16,8 @@ namespace transformer_engine {
1616
1717// Parameters for vectorization
1818constexpr int desired_load_store_size = 8 ; // bytes
19+ constexpr int n_warps_per_tile = 8 ;
20+ constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
1921
2022template <typename scalar_t , int nvec, bool aligned>
2123__device__ void fused_rope_block_forward (const scalar_t *src, const float *freqs, scalar_t *dst,
@@ -227,14 +229,14 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
227229}
228230
229231template <typename scalar_t , int nvec, bool aligned>
230- __global__ void fused_rope_forward_kernel ( const scalar_t *src, const int *cu_seqlens,
231- const float *freqs , const int *start_positions ,
232- scalar_t *dst, const bool interleaved, const int cp_size ,
233- const int cp_rank , const int s , const int h , const int d ,
234- const int d2 , const int stride_s_or_t , const int stride_b ,
235- const int stride_h, const int stride_d,
236- const int o_stride_s_or_t , const int o_stride_b ,
237- const int o_stride_h, const int o_stride_d) {
232+ __global__ void __launch_bounds__ (threads_per_block)
233+ fused_rope_forward_kernel( const scalar_t *src, const int *cu_seqlens , const float *freqs ,
234+ const int *start_positions, scalar_t *dst, const bool interleaved,
235+ const int cp_size , const int cp_rank , const int s , const int h ,
236+ const int d , const int d2 , const int stride_s_or_t ,
237+ const int stride_b, const int stride_h, const int stride_d,
238+ const int o_stride_s_or_t , const int o_stride_b , const int o_stride_h ,
239+ const int o_stride_d) {
238240 int s_id = blockIdx .x , b_id = blockIdx .y ;
239241 int offset_block, offset_block_dst;
240242 int cur_seqlens;
@@ -272,12 +274,14 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
272274}
273275
274276template <typename scalar_t , int nvec, bool aligned>
275- __global__ void fused_rope_backward_kernel (
276- const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions,
277- scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s,
278- const int h, const int d, const int d2, const int stride_s_or_t , const int stride_b,
279- const int stride_h, const int stride_d, const int o_stride_s_or_t , const int o_stride_b,
280- const int o_stride_h, const int o_stride_d) {
277+ __global__ void __launch_bounds__ (threads_per_block)
278+ fused_rope_backward_kernel(const scalar_t *src, const int *cu_seqlens, const float *freqs,
279+ const int *start_positions, scalar_t *dst, const bool interleaved,
280+ const int cp_size, const int cp_rank, const int s, const int h,
281+ const int d, const int d2, const int stride_s_or_t ,
282+ const int stride_b, const int stride_h, const int stride_d,
283+ const int o_stride_s_or_t , const int o_stride_b,
284+ const int o_stride_h, const int o_stride_d) {
281285 int s_id = blockIdx .x , b_id = blockIdx .y ;
282286 int offset_block, offset_block_dst;
283287 int cur_seqlens;
@@ -556,9 +560,8 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
556560 const int h, const int d, const int d2, const int stride_s_or_t ,
557561 const int stride_b, const int stride_h, const int stride_d,
558562 cudaStream_t stream) {
559- int warps_per_block = h < 16 ? 4 : 8 ;
560563 dim3 blocks (s, b);
561- dim3 threads (THREADS_PER_WARP, warps_per_block );
564+ dim3 threads (THREADS_PER_WARP, n_warps_per_tile );
562565 // No shared memory needed - cos/sin computed directly in registers
563566 const int shared_mem_size = 0 ;
564567 int o_stride_s_or_t , o_stride_b;
@@ -608,9 +611,8 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
608611 const int s, const int b, const int h, const int d, const int d2,
609612 const int stride_s_or_t , const int stride_b, const int stride_h,
610613 const int stride_d, cudaStream_t stream) {
611- int warps_per_block = h < 16 ? 4 : 8 ;
612614 dim3 blocks (s, b);
613- dim3 threads (THREADS_PER_WARP, warps_per_block );
615+ dim3 threads (THREADS_PER_WARP, n_warps_per_tile );
614616 // Shared memory for cos/sin cache: [cos(d2)] [padding(1)] [sin(d2)]
615617 const int shared_mem_size = (2 * d2 + 1 ) * sizeof (float );
616618 int o_stride_s_or_t , o_stride_b;
0 commit comments