diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 8d35aa65738b..5508e59bcd2f 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -32,13 +32,13 @@ def if_aiter_supported(func: Callable) -> Callable: def wrapper(*args, **kwargs): # checks the platform, device arch and aiter library existance. - from vllm.platforms.rocm import on_gfx9 + if current_platform.is_rocm() and IS_AITER_FOUND: + from vllm.platforms.rocm import on_gfx9 - if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND: - return func(*args, **kwargs) - else: - # Return None or do nothing if not supported - return None + if on_gfx9(): + return func(*args, **kwargs) + + return None return wrapper @@ -296,7 +296,7 @@ def _rocm_aiter_mla_decode_fwd_fake( pass -def _rocm_aiter_gemm_w8a8_impl( +def _rocm_aiter_gemm_a8w8_impl( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -313,7 +313,7 @@ def _rocm_aiter_gemm_w8a8_impl( return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) -def _rocm_aiter_gemm_w8a8_fake( +def _rocm_aiter_gemm_a8w8_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -327,7 +327,7 @@ def _rocm_aiter_gemm_w8a8_fake( return Y -def _rocm_aiter_gemm_w8a8_blockscale_impl( +def _rocm_aiter_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -339,7 +339,7 @@ def _rocm_aiter_gemm_w8a8_blockscale_impl( return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) -def _rocm_aiter_gemm_w8a8_blockscale_fake( +def _rocm_aiter_gemm_a8w8_blockscale_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -419,6 +419,7 @@ class rocm_aiter_ops: _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE _MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS + _TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM @classmethod @if_aiter_supported @@ -494,6 +495,11 @@ class rocm_aiter_ops: def is_triton_rotary_embed_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED + @classmethod + @if_aiter_supported + def is_triton_gemm_enabled(cls) -> bool: + return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -555,18 +561,18 @@ class rocm_aiter_ops: ) direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8", - op_func=_rocm_aiter_gemm_w8a8_impl, + op_name="rocm_aiter_gemm_a8w8", + op_func=_rocm_aiter_gemm_a8w8_impl, mutates_args=[], - fake_impl=_rocm_aiter_gemm_w8a8_fake, + fake_impl=_rocm_aiter_gemm_a8w8_fake, dispatch_key=current_platform.dispatch_key, ) direct_register_custom_op( - op_name="rocm_aiter_gemm_w8a8_blockscale", - op_func=_rocm_aiter_gemm_w8a8_blockscale_impl, + op_name="rocm_aiter_gemm_a8w8_blockscale", + op_func=_rocm_aiter_gemm_a8w8_blockscale_impl, mutates_args=[], - fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake, + fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake, dispatch_key=current_platform.dispatch_key, ) @@ -606,7 +612,7 @@ class rocm_aiter_ops: return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon) @staticmethod - def gemm_w8a8( + def gemm_a8w8( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -614,10 +620,10 @@ class rocm_aiter_ops: bias: torch.Tensor | None = None, output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype) + return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype) @staticmethod - def gemm_w8a8_blockscale( + def gemm_a8w8_blockscale( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -625,7 +631,7 @@ class rocm_aiter_ops: block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale( + return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale( A, B, As, Bs, output_dtype ) @@ -938,5 +944,4 @@ class rocm_aiter_ops: return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) -if IS_AITER_FOUND: - rocm_aiter_ops.register_ops_once() +rocm_aiter_ops.register_ops_once() diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index f5cd91469b78..038a92c516ce 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) + return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 03d086bda8e3..541c6c631053 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp: if use_triton: gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale else: - gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale + gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale if input_scale is not None: q_input = input_2d diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index b17bdd0b7207..68262a2703f9 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,6 +8,7 @@ import torch from vllm import _custom_ops as ops from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -105,8 +106,7 @@ def default_unquantized_gemm( def use_aiter_triton_gemm(n, m, k, dtype): if ( - envs.VLLM_ROCM_USE_AITER == 0 - or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0 + not rocm_aiter_ops.is_triton_gemm_enabled() # MI300's - fp8nuz=True or current_platform.is_fp8_fnuz() or dtype not in [torch.float16, torch.bfloat16] diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5fa8969b860e..d977d999de67 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -325,6 +325,7 @@ class RocmPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm._aiter_ops import rocm_aiter_ops from vllm.config.compilation import CUDAGraphMode cache_config = vllm_config.cache_config @@ -332,9 +333,7 @@ class RocmPlatform(Platform): parallel_config = vllm_config.parallel_config is_eager_execution = compilation_config == CUDAGraphMode.NONE - use_aiter_rms_norm = ( - envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM - ) + use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled() if cache_config and cache_config.block_size is None: cache_config.block_size = 16