[ROCM] Fix ROCm warnings, environment flag access, and GEMM kernel naming for consistency in _aiter_ops.py (#28464)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-13 05:46:57 +08:00 committed by GitHub
parent 74a9a9faad
commit d8140b9833
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 29 deletions

View File

@ -32,13 +32,13 @@ def if_aiter_supported(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existance.
from vllm.platforms.rocm import on_gfx9
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
return func(*args, **kwargs)
else:
# Return None or do nothing if not supported
return None
if on_gfx9():
return func(*args, **kwargs)
return None
return wrapper
@ -296,7 +296,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
pass
def _rocm_aiter_gemm_w8a8_impl(
def _rocm_aiter_gemm_a8w8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -313,7 +313,7 @@ def _rocm_aiter_gemm_w8a8_impl(
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_w8a8_fake(
def _rocm_aiter_gemm_a8w8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -327,7 +327,7 @@ def _rocm_aiter_gemm_w8a8_fake(
return Y
def _rocm_aiter_gemm_w8a8_blockscale_impl(
def _rocm_aiter_gemm_a8w8_blockscale_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -339,7 +339,7 @@ def _rocm_aiter_gemm_w8a8_blockscale_impl(
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
def _rocm_aiter_gemm_w8a8_blockscale_fake(
def _rocm_aiter_gemm_a8w8_blockscale_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -419,6 +419,7 @@ class rocm_aiter_ops:
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@classmethod
@if_aiter_supported
@ -494,6 +495,11 @@ class rocm_aiter_ops:
def is_triton_rotary_embed_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
@classmethod
@if_aiter_supported
def is_triton_gemm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
@ -555,18 +561,18 @@ class rocm_aiter_ops:
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=_rocm_aiter_gemm_w8a8_impl,
op_name="rocm_aiter_gemm_a8w8",
op_func=_rocm_aiter_gemm_a8w8_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_fake,
fake_impl=_rocm_aiter_gemm_a8w8_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8_blockscale",
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
op_name="rocm_aiter_gemm_a8w8_blockscale",
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
mutates_args=[],
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
dispatch_key=current_platform.dispatch_key,
)
@ -606,7 +612,7 @@ class rocm_aiter_ops:
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
@staticmethod
def gemm_w8a8(
def gemm_a8w8(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -614,10 +620,10 @@ class rocm_aiter_ops:
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
@staticmethod
def gemm_w8a8_blockscale(
def gemm_a8w8_blockscale(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@ -625,7 +631,7 @@ class rocm_aiter_ops:
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
A, B, As, Bs, output_dtype
)
@ -938,5 +944,4 @@ class rocm_aiter_ops:
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
if IS_AITER_FOUND:
rocm_aiter_ops.register_ops_once()
rocm_aiter_ops.register_ops_once()

View File

@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp:
if use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
if input_scale is not None:
q_input = input_2d

View File

@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@ -105,8 +106,7 @@ def default_unquantized_gemm(
def use_aiter_triton_gemm(n, m, k, dtype):
if (
envs.VLLM_ROCM_USE_AITER == 0
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
not rocm_aiter_ops.is_triton_gemm_enabled()
# MI300's - fp8nuz=True
or current_platform.is_fp8_fnuz()
or dtype not in [torch.float16, torch.bfloat16]

View File

@ -325,6 +325,7 @@ class RocmPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
@ -332,9 +333,7 @@ class RocmPlatform(Platform):
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_aiter_rms_norm = (
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
)
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16