Skip to content

Commit 402fdde

Browse files
committed
fetch q in shared mem for better address patterns
1 parent cf863a7 commit 402fdde

File tree

1 file changed

+96
-26
lines changed

1 file changed

+96
-26
lines changed

csrc/rocm/attention.cu

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -364,33 +364,33 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
364364

365365
__shared__ float shared_qk_max[NWARPS][16 + 1];
366366
__shared__ float shared_exp_sum[NWARPS][16 + 1];
367+
//shared_logits is used for multiple purposes
368+
__shared__ _B16x4 shared_logits[NWARPS][4][16][4 + 1];
367369

368370
//for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp
369371
constexpr int ROWS_PER_WARP = WARP_SIZE / 16; //rows refers to 16 lanes; refer dpp terminology
370-
constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t);
371-
constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //TODO 8B form?
372+
constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof(cache_t); //8 for 16 bit cache type, 16 for 8 bit types
373+
constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; //each fetch across a warp fetches these many elements
374+
constexpr int QK_SIZE_RATIO = sizeof(scalar_t) / sizeof(cache_t); //1 for 16bit types, 2 for 8bit types
372375
constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; //4xQKHE_16B across warp
373376

374-
_B16x8 Qlocal[QKHELOOP]; //this could be B8x16 too
377+
_B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO]; //note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this
375378

376379
constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);
377-
constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; //x is defined by vLLM as 16Bytes
380+
//constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; //x is defined by vLLM as 16Bytes
378381

379-
constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4; //mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads
380-
constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1; //16 tokens across lanes * TLOOP factor
381-
constexpr int T_PAR_SIZE = 256;
382-
constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS;
383-
constexpr int TLOOP = TLOOP1 * T_PAR_LOOP;
384-
constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //TOKENS_PER_WARP1 * T_PAR_LOOP;
382+
constexpr int T_PAR_SIZE = 256; //partition size set to 256 TODO move to template param
383+
//constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4; //mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads
384+
//constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1; //16 tokens across lanes * TLOOP factor
385+
//constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS;
386+
constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; //sub partition of tokens per warp for qk calculation
387+
constexpr int TLOOP = TOKENS_PER_WARP / 16; //each mfma16x16x16 instruction processes 16 tokens
385388

386389
_B16x8 Klocal[TLOOP][QKHELOOP]; //this could be B8x16 too
387390

388391
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
389392
const int wg_start_kv_head_idx = blockIdx.z;
390393
const int total_num_heads = gridDim.z * GQA_RATIO;
391-
const bool warp_in_context = (partition_start_token_idx + warpid * TOKENS_PER_WARP) < context_len;
392-
393-
//TODO implement warp out of context logic
394394

395395
//for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
396396
//each mfma takes QH16xT16x16HE across warp
@@ -414,23 +414,63 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
414414
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
415415
}
416416

417+
#if 0 //fetch Q into registers
418+
417419
const int local_qhead_idx = lane16id % GQA_RATIO;
418420
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
419421
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
420422
const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
421423

422424
if (lane16id < GQA_RATIO) {
423425
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
424-
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH;
425-
const _B16x8* q_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(q_fetch_ptr);
426-
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
426+
const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH;
427+
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
428+
const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B;
429+
const _B16x8* q_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(q_fetch_ptr);
430+
Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B;
431+
}
427432
}
428433
} else {
429434
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
430-
Qlocal[qkhe_depth].xy[0] = {0};
431-
Qlocal[qkhe_depth].xy[1] = {0};
435+
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
436+
Qlocal[qkhe_depth][qkratio].xy[0] = {0};
437+
Qlocal[qkhe_depth][qkratio].xy[1] = {0};
438+
}
432439
}
433440
}
441+
#else //fetch Q in shared
442+
const int local_qhead_idx = 4 * warpid + rowid;
443+
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
444+
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
445+
const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; //+ rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
446+
447+
if (local_qhead_idx < GQA_RATIO) {
448+
const scalar_t* q_fetch_ptr = q_ptr + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; //this works for head size 128 : 16 lanes x 8 elems = 128 elems
449+
const _B16x8* q_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(q_fetch_ptr);
450+
_B16x8 tmp = *q_fetch_ptr_16B;
451+
const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes
452+
shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0];
453+
shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1];
454+
}
455+
//else { //TODO: is this part needed?
456+
// const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes
457+
// shared_logits[offset1][lane4id][local_qhead_idx][0] = {0};
458+
// shared_logits[offset1][lane4id][local_qhead_idx][1] = {0};
459+
//}
460+
__syncthreads();
461+
//if (lane16id < GQA_RATIO) {
462+
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
463+
Qlocal[qkhe_depth][0].xy[0] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0];
464+
Qlocal[qkhe_depth][0].xy[1] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][1];
465+
}
466+
//}
467+
//else {
468+
// for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
469+
// Qlocal[qkhe_depth][0].xy[0] = {0};
470+
// Qlocal[qkhe_depth][0].xy[1] = {0};
471+
// }
472+
//}
473+
#endif
434474

435475
constexpr int KX = 16 / sizeof(cache_t);
436476
const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
@@ -493,17 +533,44 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
493533
}
494534
}
495535

536+
//__syncthreads(); //if using shared Q
537+
496538
floatx4 dout[TLOOP];
497-
__shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1];
539+
#if 1 //Q stored in registers
498540
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
499541
dout[token_depth] = {0};
500542
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
543+
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
501544
for (int i=0; i<2; i++) {
502-
dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Klocal[token_depth][qkhe_depth].xy[i], Qlocal[qkhe_depth].xy[i], dout[token_depth]);
545+
dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Klocal[token_depth][qkhe_depth].xy[i], Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]);
503546
}
547+
}
504548
}
505549
dout[token_depth] *= scale;
506550
}
551+
552+
#else //Q in shared
553+
_B16x4 tmpQ[QKHELOOP][2];
554+
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
555+
tmpQ[qkhe_depth][0] = shared_logits[qkhe_depth][rowid][lane16id][0];
556+
tmpQ[qkhe_depth][1] = shared_logits[qkhe_depth][rowid][lane16id][1];
557+
}
558+
559+
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
560+
dout[token_depth] = {0};
561+
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
562+
//for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
563+
for (int i=0; i<2; i++) {
564+
dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Klocal[token_depth][qkhe_depth].xy[i],
565+
tmpQ[qkhe_depth][i], //shared_logits[qkhe_depth][rowid][lane16id][i],
566+
dout[token_depth]);
567+
}
568+
//}
569+
}
570+
dout[token_depth] *= scale;
571+
}
572+
#endif
573+
507574
#if 0 //DEBUG ONLY qk * scale
508575
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
509576
auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4;
@@ -514,6 +581,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
514581
#endif
515582

516583
float qk_max = -FLT_MAX;
584+
float exp_sum = 0.0f;
517585

518586
const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4;
519587

@@ -529,7 +597,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
529597
qk_max = fmaxf(qk_max, __shfl_xor(qk_max,mask));
530598
}
531599

532-
float exp_sum = 0.0f;
533600

534601
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
535602
const int local_token_idx = qkout_token_idx + token_depth * 16;
@@ -578,6 +645,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
578645

579646
const float inv_sum_scale = __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid];
580647

648+
//__shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1];
581649
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
582650
dout[token_depth] *= inv_sum_scale;
583651
shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4<scalar_t>(dout[token_depth]);
@@ -624,21 +692,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
624692
const int offset = 4*rowid + 2*vfetch_depth + i;
625693
const int offset1 = offset % 4;
626694
const int offset2 = offset / 4;
627-
695+
#if 0
628696
//if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows
629-
//tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630-
// Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
631-
697+
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
698+
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
699+
#else
632700
//if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows
633701
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i],
634702
shared_logits[vtoken_depth][offset2][lane16id][offset1],
635703
tmp_out);
704+
#endif
636705
}
637706
}
638707
}
639708
outelems[vhe_depth] = from_floatx4<scalar_t>(tmp_out);
640709
}
641710

711+
#if 1
642712
__syncthreads();
643713

644714
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
@@ -676,7 +746,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
676746
}
677747

678748
}
679-
749+
#endif
680750

681751
#if 0
682752
//if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows

0 commit comments

Comments
 (0)