From 7c195d43da241d1ae07e73062c6fe593be3e4aac Mon Sep 17 00:00:00 2001 From: vllmellm Date: Wed, 10 Sep 2025 21:08:03 +0800 Subject: [PATCH] [ROCm][Bugfix] Fix Aiter RMSNorm (#23412) Signed-off-by: vllmellm --- .../model_executor/test_enabled_custom_ops.py | 37 ++++---- vllm/model_executor/layers/layernorm.py | 89 +++++++++++++++---- vllm/platforms/rocm.py | 18 +++- 3 files changed, 108 insertions(+), 36 deletions(-) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 140f00294765d..86139d598582d 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) -from vllm.model_executor.layers.layernorm import ( - RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, - rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.layernorm import (RMSNorm, + dispatch_rocm_rmsnorm_func, + fused_add_rms_norm, rms_norm) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform +RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16] + # Registered subclass for test @CustomOp.register("relu3") @@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): @pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) @pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) @pytest.mark.skipif(not current_platform.is_rocm(), reason="AITER is a feature exclusive for ROCm") -def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, - use_rocm_aiter_norm: str, monkeypatch): +def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype, + use_rocm_aiter: str, use_rocm_aiter_norm: str, + monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) - rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) + rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype) - if not add_residual: - if current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_rms_norm - else: - assert rms_norm_func == rms_norm - elif current_platform.is_rocm() and int(use_rocm_aiter) and int( - use_rocm_aiter_norm): - assert rms_norm_func == rocm_aiter_fused_add_rms_norm - else: + should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \ + and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES + + if add_residual and should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + elif should_use_rocm_aiter: + assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm + elif add_residual: assert rms_norm_func == fused_add_rms_norm + else: + assert rms_norm_func == rms_norm diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a5fc1db2dc10f..0488eab1e03fb 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -9,11 +9,11 @@ import torch.nn as nn import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op def is_rocm_aiter_rmsnorm_enabled() -> bool: - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + return envs.VLLM_ROCM_USE_AITER_RMSNORM \ and envs.VLLM_ROCM_USE_AITER @@ -43,8 +43,8 @@ def fused_add_rms_norm( return x, residual -def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, - variance_epsilon: float) -> torch.Tensor: +def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter if x.dim() > 2: x_original_shape = x.shape @@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, return rocm_aiter.rms_norm(x, weight, variance_epsilon) -def rocm_aiter_fused_add_rms_norm( +def rocm_aiter_rmsnorm2d_fwd_with_add_impl( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: @@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm( return output, residual_out -def dispatch_cuda_rmsnorm_func(add_residual: bool): - if add_residual: - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_fused_add_rms_norm - return fused_add_rms_norm +def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + return torch.empty_like(x) - if is_rocm_aiter_rmsnorm_enabled(): - return rocm_aiter_rms_norm + +def rocm_aiter_rmsnorm2d_fwd_with_add_fake( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(x), torch.empty_like(residual) + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_rms_norm", + op_func=rocm_aiter_rms_norm_impl, + mutates_args=[], + fake_impl=rocm_aiter_rms_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + + direct_register_custom_op( + op_name="rocm_aiter_rmsnorm2d_fwd_with_add", + op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl, + mutates_args=[], + fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake, + dispatch_key=current_platform.dispatch_key, + ) + + +def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype): + use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [ + torch.float16, torch.bfloat16 + ] + + if use_aiter and with_fused_add: + return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add + if use_aiter: + return torch.ops.vllm.rocm_aiter_rms_norm + + # fall back to CUDA implementation + if with_fused_add: + return fused_add_rms_norm return rms_norm @@ -114,6 +148,13 @@ class RMSNorm(CustomOp): self.weight = torch.ones(hidden_size) if self.has_weight: self.weight = nn.Parameter(self.weight) + weight_dtype = self.weight.data.dtype + + if current_platform.is_rocm(): + self.rocm_norm_func = dispatch_rocm_rmsnorm_func( + with_fused_add=False, dtype=weight_dtype) + self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func( + with_fused_add=True, dtype=weight_dtype) def forward_native( self, @@ -162,13 +203,27 @@ class RMSNorm(CustomOp): return self.forward_native(x, residual) add_residual = residual is not None - norm_func = dispatch_cuda_rmsnorm_func(add_residual) - if add_residual: - return norm_func(x, residual, self.weight.data, - self.variance_epsilon) + return fused_add_rms_norm(x, residual, self.weight.data, + self.variance_epsilon) else: - return norm_func(x, self.weight.data, self.variance_epsilon) + return rms_norm(x, self.weight.data, self.variance_epsilon) + + def forward_hip( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + add_residual = residual is not None + if add_residual: + return self.rocm_norm_func_with_add(x, residual, self.weight.data, + self.variance_epsilon) + else: + return self.rocm_norm_func(x, self.weight.data, + self.variance_epsilon) def forward_xpu( self, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index c6d14aa87c7f2..cf7b87cf030a9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -322,23 +322,35 @@ class RocmPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: + from vllm.config.compilation import CUDAGraphMode + cache_config = vllm_config.cache_config + compilation_config = vllm_config.compilation_config + parallel_config = vllm_config.parallel_config + is_eager_execution = compilation_config == CUDAGraphMode.NONE + + use_v1 = envs.VLLM_USE_V1 + use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \ + envs.VLLM_ROCM_USE_AITER_RMSNORM + if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": if vllm_config.speculative_config: - if not envs.VLLM_USE_V1: + if not use_v1: raise NotImplementedError( "Speculative decoding is not supported on vLLM V0.") parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" else: - if envs.VLLM_USE_V1: + if use_v1: parallel_config.worker_cls = \ "vllm.v1.worker.gpu_worker.Worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + # Aiter rms norm perform best when CUDA Graph capture is enabled. + if use_v1 and use_aiter_rms_norm and not is_eager_execution: + compilation_config.custom_ops.append("+rms_norm") @classmethod def verify_model_arch(cls, model_arch: str) -> None: