@@ -364,33 +364,33 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
364364
365365 __shared__ float shared_qk_max[NWARPS][16 + 1 ];
366366 __shared__ float shared_exp_sum[NWARPS][16 + 1 ];
367+ // shared_logits is used for multiple purposes
368+ __shared__ _B16x4 shared_logits[NWARPS][4 ][16 ][4 + 1 ];
367369
368370 // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes HeadElements in each lane, 4x16B HeadElements across 4 rows of warp
369371 constexpr int ROWS_PER_WARP = WARP_SIZE / 16 ; // rows refers to 16 lanes; refer dpp terminology
370- constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof (cache_t );
371- constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; // TODO 8B form?
372+ constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = 16 / sizeof (cache_t ); // 8 for 16 bit cache type, 16 for 8 bit types
373+ constexpr int QKHE_PER_FETCH = CONTIGUOUS_KV_ELEMS_16B_LOAD * ROWS_PER_WARP; // each fetch across a warp fetches these many elements
374+ constexpr int QK_SIZE_RATIO = sizeof (scalar_t ) / sizeof (cache_t ); // 1 for 16bit types, 2 for 8bit types
372375 constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 4xQKHE_16B across warp
373376
374- _B16x8 Qlocal[QKHELOOP]; // this could be B8x16 too
377+ _B16x8 Qlocal[QKHELOOP][QK_SIZE_RATIO] ; // note that 16 contiguous elements of Q should be fetched per lane for 8 bit cache types : QK_SIZE_RATIO changes for this
375378
376379 constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof (scalar_t );
377- constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; // x is defined by vLLM as 16Bytes
380+ // constexpr int x = CONTIGUOUS_SCALAR_ELEMS_16B; //x is defined by vLLM as 16Bytes
378381
379- constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4 ; // mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads
380- constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1 ; // 16 tokens across lanes * TLOOP factor
381- constexpr int T_PAR_SIZE = 256 ;
382- constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS;
383- constexpr int TLOOP = TLOOP1 * T_PAR_LOOP;
384- constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS ; // TOKENS_PER_WARP1 * T_PAR_LOOP;
382+ constexpr int T_PAR_SIZE = 256 ; // partition size set to 256 TODO move to template param
383+ // constexpr int TLOOP1 = CONTIGUOUS_KV_ELEMS_16B_LOAD / 4 ; //mfma16x16x16 outputs 4 elements per lane: will be moved to match layout for V dwordx4 loads
384+ // constexpr int TOKENS_PER_WARP1 = 16 * TLOOP1; //16 tokens across lanes * TLOOP factor
385+ // constexpr int T_PAR_LOOP = T_PAR_SIZE / TOKENS_PER_WARP1 / NWARPS;
386+ constexpr int TOKENS_PER_WARP = T_PAR_SIZE / NWARPS; // sub partition of tokens per warp for qk calculation
387+ constexpr int TLOOP = TOKENS_PER_WARP / 16 ; // each mfma16x16x16 instruction processes 16 tokens
385388
386389 _B16x8 Klocal[TLOOP][QKHELOOP]; // this could be B8x16 too
387390
388391 const int wg_start_head_idx = blockIdx .z * GQA_RATIO;
389392 const int wg_start_kv_head_idx = blockIdx .z ;
390393 const int total_num_heads = gridDim .z * GQA_RATIO;
391- const bool warp_in_context = (partition_start_token_idx + warpid * TOKENS_PER_WARP) < context_len;
392-
393- // TODO implement warp out of context logic
394394
395395 // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
396396 // each mfma takes QH16xT16x16HE across warp
@@ -414,23 +414,63 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
414414 kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
415415 }
416416
417+ #if 0 //fetch Q into registers
418+
417419 const int local_qhead_idx = lane16id % GQA_RATIO;
418420 const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
419421 const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
420422 const scalar_t* q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
421423
422424 if (lane16id < GQA_RATIO) {
423425 for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
424- const scalar_t * q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH;
425- const _B16x8* q_fetch_ptr_16B = reinterpret_cast <const _B16x8*>(q_fetch_ptr);
426- Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
426+ const scalar_t* q_ptr2 = q_ptr + qkhe_depth * QKHE_PER_FETCH;
427+ for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
428+ const scalar_t* q_fetch_ptr = q_ptr2 + qkratio * CONTIGUOUS_SCALAR_ELEMS_16B;
429+ const _B16x8* q_fetch_ptr_16B = reinterpret_cast<const _B16x8*>(q_fetch_ptr);
430+ Qlocal[qkhe_depth][qkratio] = *q_fetch_ptr_16B;
431+ }
427432 }
428433 } else {
429434 for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
430- Qlocal[qkhe_depth].xy [0 ] = {0 };
431- Qlocal[qkhe_depth].xy [1 ] = {0 };
435+ for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
436+ Qlocal[qkhe_depth][qkratio].xy[0] = {0};
437+ Qlocal[qkhe_depth][qkratio].xy[1] = {0};
438+ }
432439 }
433440 }
441+ #else // fetch Q in shared
442+ const int local_qhead_idx = 4 * warpid + rowid;
443+ const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
444+ const int64_t seq_idx64 = static_cast <int64_t >(seq_idx);
445+ const scalar_t * q_ptr = q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; // + rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
446+
447+ if (local_qhead_idx < GQA_RATIO) {
448+ const scalar_t * q_fetch_ptr = q_ptr + lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; // this works for head size 128 : 16 lanes x 8 elems = 128 elems
449+ const _B16x8* q_fetch_ptr_16B = reinterpret_cast <const _B16x8*>(q_fetch_ptr);
450+ _B16x8 tmp = *q_fetch_ptr_16B;
451+ const int offset1 = lane16id/4 ; // 16 contiguous chunks of head elems are spread across 4x4lanes
452+ shared_logits[offset1][lane4id][local_qhead_idx][0 ] = tmp.xy [0 ];
453+ shared_logits[offset1][lane4id][local_qhead_idx][1 ] = tmp.xy [1 ];
454+ }
455+ // else { //TODO: is this part needed?
456+ // const int offset1 = lane16id/4; //16 contiguous chunks of head elems are spread across 4x4lanes
457+ // shared_logits[offset1][lane4id][local_qhead_idx][0] = {0};
458+ // shared_logits[offset1][lane4id][local_qhead_idx][1] = {0};
459+ // }
460+ __syncthreads ();
461+ // if (lane16id < GQA_RATIO) {
462+ for (int qkhe_depth = 0 ; qkhe_depth < QKHELOOP; qkhe_depth++) {
463+ Qlocal[qkhe_depth][0 ].xy [0 ] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0 ];
464+ Qlocal[qkhe_depth][0 ].xy [1 ] = shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][1 ];
465+ }
466+ // }
467+ // else {
468+ // for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
469+ // Qlocal[qkhe_depth][0].xy[0] = {0};
470+ // Qlocal[qkhe_depth][0].xy[1] = {0};
471+ // }
472+ // }
473+ #endif
434474
435475 constexpr int KX = 16 / sizeof (cache_t );
436476 const cache_t * k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride;
@@ -493,17 +533,44 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
493533 }
494534 }
495535
536+ // __syncthreads(); //if using shared Q
537+
496538 floatx4 dout[TLOOP];
497- __shared__ _B16x4 shared_logits[NWARPS][TLOOP][ 16 ][VTOKENS_PER_LANE/ 4 + 1 ];
539+ # if 1 // Q stored in registers
498540 for (int token_depth = 0 ; token_depth < TLOOP; token_depth++) {
499541 dout[token_depth] = {0 };
500542 for (int qkhe_depth = 0 ; qkhe_depth < QKHELOOP; qkhe_depth++) {
543+ for (int qkratio = 0 ; qkratio < QK_SIZE_RATIO; qkratio++) {
501544 for (int i=0 ; i<2 ; i++) {
502- dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t , 0 , 0 , 0 >(Klocal[token_depth][qkhe_depth].xy [i], Qlocal[qkhe_depth].xy [i], dout[token_depth]);
545+ dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t , 0 , 0 , 0 >(Klocal[token_depth][qkhe_depth].xy [i], Qlocal[qkhe_depth][qkratio] .xy [i], dout[token_depth]);
503546 }
547+ }
504548 }
505549 dout[token_depth] *= scale;
506550 }
551+
552+ #else //Q in shared
553+ _B16x4 tmpQ[QKHELOOP][2];
554+ for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
555+ tmpQ[qkhe_depth][0] = shared_logits[qkhe_depth][rowid][lane16id][0];
556+ tmpQ[qkhe_depth][1] = shared_logits[qkhe_depth][rowid][lane16id][1];
557+ }
558+
559+ for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
560+ dout[token_depth] = {0};
561+ for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
562+ //for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
563+ for (int i=0; i<2; i++) {
564+ dout[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Klocal[token_depth][qkhe_depth].xy[i],
565+ tmpQ[qkhe_depth][i], //shared_logits[qkhe_depth][rowid][lane16id][i],
566+ dout[token_depth]);
567+ }
568+ //}
569+ }
570+ dout[token_depth] *= scale;
571+ }
572+ #endif
573+
507574#if 0 //DEBUG ONLY qk * scale
508575 for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
509576 auto qkout_ptr2 = qkout_ptr + warpid * TLOOP * 16 + token_depth * 16 + rowid * 4;
@@ -514,6 +581,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
514581#endif
515582
516583 float qk_max = -FLT_MAX;
584+ float exp_sum = 0 .0f ;
517585
518586 const int qkout_token_idx = partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4 ;
519587
@@ -529,7 +597,6 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
529597 qk_max = fmaxf (qk_max, __shfl_xor (qk_max,mask));
530598 }
531599
532- float exp_sum = 0 .0f ;
533600
534601 for (int token_depth = 0 ; token_depth < TLOOP; token_depth++) {
535602 const int local_token_idx = qkout_token_idx + token_depth * 16 ;
@@ -578,6 +645,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
578645
579646 const float inv_sum_scale = __fdividef (1 .f , partition_exp_sum + 1e-6f ) * warp_qk_max_exp[warpid];
580647
648+ // __shared__ _B16x4 shared_logits[NWARPS][TLOOP][16][VTOKENS_PER_LANE/4 + 1];
581649 for (int token_depth = 0 ; token_depth < TLOOP; token_depth++) {
582650 dout[token_depth] *= inv_sum_scale;
583651 shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4<scalar_t >(dout[token_depth]);
@@ -624,21 +692,23 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
624692 const int offset = 4 *rowid + 2 *vfetch_depth + i;
625693 const int offset1 = offset % 4 ;
626694 const int offset2 = offset / 4 ;
627-
695+ # if 0
628696 //if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows
629- // tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630- // Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
631-
697+ tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
698+ Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
699+ # else
632700 // if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows
633701 tmp_out = gcn_mfma16x16x16_instr<scalar_t , 0 , 0 , 0 >(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy [i],
634702 shared_logits[vtoken_depth][offset2][lane16id][offset1],
635703 tmp_out);
704+ #endif
636705 }
637706 }
638707 }
639708 outelems[vhe_depth] = from_floatx4<scalar_t >(tmp_out);
640709 }
641710
711+ #if 1
642712 __syncthreads ();
643713
644714 for (int vhe_depth = 0 ; vhe_depth < VHELOOP; vhe_depth++) {
@@ -676,7 +746,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
676746 }
677747
678748 }
679-
749+ # endif
680750
681751#if 0
682752 //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows
0 commit comments