[ROCm] triton fp8 kernel (#27058)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
Aleksandr Malyshev 2025-11-06 11:46:44 -08:00 committed by GitHub
parent d4aa65c998
commit 449de9001a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -69,30 +69,67 @@ def cutlass_scaled_mm(
def rocm_aiter_gemm_w8a8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
import aiter as rocm_aiter
def is_aiter_triton_kernel_tuned(n, k):
return (n, k) in [
(1024, 8192),
(2112, 7168),
(3072, 1536),
(32768, 8192),
(4096, 7168),
(4608, 7168),
(512, 7168),
(7168, 2048),
(7168, 256),
(8192, 1024),
(8192, 32768),
]
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
n, k = weight.shape
if input_scale is not None:
q_input = input_2d
elif not current_platform.is_fp8_fnuz() and is_aiter_triton_kernel_tuned(n, k):
from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale
# MI350 case uses triton kernel
q_input, input_scale = per_token_group_quant_fp8(
input_2d,
group_size,
column_major_scales=False,
use_ue8m0=False,
)
else:
# MI300 uses tuned AITER ASM/C++ kernel
import aiter as rocm_aiter
from aiter import gemm_a8w8_blockscale, get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return gemm_a8w8_blockscale(
q_input, weight, input_scale, weight_scale, dtype=output_dtype
)
def rocm_aiter_gemm_w8a8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
input_2d: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y
m = input_2d.shape[0]
n = weight.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=input_2d.device)
if current_platform.is_rocm():
@ -101,15 +138,6 @@ if current_platform.is_rocm():
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
)
if (
envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
):
import aiter as rocm_aiter
from aiter import get_hip_quant
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
# TODO we should be able to change the type of block_size to GroupShape
@ -293,7 +321,9 @@ class W8A8BlockFp8LinearOp:
):
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(input_2d, weight, weight_scale)
output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
)
if bias is not None:
output = output + bias
@ -322,7 +352,9 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper:
@ -350,17 +382,15 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
q_input, input_scale = aiter_per1x128_quant(
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8
)
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
q_input,
input_2d,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
self.act_quant_group_shape.col,
input_2d.dtype,
)
@ -369,7 +399,9 @@ class W8A8BlockFp8LinearOp:
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
@ -391,6 +423,7 @@ class W8A8BlockFp8LinearOp:
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
],
torch.Tensor,
],
@ -939,13 +972,11 @@ def requant_weight_ue8m0_inplace(
def check_aiter_fp8_linear_support() -> bool:
"""AITER is only supported on ROCm and only for FP8_FNUZ
and at the moment are MI300 series"""
"""AITER is only supported on ROCm for MI3XX"""
return (
current_platform.is_rocm()
and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_AITER_LINEAR
and current_platform.is_fp8_fnuz()
)