mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 01:54:12 +08:00
[ROCm] Aiter Quant Kernels (#25552)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
1166c31cc7
commit
ee14644ba9
@ -9,6 +9,8 @@ import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
from importlib.util import find_spec
|
||||
@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
def _rocm_aiter_per_tensor_quant_impl(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.quant import per_tensor_quant_hip
|
||||
|
||||
return per_tensor_quant_hip(x, scale, quant_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_per_tensor_quant_fake(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x, dtype=quant_dtype), torch.empty(
|
||||
1, dtype=torch.float32, device=x.device
|
||||
)
|
||||
|
||||
|
||||
def _rocm_aiter_per_token_quant_impl(
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.quant import dynamic_per_token_scaled_quant
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
|
||||
out_shape = x.shape
|
||||
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
|
||||
if scale is None:
|
||||
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
|
||||
dynamic_per_token_scaled_quant(
|
||||
out,
|
||||
x,
|
||||
scale,
|
||||
scale_ub=None,
|
||||
shuffle_scale=False,
|
||||
num_rows=None,
|
||||
num_rows_factor=1,
|
||||
)
|
||||
return out, scale
|
||||
|
||||
|
||||
def _rocm_aiter_per_token_quant_fake(
|
||||
x: torch.Tensor, quant_dtype: torch.dtype, scale: torch.Tensor | None = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
out_shape = x.shape
|
||||
return (
|
||||
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
|
||||
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
|
||||
# Global flag to ensure ops are registered only once
|
||||
_OPS_REGISTERED = False
|
||||
|
||||
@ -665,6 +720,22 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_per_tensor_quant",
|
||||
op_func=_rocm_aiter_per_tensor_quant_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_per_tensor_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_per_token_quant",
|
||||
op_func=_rocm_aiter_per_token_quant_impl,
|
||||
mutates_args=["scale"],
|
||||
fake_impl=_rocm_aiter_per_token_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
_OPS_REGISTERED = True
|
||||
|
||||
@staticmethod
|
||||
@ -859,6 +930,22 @@ class rocm_aiter_ops:
|
||||
kv_scale=kv_scale,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def per_tensor_quant(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_per_tensor_quant(x, quant_dtype, scale)
|
||||
|
||||
@staticmethod
|
||||
def per_token_quant(
|
||||
x: torch.Tensor,
|
||||
quant_dtype: torch.dtype,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_per_token_quant(x, quant_dtype, scale)
|
||||
|
||||
@staticmethod
|
||||
def triton_fp4_gemm_dynamic_qaunt(
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
|
||||
super().__init__()
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
self.num_token_padding = num_token_padding
|
||||
self.column_major_scales = column_major_scales
|
||||
self.use_ue8m0 = use_ue8m0
|
||||
|
||||
self.use_aiter = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
self.is_group_quant = group_shape.is_per_group()
|
||||
if self.is_group_quant:
|
||||
assert not static, "Group quantization only supports dynamic mode"
|
||||
@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
scale_ub: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
use_aiter_quant = (
|
||||
not self.is_group_quant
|
||||
and self.use_aiter
|
||||
and scale_ub is None
|
||||
and x.is_contiguous()
|
||||
)
|
||||
use_aiter_per_tensor_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TENSOR
|
||||
)
|
||||
use_aiter_per_token_quant = (
|
||||
use_aiter_quant and self.group_shape == GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
if use_aiter_per_tensor_quant:
|
||||
return rocm_aiter_ops.per_tensor_quant(x, _FP8_DTYPE, scale)
|
||||
if use_aiter_per_token_quant:
|
||||
return rocm_aiter_ops.per_token_quant(x, _FP8_DTYPE, scale)
|
||||
|
||||
# Fallback to CUDA implementation
|
||||
return self.forward_cuda(x, scale, scale_ub)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -381,6 +381,8 @@ class RocmPlatform(Platform):
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
use_aiter_fp8_linear = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
# decode context parallel does not support full cudagraphs
|
||||
@ -400,8 +402,6 @@ class RocmPlatform(Platform):
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
@ -415,6 +415,9 @@ class RocmPlatform(Platform):
|
||||
):
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
if use_aiter_fp8_linear and "-quant_fp8" not in compilation_config.custom_ops:
|
||||
compilation_config.custom_ops.append("+quant_fp8")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user