Skip to content

Commit ae339b1

Browse files
authored
[Bugfix] Fix DeepGEMM after #29546 (#30267)
Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: Zhewen Li <zhewenli@meta.com>
1 parent 0ee6416 commit ae339b1

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.platforms import current_platform
3131
from vllm.triton_utils import tl, triton
3232
from vllm.utils.deep_gemm import (
33+
DeepGemmQuantScaleFMT,
3334
fp8_gemm_nt,
3435
is_deep_gemm_e8m0_used,
3536
is_deep_gemm_supported,
@@ -268,12 +269,15 @@ def _run_deepgemm(
268269
weight: torch.Tensor,
269270
weight_scale: torch.Tensor,
270271
) -> torch.Tensor:
271-
assert self.deepgemm_input_quant_op is not None
272-
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
273-
input_2d,
274-
group_size=self.act_quant_group_shape.col,
275-
use_ue8m0=True,
276-
)
272+
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
273+
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
274+
input_2d,
275+
group_size=self.act_quant_group_shape.col,
276+
use_ue8m0=True,
277+
)
278+
else:
279+
assert self.deepgemm_input_quant_op is not None
280+
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
277281
output = torch.empty(
278282
(q_input.shape[0], weight.shape[0]),
279283
dtype=torch.bfloat16,

vllm/utils/deep_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
399399

400400
__all__ = [
401401
"calc_diff",
402+
"DeepGemmQuantScaleFMT",
402403
"fp8_gemm_nt",
403404
"m_grouped_fp8_gemm_nt_contiguous",
404405
"fp8_m_grouped_gemm_nt_masked",

0 commit comments

Comments
 (0)