[Bugfix] Fix fp8 DeepGemm compilation issues (#30336)

This commit is contained in:
ElizaWszola 2025-12-10 02:17:25 +01:00 committed by GitHub
parent 4c2e10ea19
commit 2e7035dd8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,7 +31,6 @@ from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_gemm_nt, fp8_gemm_nt,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported, is_deep_gemm_supported,
@ -248,6 +247,7 @@ class W8A8BlockFp8LinearOp:
self.act_quant_group_shape = act_quant_group_shape self.act_quant_group_shape = act_quant_group_shape
self.is_deep_gemm_supported = is_deep_gemm_supported() self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90) self.is_hopper = current_platform.is_device_capability(90)
self.is_blackwell = current_platform.is_device_capability(100)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used() self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
# Get the correct blockscale mul and input quant operations. # Get the correct blockscale mul and input quant operations.
@ -303,7 +303,7 @@ class W8A8BlockFp8LinearOp:
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0: if self.use_deep_gemm_e8m0 and self.is_blackwell:
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm( q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
input_2d, input_2d,
group_size=self.act_quant_group_shape.col, group_size=self.act_quant_group_shape.col,