@@ -543,7 +543,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
543543 for (int mask = WARP_SIZE/2 ; mask >= 16 ; mask/=2 ) {
544544 exp_sum += __shfl_xor (exp_sum,mask);
545545 }
546-
546+
547547 if (laneid < 16 ) {
548548 shared_qk_max[warpid][lane16id] = qk_max;
549549 shared_exp_sum[warpid][lane16id] = exp_sum;
@@ -626,20 +626,59 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
626626 const int offset2 = offset / 4 ;
627627
628628 // if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows
629- tmp_out = gcn_mfma16x16x16_instr<scalar_t , 0 , 0 , 0 >(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630- Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy [i], tmp_out);
629+ // tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630+ // Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
631631
632632 // if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows
633- // partition_out[vhe_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i],
634- // shared_tokens [vtoken_depth][offset2][lane16id][offset1],
635- // partition_out[vhe_depth] );
633+ tmp_out = gcn_mfma16x16x16_instr<scalar_t , 0 , 0 , 0 >(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy [i],
634+ shared_logits [vtoken_depth][offset2][lane16id][offset1],
635+ tmp_out );
636636 }
637637 }
638638 }
639639 outelems[vhe_depth] = from_floatx4<scalar_t >(tmp_out);
640640 }
641641
642- #if 1
642+ __syncthreads ();
643+
644+ for (int vhe_depth = 0 ; vhe_depth < VHELOOP; vhe_depth++) {
645+ shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; // lane16 id head dimension; rowid head element dimension
646+ }
647+
648+ __syncthreads ();
649+
650+ if (warpid == 0 ) {
651+ _B16x8 vout[GQA_RATIO4];
652+ for (int h = 0 ; h < GQA_RATIO4; h++) {
653+ const int local_head_idx = 4 * h + rowid;
654+ const int head_elem_idx = lane16id * 8 ;
655+ const int offset1 = (head_elem_idx / 16 )%4 ;
656+ const int offset2 = head_elem_idx / 16 / NWARPS;
657+ const int offset3 = (head_elem_idx / 4 )%4 ;
658+ for (int i=0 ; i<2 ; i++) {
659+ vout[h].xy [i] = shared_logits[offset1][offset2][local_head_idx][offset3+i];
660+ }
661+ }
662+
663+ const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
664+ scalar_t * out_ptr = out +
665+ seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE;
666+ for (int h = 0 ; h < GQA_RATIO4; h++) {
667+ const int local_head_idx = 4 * h + rowid;
668+ if (local_head_idx < GQA_RATIO) {
669+ const int out_head_idx = wg_start_head_idx + local_head_idx;
670+ scalar_t * out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
671+ const int head_elem_idx = lane16id * 8 ;
672+ scalar_t * out_ptr3 = out_ptr2 + head_elem_idx;
673+ _B16x8* out_ptr_B16x8 = reinterpret_cast <_B16x8*>(out_ptr3);
674+ *out_ptr_B16x8 = vout[h];
675+ }
676+ }
677+
678+ }
679+
680+
681+ #if 0
643682 //if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows
644683 const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
645684 scalar_t* out_ptr = out +
@@ -661,7 +700,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
661700 }
662701 }
663702 }
664- #else
703+ #endif
704+ #if 0
665705 //if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows
666706 if (lane16id < GQA_RATIO) {
667707 const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
0 commit comments