From 80d38b8ac850fbb19b9a76e74cd53ff04573e58b Mon Sep 17 00:00:00 2001 From: TJian Date: Sun, 13 Jul 2025 08:19:32 -0700 Subject: [PATCH] [V1] [ROCm] [AITER] Upgrade AITER to commit `916bf3c` and bugfix APIs (#20880) Signed-off-by: tjtanaa --- docker/Dockerfile.rocm_base | 2 +- .../quantization/kernels/scaled_mm/aiter.py | 49 +++++++++++++++++-- .../layers/quantization/utils/fp8_utils.py | 2 +- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index dc8ec5f1a15e5..3414c0aa845cb 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="6487649" +ARG AITER_BRANCH="916bf3c" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 165548a060128..7f808fa92a9a8 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -8,11 +8,55 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +def rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -111,10 +155,9 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + "does not support AITER block scaled GEMM.") - from aiter import gemm_a8w8_CK - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) + return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, + bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9c78dea17e5c4..c093a9bfc4a60 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -56,7 +56,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( ) -> torch.Tensor: import aiter as rocm_aiter - return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) def rocm_aiter_gemm_w8a8_blockscale_fake(