mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-19 11:34:29 +08:00
[ROCm] triton fp8 kernel (#27058)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
parent
d4aa65c998
commit
449de9001a
@ -69,30 +69,67 @@ def cutlass_scaled_mm(
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
def is_aiter_triton_kernel_tuned(n, k):
|
||||
return (n, k) in [
|
||||
(1024, 8192),
|
||||
(2112, 7168),
|
||||
(3072, 1536),
|
||||
(32768, 8192),
|
||||
(4096, 7168),
|
||||
(4608, 7168),
|
||||
(512, 7168),
|
||||
(7168, 2048),
|
||||
(7168, 256),
|
||||
(8192, 1024),
|
||||
(8192, 32768),
|
||||
]
|
||||
|
||||
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
n, k = weight.shape
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
|
||||
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
|
||||
|
||||
# MI350 case uses triton kernel
|
||||
q_input, input_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
group_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=False,
|
||||
)
|
||||
else:
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
import aiter as rocm_aiter
|
||||
from aiter import gemm_a8w8_blockscale, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
|
||||
return gemm_a8w8_blockscale(
|
||||
q_input, weight, input_scale, weight_scale, dtype=output_dtype
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_size: int,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
m = A.shape[0]
|
||||
n = B.shape[0]
|
||||
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
|
||||
return Y
|
||||
m = input_2d.shape[0]
|
||||
n = weight.shape[0]
|
||||
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
@ -101,15 +138,6 @@ if current_platform.is_rocm():
|
||||
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
)
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()
|
||||
):
|
||||
import aiter as rocm_aiter
|
||||
from aiter import get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||
|
||||
|
||||
# TODO we should be able to change the type of block_size to GroupShape
|
||||
@ -293,7 +321,9 @@ class W8A8BlockFp8LinearOp:
|
||||
):
|
||||
output = self._run_deepgemm(input_2d, weight, weight_scale)
|
||||
else:
|
||||
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
|
||||
output = self.w8a8_blockscale_op(
|
||||
input_2d, weight, weight_scale, input_scale
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
@ -322,7 +352,9 @@ class W8A8BlockFp8LinearOp:
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
if self.is_hopper:
|
||||
@ -350,17 +382,15 @@ class W8A8BlockFp8LinearOp:
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||
q_input, input_scale = aiter_per1x128_quant(
|
||||
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
|
||||
)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
q_input,
|
||||
input_2d,
|
||||
weight,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
self.act_quant_group_shape.col,
|
||||
input_2d.dtype,
|
||||
)
|
||||
|
||||
@ -369,7 +399,9 @@ class W8A8BlockFp8LinearOp:
|
||||
input_2d: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
assert self.input_quant_op is not None
|
||||
q_input, input_scale = self.input_quant_op(input_2d)
|
||||
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
|
||||
@ -391,6 +423,7 @@ class W8A8BlockFp8LinearOp:
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
],
|
||||
torch.Tensor,
|
||||
],
|
||||
@ -939,13 +972,11 @@ def requant_weight_ue8m0_inplace(
|
||||
|
||||
|
||||
def check_aiter_fp8_linear_support() -> bool:
|
||||
"""AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
and at the moment are MI300 series"""
|
||||
"""AITER is only supported on ROCm for MI3XX"""
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
and current_platform.is_fp8_fnuz()
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user