File tree Expand file tree Collapse file tree 2 files changed +11
-6
lines changed
model_executor/layers/quantization/utils Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change 3030from vllm .platforms import current_platform
3131from vllm .triton_utils import tl , triton
3232from vllm .utils .deep_gemm import (
33+ DeepGemmQuantScaleFMT ,
3334 fp8_gemm_nt ,
3435 is_deep_gemm_e8m0_used ,
3536 is_deep_gemm_supported ,
@@ -268,12 +269,15 @@ def _run_deepgemm(
268269 weight : torch .Tensor ,
269270 weight_scale : torch .Tensor ,
270271 ) -> torch .Tensor :
271- assert self .deepgemm_input_quant_op is not None
272- q_input , input_scale = per_token_group_quant_fp8_packed_for_deepgemm (
273- input_2d ,
274- group_size = self .act_quant_group_shape .col ,
275- use_ue8m0 = True ,
276- )
272+ if DeepGemmQuantScaleFMT .from_oracle () == DeepGemmQuantScaleFMT .UE8M0 :
273+ q_input , input_scale = per_token_group_quant_fp8_packed_for_deepgemm (
274+ input_2d ,
275+ group_size = self .act_quant_group_shape .col ,
276+ use_ue8m0 = True ,
277+ )
278+ else :
279+ assert self .deepgemm_input_quant_op is not None
280+ q_input , input_scale = self .deepgemm_input_quant_op (input_2d )
277281 output = torch .empty (
278282 (q_input .shape [0 ], weight .shape [0 ]),
279283 dtype = torch .bfloat16 ,
Original file line number Diff line number Diff line change @@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
399399
400400__all__ = [
401401 "calc_diff" ,
402+ "DeepGemmQuantScaleFMT" ,
402403 "fp8_gemm_nt" ,
403404 "m_grouped_fp8_gemm_nt_contiguous" ,
404405 "fp8_m_grouped_gemm_nt_masked" ,
You can’t perform that action at this time.
0 commit comments