mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:45:22 +08:00
[Feature] Enable DeepGEMM Linear on B200; 1.5% E2E throughput improvement (#23351)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
3ac849665d
commit
394591e343
@ -19,8 +19,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
from vllm.utils import cdiv, direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
||||||
|
should_use_deepgemm_for_fp8_linear)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -108,19 +109,6 @@ def dispatch_w8a8_blockscale_func(
|
|||||||
return w8a8_block_fp8_matmul
|
return w8a8_block_fp8_matmul
|
||||||
|
|
||||||
|
|
||||||
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Check if DeepGEMM should be used based on the output dtype and weight shape.
|
|
||||||
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
|
|
||||||
divisible by 128.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return (current_platform.is_cuda()
|
|
||||||
and current_platform.is_device_capability(90) and has_deep_gemm()
|
|
||||||
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
|
||||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
# TODO fix ROCm->Triton custom path:
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
@ -139,7 +127,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
output_dtype = input.dtype
|
output_dtype = input.dtype
|
||||||
|
|
||||||
if should_use_deepgemm(output_dtype, weight):
|
if should_use_deepgemm_for_fp8_linear(output_dtype, weight):
|
||||||
|
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
@ -150,7 +138,9 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ensure DeepGEMM-backed custom op is registered before use
|
||||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
||||||
|
|
||||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||||
q_input,
|
q_input,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -202,6 +202,12 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|||||||
return 1 - sim
|
return 1 - sim
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype,
|
||||||
|
weight: torch.Tensor):
|
||||||
|
return (is_deep_gemm_supported() and output_dtype == torch.bfloat16
|
||||||
|
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"calc_diff",
|
"calc_diff",
|
||||||
"fp8_gemm_nt",
|
"fp8_gemm_nt",
|
||||||
@ -210,4 +216,5 @@ __all__ = [
|
|||||||
"per_block_cast_to_fp8",
|
"per_block_cast_to_fp8",
|
||||||
"is_blackwell_deep_gemm_e8m0_used",
|
"is_blackwell_deep_gemm_e8m0_used",
|
||||||
"is_deep_gemm_supported",
|
"is_deep_gemm_supported",
|
||||||
|
"should_use_deepgemm_for_fp8_linear",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user