mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 11:44:38 +08:00
[Refactor] Remove DeepGEMM OP Register (#25710)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
1d21080118
commit
3a32aa8a6b
@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.triton_utils import triton
|
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
from vllm.utils.deep_gemm import fp8_gemm_nt
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_block_fp8_matmul_inputs(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
As: torch.Tensor,
|
|
||||||
Bs: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype = torch.float16,
|
|
||||||
) -> tuple[int, int, int, torch.Tensor]:
|
|
||||||
assert len(block_size) == 2
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
|
|
||||||
assert A.shape[-1] == B.shape[-1]
|
|
||||||
assert A.shape[:-1] == As.shape[:-1]
|
|
||||||
assert A.is_contiguous()
|
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
||||||
|
|
||||||
M = A.numel() // A.shape[-1]
|
|
||||||
|
|
||||||
assert B.ndim == 2
|
|
||||||
assert B.is_contiguous()
|
|
||||||
assert Bs.ndim == 2
|
|
||||||
N, K = B.shape
|
|
||||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
|
||||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
|
||||||
|
|
||||||
C_shape = A.shape[:-1] + (N, )
|
|
||||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
|
||||||
|
|
||||||
return M, N, K, C
|
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul_deepgemm(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
As: torch.Tensor,
|
|
||||||
Bs: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
|
||||||
output_dtype)
|
|
||||||
# Deepgemm only supports output tensor type as bfloat16
|
|
||||||
assert C.dtype == torch.bfloat16
|
|
||||||
fp8_gemm_nt((A, As), (B, Bs), C)
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul_deepgemm_fake(
|
|
||||||
A: torch.Tensor,
|
|
||||||
B: torch.Tensor,
|
|
||||||
As: torch.Tensor,
|
|
||||||
Bs: torch.Tensor,
|
|
||||||
block_size: list[int],
|
|
||||||
output_dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size,
|
|
||||||
output_dtype)
|
|
||||||
return C
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="w8a8_block_fp8_matmul_deepgemm",
|
|
||||||
op_func=w8a8_block_fp8_matmul_deepgemm,
|
|
||||||
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
|
|
||||||
)
|
|
||||||
@ -23,7 +23,7 @@ from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
|
from vllm.utils.deep_gemm import (fp8_gemm_nt, is_deep_gemm_e8m0_used,
|
||||||
should_use_deepgemm_for_fp8_linear)
|
should_use_deepgemm_for_fp8_linear)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -141,17 +141,10 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
block_size[1],
|
block_size[1],
|
||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
)
|
)
|
||||||
|
output = torch.empty((q_input.shape[0], weight.shape[0]),
|
||||||
# ensure DeepGEMM-backed custom op is registered before use
|
dtype=torch.bfloat16,
|
||||||
import vllm.model_executor.layers.quantization.deepgemm # noqa: F401
|
device=q_input.device)
|
||||||
|
fp8_gemm_nt((q_input, x_scale), (weight, weight_scale), output)
|
||||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
x_scale,
|
|
||||||
weight_scale,
|
|
||||||
block_size,
|
|
||||||
output_dtype=output_dtype)
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
output += bias
|
output += bias
|
||||||
return output.to(dtype=output_dtype).view(*output_shape)
|
return output.to(dtype=output_dtype).view(*output_shape)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user