mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 14:16:37 +08:00
[ROCm][Bugfix] Fix Aiter RMSNorm (#23412)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
0ae43dbf8c
commit
7c195d43da
@ -13,13 +13,15 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
|||||||
vllm_topk_softmax)
|
vllm_topk_softmax)
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled)
|
is_rocm_aiter_moe_enabled)
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (RMSNorm,
|
||||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
dispatch_rocm_rmsnorm_func,
|
||||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
fused_add_rms_norm, rms_norm)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
|
||||||
# Registered subclass for test
|
# Registered subclass for test
|
||||||
@CustomOp.register("relu3")
|
@CustomOp.register("relu3")
|
||||||
@ -149,24 +151,27 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("add_residual", [True, False])
|
@pytest.mark.parametrize("add_residual", [True, False])
|
||||||
|
@pytest.mark.parametrize("dtype",
|
||||||
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||||
reason="AITER is a feature exclusive for ROCm")
|
reason="AITER is a feature exclusive for ROCm")
|
||||||
def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str,
|
def test_rms_norm_dispatch(add_residual: bool, dtype: torch.dtype,
|
||||||
use_rocm_aiter_norm: str, monkeypatch):
|
use_rocm_aiter: str, use_rocm_aiter_norm: str,
|
||||||
|
monkeypatch):
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm)
|
||||||
rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
rms_norm_func = dispatch_rocm_rmsnorm_func(add_residual, dtype)
|
||||||
|
|
||||||
if not add_residual:
|
should_use_rocm_aiter = current_platform.is_rocm() and int(use_rocm_aiter) \
|
||||||
if current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
and int(use_rocm_aiter_norm) and dtype in RMS_NORM_SUPPORTED_DTYPES
|
||||||
use_rocm_aiter_norm):
|
|
||||||
assert rms_norm_func == rocm_aiter_rms_norm
|
if add_residual and should_use_rocm_aiter:
|
||||||
else:
|
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||||
assert rms_norm_func == rms_norm
|
elif should_use_rocm_aiter:
|
||||||
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
|
assert rms_norm_func == torch.ops.vllm.rocm_aiter_rms_norm
|
||||||
use_rocm_aiter_norm):
|
elif add_residual:
|
||||||
assert rms_norm_func == rocm_aiter_fused_add_rms_norm
|
|
||||||
else:
|
|
||||||
assert rms_norm_func == fused_add_rms_norm
|
assert rms_norm_func == fused_add_rms_norm
|
||||||
|
else:
|
||||||
|
assert rms_norm_func == rms_norm
|
||||||
|
|||||||
@ -9,11 +9,11 @@ import torch.nn as nn
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
def is_rocm_aiter_rmsnorm_enabled() -> bool:
|
||||||
return current_platform.is_rocm() \
|
return envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
||||||
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
|
|
||||||
and envs.VLLM_ROCM_USE_AITER
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
@ -43,8 +43,8 @@ def fused_add_rms_norm(
|
|||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
|
||||||
variance_epsilon: float) -> torch.Tensor:
|
variance_epsilon: float) -> torch.Tensor:
|
||||||
import aiter as rocm_aiter
|
import aiter as rocm_aiter
|
||||||
if x.dim() > 2:
|
if x.dim() > 2:
|
||||||
x_original_shape = x.shape
|
x_original_shape = x.shape
|
||||||
@ -55,7 +55,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
|||||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_add_rms_norm(
|
def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
|
||||||
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
@ -74,14 +74,48 @@ def rocm_aiter_fused_add_rms_norm(
|
|||||||
return output, residual_out
|
return output, residual_out
|
||||||
|
|
||||||
|
|
||||||
def dispatch_cuda_rmsnorm_func(add_residual: bool):
|
def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
|
||||||
if add_residual:
|
variance_epsilon: float) -> torch.Tensor:
|
||||||
if is_rocm_aiter_rmsnorm_enabled():
|
return torch.empty_like(x)
|
||||||
return rocm_aiter_fused_add_rms_norm
|
|
||||||
return fused_add_rms_norm
|
|
||||||
|
|
||||||
if is_rocm_aiter_rmsnorm_enabled():
|
|
||||||
return rocm_aiter_rms_norm
|
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
|
||||||
|
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
|
||||||
|
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return torch.empty_like(x), torch.empty_like(residual)
|
||||||
|
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rms_norm",
|
||||||
|
op_func=rocm_aiter_rms_norm_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_rms_norm_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||||
|
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
|
||||||
|
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
|
||||||
|
torch.float16, torch.bfloat16
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_aiter and with_fused_add:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
|
||||||
|
if use_aiter:
|
||||||
|
return torch.ops.vllm.rocm_aiter_rms_norm
|
||||||
|
|
||||||
|
# fall back to CUDA implementation
|
||||||
|
if with_fused_add:
|
||||||
|
return fused_add_rms_norm
|
||||||
return rms_norm
|
return rms_norm
|
||||||
|
|
||||||
|
|
||||||
@ -114,6 +148,13 @@ class RMSNorm(CustomOp):
|
|||||||
self.weight = torch.ones(hidden_size)
|
self.weight = torch.ones(hidden_size)
|
||||||
if self.has_weight:
|
if self.has_weight:
|
||||||
self.weight = nn.Parameter(self.weight)
|
self.weight = nn.Parameter(self.weight)
|
||||||
|
weight_dtype = self.weight.data.dtype
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||||
|
with_fused_add=False, dtype=weight_dtype)
|
||||||
|
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
|
||||||
|
with_fused_add=True, dtype=weight_dtype)
|
||||||
|
|
||||||
def forward_native(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
@ -162,13 +203,27 @@ class RMSNorm(CustomOp):
|
|||||||
return self.forward_native(x, residual)
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
add_residual = residual is not None
|
add_residual = residual is not None
|
||||||
norm_func = dispatch_cuda_rmsnorm_func(add_residual)
|
|
||||||
|
|
||||||
if add_residual:
|
if add_residual:
|
||||||
return norm_func(x, residual, self.weight.data,
|
return fused_add_rms_norm(x, residual, self.weight.data,
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
else:
|
else:
|
||||||
return norm_func(x, self.weight.data, self.variance_epsilon)
|
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if self.variance_size_override is not None:
|
||||||
|
return self.forward_native(x, residual)
|
||||||
|
|
||||||
|
add_residual = residual is not None
|
||||||
|
if add_residual:
|
||||||
|
return self.rocm_norm_func_with_add(x, residual, self.weight.data,
|
||||||
|
self.variance_epsilon)
|
||||||
|
else:
|
||||||
|
return self.rocm_norm_func(x, self.weight.data,
|
||||||
|
self.variance_epsilon)
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -322,23 +322,35 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
compilation_config = vllm_config.compilation_config
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||||
|
|
||||||
|
use_v1 = envs.VLLM_USE_V1
|
||||||
|
use_aiter_rms_norm = envs.VLLM_ROCM_USE_AITER and \
|
||||||
|
envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
|
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if vllm_config.speculative_config:
|
if vllm_config.speculative_config:
|
||||||
if not envs.VLLM_USE_V1:
|
if not use_v1:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Speculative decoding is not supported on vLLM V0.")
|
"Speculative decoding is not supported on vLLM V0.")
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
|
||||||
else:
|
else:
|
||||||
if envs.VLLM_USE_V1:
|
if use_v1:
|
||||||
parallel_config.worker_cls = \
|
parallel_config.worker_cls = \
|
||||||
"vllm.v1.worker.gpu_worker.Worker"
|
"vllm.v1.worker.gpu_worker.Worker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||||
|
# Aiter rms norm perform best when CUDA Graph capture is enabled.
|
||||||
|
if use_v1 and use_aiter_rms_norm and not is_eager_execution:
|
||||||
|
compilation_config.custom_ops.append("+rms_norm")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_model_arch(cls, model_arch: str) -> None:
|
def verify_model_arch(cls, model_arch: str) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user