Skip to content

Commit cafa75f

Browse files
committed
add launch bound
Signed-off-by: Xin Yao <xiny@nvidia.com>
1 parent be8df50 commit cafa75f

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

transformer_engine/common/fused_rope/fused_rope.cu

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace transformer_engine {
1616

1717
// Parameters for vectorization
1818
constexpr int desired_load_store_size = 8; // bytes
19+
constexpr int n_warps_per_tile = 8;
20+
constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile;
1921

2022
template <typename scalar_t, int nvec, bool aligned>
2123
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
@@ -227,14 +229,14 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
227229
}
228230

229231
template <typename scalar_t, int nvec, bool aligned>
230-
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
231-
const float *freqs, const int *start_positions,
232-
scalar_t *dst, const bool interleaved, const int cp_size,
233-
const int cp_rank, const int s, const int h, const int d,
234-
const int d2, const int stride_s_or_t, const int stride_b,
235-
const int stride_h, const int stride_d,
236-
const int o_stride_s_or_t, const int o_stride_b,
237-
const int o_stride_h, const int o_stride_d) {
232+
__global__ void __launch_bounds__(threads_per_block)
233+
fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens, const float *freqs,
234+
const int *start_positions, scalar_t *dst, const bool interleaved,
235+
const int cp_size, const int cp_rank, const int s, const int h,
236+
const int d, const int d2, const int stride_s_or_t,
237+
const int stride_b, const int stride_h, const int stride_d,
238+
const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
239+
const int o_stride_d) {
238240
int s_id = blockIdx.x, b_id = blockIdx.y;
239241
int offset_block, offset_block_dst;
240242
int cur_seqlens;
@@ -272,12 +274,14 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
272274
}
273275

274276
template <typename scalar_t, int nvec, bool aligned>
275-
__global__ void fused_rope_backward_kernel(
276-
const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions,
277-
scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s,
278-
const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
279-
const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b,
280-
const int o_stride_h, const int o_stride_d) {
277+
__global__ void __launch_bounds__(threads_per_block)
278+
fused_rope_backward_kernel(const scalar_t *src, const int *cu_seqlens, const float *freqs,
279+
const int *start_positions, scalar_t *dst, const bool interleaved,
280+
const int cp_size, const int cp_rank, const int s, const int h,
281+
const int d, const int d2, const int stride_s_or_t,
282+
const int stride_b, const int stride_h, const int stride_d,
283+
const int o_stride_s_or_t, const int o_stride_b,
284+
const int o_stride_h, const int o_stride_d) {
281285
int s_id = blockIdx.x, b_id = blockIdx.y;
282286
int offset_block, offset_block_dst;
283287
int cur_seqlens;
@@ -556,9 +560,8 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
556560
const int h, const int d, const int d2, const int stride_s_or_t,
557561
const int stride_b, const int stride_h, const int stride_d,
558562
cudaStream_t stream) {
559-
int warps_per_block = h < 16 ? 4 : 8;
560563
dim3 blocks(s, b);
561-
dim3 threads(THREADS_PER_WARP, warps_per_block);
564+
dim3 threads(THREADS_PER_WARP, n_warps_per_tile);
562565
// No shared memory needed - cos/sin computed directly in registers
563566
const int shared_mem_size = 0;
564567
int o_stride_s_or_t, o_stride_b;
@@ -608,9 +611,8 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
608611
const int s, const int b, const int h, const int d, const int d2,
609612
const int stride_s_or_t, const int stride_b, const int stride_h,
610613
const int stride_d, cudaStream_t stream) {
611-
int warps_per_block = h < 16 ? 4 : 8;
612614
dim3 blocks(s, b);
613-
dim3 threads(THREADS_PER_WARP, warps_per_block);
615+
dim3 threads(THREADS_PER_WARP, n_warps_per_tile);
614616
// Shared memory for cos/sin cache: [cos(d2)] [padding(1)] [sin(d2)]
615617
const int shared_mem_size = (2 * d2 + 1) * sizeof(float);
616618
int o_stride_s_or_t, o_stride_b;

0 commit comments

Comments
 (0)