@@ -2037,21 +2037,30 @@ def forward(
20372037
20382038 if fp8_attention :
20392039 ql_nope_shape = decode_ql_nope .shape
2040- decode_ql_nope , _ = ops .scaled_fp8_quant (
2041- decode_ql_nope .reshape (
2042- [ql_nope_shape [0 ], ql_nope_shape [1 ] * ql_nope_shape [2 ]]
2043- ),
2044- layer ._q_scale ,
2045- )
2046- decode_ql_nope = decode_ql_nope .reshape (ql_nope_shape )
20472040 q_pe_shape = decode_q_pe .shape
2048- decode_q_pe , _ = ops .scaled_fp8_quant (
2049- decode_q_pe .reshape ([q_pe_shape [0 ], q_pe_shape [1 ] * q_pe_shape [2 ]]),
2050- layer ._q_scale ,
2041+ assert decode_ql_nope .shape [0 ] == decode_q_pe .shape [0 ]
2042+ assert decode_ql_nope .shape [1 ] == decode_q_pe .shape [1 ]
2043+ decode_q_shape = (
2044+ ql_nope_shape [0 ],
2045+ ql_nope_shape [1 ],
2046+ ql_nope_shape [2 ] + q_pe_shape [2 ],
2047+ )
2048+ # Using empty and copy since torch.cat introduces significant overhead.
2049+ decode_q0 = torch .empty (
2050+ decode_q_shape ,
2051+ device = decode_ql_nope .device ,
2052+ dtype = decode_ql_nope .dtype ,
20512053 )
2052- decode_q_pe = decode_q_pe .reshape (q_pe_shape )
2054+ decode_q0 [..., : ql_nope_shape [2 ]].copy_ (decode_ql_nope )
2055+ decode_q0 [..., ql_nope_shape [2 ] :].copy_ (decode_q_pe )
20532056
2054- decode_q = (decode_ql_nope , decode_q_pe )
2057+ decode_q , _ = ops .scaled_fp8_quant (
2058+ decode_q0 .view (decode_q_shape [0 ], - 1 ),
2059+ layer ._q_scale ,
2060+ )
2061+ decode_q = decode_q .view (decode_q_shape )
2062+ else :
2063+ decode_q = (decode_ql_nope , decode_q_pe )
20552064 if self .dcp_world_size > 1 :
20562065 assert not fp8_attention , "DCP not support fp8 kvcache now."
20572066 # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
0 commit comments