[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-09-10 21:08:03 +08:00 committed by GitHub
parent 0ae43dbf8c
commit 7c195d43da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 108 additions and 36 deletions

View File

@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
# Registered subclass for test
@CustomOp.register("relu3")
@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
@pytest.mark.parametrize("add_residual", [True, False])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="AITER is a feature exclusive for ROCm")
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
use_rocm_aiter_norm: str, monkeypatch):
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
use_rocm_aiter: str, use_rocm_aiter_norm: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
if not add_residual:
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_rms_norm
else:
assert rms_norm_func == rms_norm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_norm):
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
else:
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
if add_residual and should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
elif should_use_rocm_aiter:
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
elif add_residual:
assert rms_norm_func == fused_add_rms_norm
else:
assert rms_norm_func == rms_norm

View File

@ -9,11 +9,11 @@ import torch.nn as nn
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER
@ -43,8 +43,8 @@ def fused_add_rms_norm(
return x, residual
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_fused_add_rms_norm(
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
return output, residual_out
def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return torch.empty_like(x)
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
torch.float16, torch.bfloat16
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
# fall back to CUDA implementation
if with_fused_add:
return fused_add_rms_norm
return rms_norm
@ -114,6 +148,13 @@ class RMSNorm(CustomOp):
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype)
def forward_native(
self,
@ -162,13 +203,27 @@ class RMSNorm(CustomOp):
return self.forward_native(x, residual)
add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
if add_residual:
return norm_func(x, residual, self.weight.data,
self.variance_epsilon)
return fused_add_rms_norm(x, residual, self.weight.data,
self.variance_epsilon)
else:
return norm_func(x, self.weight.data, self.variance_epsilon)
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
if add_residual:
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
self.variance_epsilon)
else:
return self.rocm_norm_func(x, self.weight.data,
self.variance_epsilon)
def forward_xpu(
self,

View File

@ -322,23 +322,35 @@ class RocmPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.config.compilation import CUDAGraphMode
cache_config = vllm_config.cache_config
compilation_config = vllm_config.compilation_config
parallel_config = vllm_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE
use_v1 = envs.VLLM_USE_V1
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
envs.VLLM_ROCM_USE_AITER_RMSNORM
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if vllm_config.speculative_config:
if not envs.VLLM_USE_V1:
if not use_v1:
raise NotImplementedError(
"Speculative decoding is not supported on vLLM V0.")
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
else:
if envs.VLLM_USE_V1:
if use_v1:
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if use_v1 and use_aiter_rms_norm and not is_eager_execution:
compilation_config.custom_ops.append("+rms_norm")
@classmethod
def verify_model_arch(cls, model_arch: str) -> None: