[Bugfix] Fix DeepGEMM after #29546 (#30267)

Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Zhewen Li <zhewenli@meta.com>
This commit is contained in:
Zhewen Li 2025-12-08 17:05:27 -08:00 committed by GitHub
parent 0ee6416f67
commit ae339b1a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View File

@ -30,6 +30,7 @@ from vllm.model_executor.parameter import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d, input_2d,
group_size=self.act_quant_group_shape.col, group_size=self.act_quant_group_shape.col,
use_ue8m0=True, use_ue8m0=True,
) )
else:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty( output = torch.empty(
(q_input.shape[0], weight.shape[0]), (q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16, dtype=torch.bfloat16,

View File

@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
__all__ = [ __all__ = [
"calc_diff", "calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt", "fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous", "m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked", "fp8_m_grouped_gemm_nt_masked",