mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:34:54 +08:00
[Feature] Enable DeepGEMM Linear on B200; 1.5% E2E throughput improvement (#23351)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3ac849665d
commit
394591e343
@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
||||
should_use_deepgemm_for_fp8_linear)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func(
|
||||
return w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
||||
"""
|
||||
Check if DeepGEMM should be used based on the output dtype and weight shape.
|
||||
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
|
||||
divisible by 128.
|
||||
"""
|
||||
|
||||
return (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90) and has_deep_gemm()
|
||||
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear(
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm(output_dtype, weight):
|
||||
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
|
||||
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear(
|
||||
column_major_scales=True,
|
||||
)
|
||||
|
||||
# ensure DeepGEMM-backed custom op is registered before use
|
||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
||||
|
||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input,
|
||||
weight,
|
||||
|
||||
@ -202,6 +202,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
||||
weight: torch.Tensor):
|
||||
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"fp8_gemm_nt",
|
||||
@ -210,4 +216,5 @@ __all__ = [
|
||||
"per_block_cast_to_fp8",
|
||||
"is_blackwell_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user