[ROCm] Aiter Quant Kernels (#25552)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-12-09 22:27:37 +08:00 committed by GitHub
parent 1166c31cc7
commit ee14644ba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 2 deletions

View File

@ -9,6 +9,8 @@ import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
_FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool:
from importlib.util import find_spec
@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return torch.empty_like(x), torch.empty_like(residual)
def _rocm_aiter_per_tensor_quant_impl(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import per_tensor_quant_hip
return per_tensor_quant_hip(x, scale, quant_dtype)
def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
1, dtype=torch.float32, device=x.device
)
def _rocm_aiter_per_token_quant_impl(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
out,
x,
scale,
scale_ub=None,
shuffle_scale=False,
num_rows=None,
num_rows_factor=1,
)
return out, scale
def _rocm_aiter_per_token_quant_fake(
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False
@ -665,6 +720,22 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_tensor_quant",
op_func=_rocm_aiter_per_tensor_quant_impl,
mutates_args=[],
fake_impl=_rocm_aiter_per_tensor_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_per_token_quant",
op_func=_rocm_aiter_per_token_quant_impl,
mutates_args=["scale"],
fake_impl=_rocm_aiter_per_token_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
_OPS_REGISTERED = True
@staticmethod
@ -859,6 +930,22 @@ class rocm_aiter_ops:
kv_scale=kv_scale,
)
@staticmethod
def per_tensor_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
@staticmethod
def per_token_quant(
x: torch.Tensor,
quant_dtype: torch.dtype,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
@staticmethod
def triton_fp4_gemm_dynamic_qaunt(
x: torch.Tensor,

View File

@ -5,6 +5,7 @@ import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
super().__init__()
self.static = static
self.group_shape = group_shape
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
self.num_token_padding = num_token_padding
self.column_major_scales = column_major_scales
self.use_ue8m0 = use_ue8m0
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
def forward_hip(
self,
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
use_aiter_quant = (
not self.is_group_quant
and self.use_aiter
and scale_ub is None
and x.is_contiguous()
)
use_aiter_per_tensor_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
)
use_aiter_per_token_quant = (
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
)
if use_aiter_per_tensor_quant:
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
if use_aiter_per_token_quant:
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
# Fallback to CUDA implementation
return self.forward_cuda(x, scale, scale_ub)
def forward_native(
self,
x: torch.Tensor,

View File

@ -381,6 +381,8 @@ class RocmPlatform(Platform):
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
if compilation_config.cudagraph_mode.has_full_cudagraphs():
# decode context parallel does not support full cudagraphs
@ -400,8 +402,6 @@ class RocmPlatform(Platform):
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
@ -415,6 +415,9 @@ class RocmPlatform(Platform):
):
compilation_config.custom_ops.append("+rms_norm")
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+quant_fp8")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: