From 394591e34371e4a1ad23a85401148df31d8de451 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 22 Aug 2025 00:01:08 -0400 Subject: [PATCH] [Feature] Enable DeepGEMM Linear on B200; 1.5% E2E throughput improvement (#23351) Signed-off-by: yewentao256 --- .../layers/quantization/utils/fp8_utils.py | 22 +++++-------------- vllm/utils/deep_gemm.py | 7 ++++++ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 2fb7ef29e4684..ab1d5383f4651 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, + should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func( return w8a8_block_fp8_matmul -def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): - """ - Check if DeepGEMM should be used based on the output dtype and weight shape. - DeepGEMM is only supported for bfloat16 output dtype and weights with shape - divisible by 128. - """ - - return (current_platform.is_cuda() - and current_platform.is_device_capability(90) and has_deep_gemm() - and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 - and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) - - # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear( output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype - if should_use_deepgemm(output_dtype, weight): + if should_use_deepgemm_for_fp8_linear(output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear( column_major_scales=True, ) + # ensure DeepGEMM-backed custom op is registered before use import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 + output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 861d9c0c0005d..c0a4ed077e660 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -202,6 +202,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return 1 - sim +def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, + weight: torch.Tensor): + return (is_deep_gemm_supported() and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + + __all__ = [ "calc_diff", "fp8_gemm_nt", @@ -210,4 +216,5 @@ __all__ = [ "per_block_cast_to_fp8", "is_blackwell_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "should_use_deepgemm_for_fp8_linear", ]