mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 06:24:30 +08:00
[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
0ae43dbf8c
commit
7c195d43da
@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
||||
vllm_topk_softmax)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.layernorm import (
|
||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||
dispatch_rocm_rmsnorm_func,
|
||||
fused_add_rms_norm, rms_norm)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
# Registered subclass for test
|
||||
@CustomOp.register("relu3")
|
||||
@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("add_residual", [True, False])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="AITER is a feature exclusive for ROCm")
|
||||
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
|
||||
use_rocm_aiter_norm: str, monkeypatch):
|
||||
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
|
||||
use_rocm_aiter: str, use_rocm_aiter_norm: str,
|
||||
monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
||||
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||
|
||||
if not add_residual:
|
||||
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
||||
use_rocm_aiter_norm):
|
||||
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
|
||||
else:
|
||||
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
|
||||
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||
|
||||
if add_residual and should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
elif should_use_rocm_aiter:
|
||||
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
|
||||
elif add_residual:
|
||||
assert rms_norm_func == fused_add_rms_norm
|
||||
else:
|
||||
assert rms_norm_func == rms_norm
|
||||
|
||||
@ -9,11 +9,11 @@ import torch.nn as nn
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
||||
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
@ -43,8 +43,8 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
def rocm_aiter_fused_add_rms_norm(
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
|
||||
return output, residual_out
|
||||
|
||||
|
||||
def dispatch_cuda_rmsnorm_func(add_residual: bool):
|
||||
if add_residual:
|
||||
if is_rocm_aiter_rmsnorm_enabled():
|
||||
return rocm_aiter_fused_add_rms_norm
|
||||
return fused_add_rms_norm
|
||||
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
if is_rocm_aiter_rmsnorm_enabled():
|
||||
return rocm_aiter_rms_norm
|
||||
|
||||
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(x), torch.empty_like(residual)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rms_norm",
|
||||
op_func=rocm_aiter_rms_norm_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rms_norm_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||
torch.float16, torch.bfloat16
|
||||
]
|
||||
|
||||
if use_aiter and with_fused_add:
|
||||
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||
if use_aiter:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm
|
||||
|
||||
# fall back to CUDA implementation
|
||||
if with_fused_add:
|
||||
return fused_add_rms_norm
|
||||
return rms_norm
|
||||
|
||||
|
||||
@ -114,6 +148,13 @@ class RMSNorm(CustomOp):
|
||||
self.weight = torch.ones(hidden_size)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
weight_dtype = self.weight.data.dtype
|
||||
|
||||
if current_platform.is_rocm():
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=False, dtype=weight_dtype)
|
||||
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||
with_fused_add=True, dtype=weight_dtype)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@ -162,13 +203,27 @@ class RMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
add_residual = residual is not None
|
||||
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
||||
|
||||
if add_residual:
|
||||
return norm_func(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
return fused_add_rms_norm(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
return norm_func(x, self.weight.data, self.variance_epsilon)
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
add_residual = residual is not None
|
||||
if add_residual:
|
||||
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
return self.rocm_norm_func(x, self.weight.data,
|
||||
self.variance_epsilon)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
|
||||
@ -322,23 +322,35 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
|
||||
envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if vllm_config.speculative_config:
|
||||
if not envs.VLLM_USE_V1:
|
||||
if not use_v1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not supported on vLLM V0.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
if use_v1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||
if use_v1 and use_aiter_rms_norm and not is_eager_execution:
|
||||
compilation_config.custom_ops.append("+rms_norm")
|
||||
|
||||
@classmethod
|
||||
def verify_model_arch(cls, model_arch: str) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user