From ee14644ba9a3696c83ede2c948b73ebc3e1ffb33 Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 9 Dec 2025 22:27:37 +0800 Subject: [PATCH] [ROCm] Aiter Quant Kernels (#25552) Signed-off-by: vllmellm --- vllm/_aiter_ops.py | 87 +++++++++++++++++++ .../layers/quantization/input_quant_fp8.py | 31 +++++++ vllm/platforms/rocm.py | 7 +- 3 files changed, 123 insertions(+), 2 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 35920d826578e..94bbc9b00225e 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -9,6 +9,8 @@ import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +_FP8_DTYPE = current_platform.fp8_dtype() + def is_aiter_found() -> bool: from importlib.util import find_spec @@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( return torch.empty_like(x), torch.empty_like(residual) +def _rocm_aiter_per_tensor_quant_impl( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.quant import per_tensor_quant_hip + + return per_tensor_quant_hip(x, scale, quant_dtype) + + +def _rocm_aiter_per_tensor_quant_fake( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x, dtype=quant_dtype), torch.empty( + 1, dtype=torch.float32, device=x.device + ) + + +def _rocm_aiter_per_token_quant_impl( + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.ops.quant import dynamic_per_token_scaled_quant + + assert quant_dtype in [torch.int8, _FP8_DTYPE] + + out_shape = x.shape + out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device) + if scale is None: + scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) + dynamic_per_token_scaled_quant( + out, + x, + scale, + scale_ub=None, + shuffle_scale=False, + num_rows=None, + num_rows_factor=1, + ) + return out, scale + + +def _rocm_aiter_per_token_quant_fake( + x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None +) -> tuple[torch.Tensor, torch.Tensor]: + out_shape = x.shape + return ( + torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device), + torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), + ) + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -665,6 +720,22 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_per_tensor_quant", + op_func=_rocm_aiter_per_tensor_quant_impl, + mutates_args=[], + fake_impl=_rocm_aiter_per_tensor_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_per_token_quant", + op_func=_rocm_aiter_per_token_quant_impl, + mutates_args=["scale"], + fake_impl=_rocm_aiter_per_token_quant_fake, + dispatch_key=current_platform.dispatch_key, + ) + _OPS_REGISTERED = True @staticmethod @@ -859,6 +930,22 @@ class rocm_aiter_ops: kv_scale=kv_scale, ) + @staticmethod + def per_tensor_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale) + + @staticmethod + def per_token_quant( + x: torch.Tensor, + quant_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale) + @staticmethod def triton_fp4_gemm_dynamic_qaunt( x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 7ded8eea79060..a5db086fb4729 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -5,6 +5,7 @@ import torch import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.platforms import current_platform @@ -45,10 +46,13 @@ class QuantFP8(CustomOp): super().__init__() self.static = static self.group_shape = group_shape + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.num_token_padding = num_token_padding self.column_major_scales = column_major_scales self.use_ue8m0 = use_ue8m0 + self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled() + self.is_group_quant = group_shape.is_per_group() if self.is_group_quant: assert not static, "Group quantization only supports dynamic mode" @@ -92,6 +96,33 @@ class QuantFP8(CustomOp): use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) + def forward_hip( + self, + x: torch.Tensor, + scale: torch.Tensor | None = None, + scale_ub: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + use_aiter_quant = ( + not self.is_group_quant + and self.use_aiter + and scale_ub is None + and x.is_contiguous() + ) + use_aiter_per_tensor_quant = ( + use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR + ) + use_aiter_per_token_quant = ( + use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN + ) + + if use_aiter_per_tensor_quant: + return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale) + if use_aiter_per_token_quant: + return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale) + + # Fallback to CUDA implementation + return self.forward_cuda(x, scale, scale_ub) + def forward_native( self, x: torch.Tensor, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ff0fc78517876..f7adecbd88746 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -381,6 +381,8 @@ class RocmPlatform(Platform): compilation_config = vllm_config.compilation_config parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE + use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() + use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled() if compilation_config.cudagraph_mode.has_full_cudagraphs(): # decode context parallel does not support full cudagraphs @@ -400,8 +402,6 @@ class RocmPlatform(Platform): ) compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE - use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() - if cache_config and cache_config.block_size is None: cache_config.block_size = 16 @@ -415,6 +415,9 @@ class RocmPlatform(Platform): ): compilation_config.custom_ops.append("+rms_norm") + if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops: + compilation_config.custom_ops.append("+quant_fp8") + @classmethod def verify_model_arch(cls, model_arch: str) -> None: if model_arch in _ROCM_UNSUPPORTED_MODELS: