[Bugfix] Fix DeepGEMM after #29546 (#30267)

Signed-off-by: zhewenli <zhewenli@meta.com>
Signed-off-by: Zhewen Li <zhewenli@meta.com>
This commit is contained in:
Zhewen Li 2025-12-08 17:05:27 -08:00 committed by GitHub
parent 0ee6416f67
commit ae339b1a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 6 deletions

View File

@ -30,6 +30,7 @@ from vllm.model_executor.parameter import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d,
group_size=self.act_quant_group_shape.col,
use_ue8m0=True,
)
else:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,

View File

@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
__all__ = [
"calc_diff",
"DeepGemmQuantScaleFMT",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",