[Feature] Enable DeepGEMM Linear on B200; 1.5% E2E throughput improvement (#23351)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-08-22 00:01:08 -04:00 committed by GitHub
parent 3ac849665d
commit 394591e343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 16 deletions

View File

@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func(
return w8a8_block_fp8_matmul
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
Check if DeepGEMM should be used based on the output dtype and weight shape.
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
divisible by 128.
"""
return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm()
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear(
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight):
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear(
column_major_scales=True,
)
# ensure DeepGEMM-backed custom op is registered before use
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,

View File

@ -202,6 +202,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
return 1 - sim
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
weight: torch.Tensor):
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
__all__ = [
"calc_diff",
"fp8_gemm_nt",
@ -210,4 +216,5 @@ __all__ = [
"per_block_cast_to_fp8",
"is_blackwell_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
"should_use_deepgemm_for_fp8_linear",
]