Skip to content

Commit cf863a7

Browse files
committed
optimize output writes
1 parent ae66314 commit cf863a7

File tree

1 file changed

+48
-8
lines changed

1 file changed

+48
-8
lines changed

csrc/rocm/attention.cu

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
543543
for (int mask = WARP_SIZE/2; mask >= 16; mask/=2) {
544544
exp_sum += __shfl_xor(exp_sum,mask);
545545
}
546-
546+
547547
if (laneid < 16) {
548548
shared_qk_max[warpid][lane16id] = qk_max;
549549
shared_exp_sum[warpid][lane16id] = exp_sum;
@@ -626,20 +626,59 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
626626
const int offset2 = offset / 4;
627627

628628
//if output format is 16 head elems across 16 lanes, 16 qheads spread across 4 rows
629-
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630-
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
629+
//tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(shared_logits[vtoken_depth][offset2][lane16id][offset1],
630+
// Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], tmp_out);
631631

632632
//if output format is 16 qheads across 16 lanes, 16 head elems spread across 4 rows
633-
//partition_out[vhe_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i],
634-
// shared_tokens[vtoken_depth][offset2][lane16id][offset1],
635-
// partition_out[vhe_depth]);
633+
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i],
634+
shared_logits[vtoken_depth][offset2][lane16id][offset1],
635+
tmp_out);
636636
}
637637
}
638638
}
639639
outelems[vhe_depth] = from_floatx4<scalar_t>(tmp_out);
640640
}
641641

642-
#if 1
642+
__syncthreads();
643+
644+
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
645+
shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; //lane16 id head dimension; rowid head element dimension
646+
}
647+
648+
__syncthreads();
649+
650+
if (warpid == 0) {
651+
_B16x8 vout[GQA_RATIO4];
652+
for (int h = 0; h < GQA_RATIO4; h++) {
653+
const int local_head_idx = 4 * h + rowid;
654+
const int head_elem_idx = lane16id * 8;
655+
const int offset1 = (head_elem_idx / 16)%4;
656+
const int offset2 = head_elem_idx / 16 / NWARPS;
657+
const int offset3 = (head_elem_idx / 4)%4;
658+
for (int i=0; i<2; i++) {
659+
vout[h].xy[i] = shared_logits[offset1][offset2][local_head_idx][offset3+i];
660+
}
661+
}
662+
663+
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
664+
scalar_t* out_ptr = out +
665+
seq_idx * total_num_heads * hsz_maxp_mult + partition_idx * HEAD_SIZE;
666+
for (int h = 0; h < GQA_RATIO4; h++) {
667+
const int local_head_idx = 4 * h + rowid;
668+
if (local_head_idx < GQA_RATIO) {
669+
const int out_head_idx = wg_start_head_idx + local_head_idx;
670+
scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult;
671+
const int head_elem_idx = lane16id * 8;
672+
scalar_t* out_ptr3 = out_ptr2 + head_elem_idx;
673+
_B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3);
674+
*out_ptr_B16x8 = vout[h];
675+
}
676+
}
677+
678+
}
679+
680+
681+
#if 0
643682
//if output format is 16 he across 16 lanes, 16 qheads spread across 4 rows
644683
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;
645684
scalar_t* out_ptr = out +
@@ -661,7 +700,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_
661700
}
662701
}
663702
}
664-
#else
703+
#endif
704+
#if 0
665705
//if output format is 16 qheads across 16 lanes, 16 he spread across 4 rows
666706
if (lane16id < GQA_RATIO) {
667707
const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions;

0 commit comments

Comments
 (0)