From ae339b1a67ac6651dae926aae2afa0c168c1ddb5 Mon Sep 17 00:00:00 2001 From: Zhewen Li Date: Mon, 8 Dec 2025 17:05:27 -0800 Subject: [PATCH] [Bugfix] Fix DeepGEMM after #29546 (#30267) Signed-off-by: zhewenli Signed-off-by: Zhewen Li --- .../layers/quantization/utils/fp8_utils.py | 16 ++++++++++------ vllm/utils/deep_gemm.py | 1 + 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ad92f4ec63c34..366c5778fa5d9 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -30,6 +30,7 @@ from vllm.model_executor.parameter import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import ( + DeepGemmQuantScaleFMT, fp8_gemm_nt, is_deep_gemm_e8m0_used, is_deep_gemm_supported, @@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp: weight: torch.Tensor, weight_scale: torch.Tensor, ) -> torch.Tensor: - assert self.deepgemm_input_quant_op is not None - q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( - input_2d, - group_size=self.act_quant_group_shape.col, - use_ue8m0=True, - ) + if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0: + q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( + input_2d, + group_size=self.act_quant_group_shape.col, + use_ue8m0=True, + ) + else: + assert self.deepgemm_input_quant_op is not None + q_input, input_scale = self.deepgemm_input_quant_op(input_2d) output = torch.empty( (q_input.shape[0], weight.shape[0]), dtype=torch.bfloat16, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 8545108a02666..a099fde1bdc45 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk( __all__ = [ "calc_diff", + "DeepGemmQuantScaleFMT", "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked",