Skip to content

Commit 1fb632f

Browse files
authored
[Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795)
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
1 parent 6af70e1 commit 1fb632f

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def reduce_scatterv(
225225
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
226226
)
227227

228-
if sizes is not None:
228+
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
229229
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
230230
else:
231231
pynccl_comm.reduce_scatter(output, input_tensor)

vllm/v1/attention/backends/mla/common.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,21 +2037,30 @@ def forward(
20372037

20382038
if fp8_attention:
20392039
ql_nope_shape = decode_ql_nope.shape
2040-
decode_ql_nope, _ = ops.scaled_fp8_quant(
2041-
decode_ql_nope.reshape(
2042-
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
2043-
),
2044-
layer._q_scale,
2045-
)
2046-
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
20472040
q_pe_shape = decode_q_pe.shape
2048-
decode_q_pe, _ = ops.scaled_fp8_quant(
2049-
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
2050-
layer._q_scale,
2041+
assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
2042+
assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
2043+
decode_q_shape = (
2044+
ql_nope_shape[0],
2045+
ql_nope_shape[1],
2046+
ql_nope_shape[2] + q_pe_shape[2],
2047+
)
2048+
# Using empty and copy since torch.cat introduces significant overhead.
2049+
decode_q0 = torch.empty(
2050+
decode_q_shape,
2051+
device=decode_ql_nope.device,
2052+
dtype=decode_ql_nope.dtype,
20512053
)
2052-
decode_q_pe = decode_q_pe.reshape(q_pe_shape)
2054+
decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
2055+
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
20532056

2054-
decode_q = (decode_ql_nope, decode_q_pe)
2057+
decode_q, _ = ops.scaled_fp8_quant(
2058+
decode_q0.view(decode_q_shape[0], -1),
2059+
layer._q_scale,
2060+
)
2061+
decode_q = decode_q.view(decode_q_shape)
2062+
else:
2063+
decode_q = (decode_ql_nope, decode_q_pe)
20552064
if self.dcp_world_size > 1:
20562065
assert not fp8_attention, "DCP not support fp8 kvcache now."
20572066
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)

0 commit comments

Comments
 (0)