mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 03:15:01 +08:00
[Kernel] Support deep_gemm for linear methods (#19085)
Signed-off-by: artetaout <lulala341@gmail.com>
This commit is contained in:
parent
5039ec2336
commit
b8e809a057
84
vllm/model_executor/layers/quantization/deepgemm.py
Normal file
84
vllm/model_executor/layers/quantization/deepgemm.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
if has_deep_gemm:
|
||||||
|
import deep_gemm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_block_fp8_matmul_inputs(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> tuple[int, int, int, torch.Tensor]:
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1]
|
||||||
|
assert A.is_contiguous()
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
|
assert B.ndim == 2
|
||||||
|
assert B.is_contiguous()
|
||||||
|
assert Bs.ndim == 2
|
||||||
|
N, K = B.shape
|
||||||
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = A.shape[:-1] + (N, )
|
||||||
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
|
return M, N, K, C
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_fp8_matmul_deepgemm(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
||||||
|
output_dtype)
|
||||||
|
# Deepgemm only supports output tensor type as bfloat16
|
||||||
|
assert C.dtype == torch.bfloat16
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_fp8_matmul_deepgemm_fake(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
output_dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
||||||
|
output_dtype)
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="w8a8_block_fp8_matmul_deepgemm",
|
||||||
|
op_func=w8a8_block_fp8_matmul_deepgemm,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
@ -402,6 +402,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.weight_block_size is not None
|
assert self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
return torch.ops.vllm.apply_w8a8_block_fp8_linear(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
|
|||||||
@ -3,12 +3,14 @@
|
|||||||
|
|
||||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||||
import functools
|
import functools
|
||||||
|
import importlib.util
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton
|
|||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||||
@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
|
|||||||
return w8a8_block_fp8_matmul
|
return w8a8_block_fp8_matmul
|
||||||
|
|
||||||
|
|
||||||
|
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Check if DeepGEMM should be used based on the output dtype and weight shape.
|
||||||
|
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
|
||||||
|
divisible by 128.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (current_platform.is_cuda()
|
||||||
|
and current_platform.is_device_capability(90) and has_deep_gemm
|
||||||
|
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
||||||
|
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||||
|
|
||||||
|
|
||||||
# TODO fix ROCm->Triton custom path:
|
# TODO fix ROCm->Triton custom path:
|
||||||
# https://github.com/vllm-project/vllm/issues/14397
|
# https://github.com/vllm-project/vllm/issues/14397
|
||||||
def apply_w8a8_block_fp8_linear(
|
def apply_w8a8_block_fp8_linear(
|
||||||
@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
# View input as 2D matrix for fp8 methods
|
# View input as 2D matrix for fp8 methods
|
||||||
input_2d = input.view(-1, input.shape[-1])
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
output_dtype = input.dtype
|
||||||
|
|
||||||
|
if should_use_deepgemm(output_dtype, weight):
|
||||||
|
|
||||||
|
input_2d = input.view(-1, input.shape[-1])
|
||||||
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||||
|
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d,
|
||||||
|
block_size[1],
|
||||||
|
column_major_scales=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||||
|
q_input,
|
||||||
|
weight,
|
||||||
|
x_scale,
|
||||||
|
weight_scale,
|
||||||
|
block_size,
|
||||||
|
output_dtype=output_dtype)
|
||||||
|
if bias is not None:
|
||||||
|
output += bias
|
||||||
|
return output.to(dtype=output_dtype).view(*output_shape)
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
if current_platform.has_device_capability(100):
|
if current_platform.has_device_capability(100):
|
||||||
@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
|
|
||||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||||
use_cutlass, use_aiter_and_is_supported)
|
use_cutlass, use_aiter_and_is_supported)
|
||||||
|
|
||||||
if use_cutlass:
|
if use_cutlass:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user