From ec870fba9a59b8287fa205e4c35def4d3d153080 Mon Sep 17 00:00:00 2001 From: TJian Date: Sat, 22 Mar 2025 13:36:14 +0800 Subject: [PATCH] [FEAT] [ROCm]: Add AITER RMS Norm (Layer Norm) Feature (#14959) Signed-off-by: tjtanaa --- Dockerfile.rocm_base | 16 +++- .../model_executor/test_enabled_custom_ops.py | 29 +++++- .../decoder_only/language/test_models.py | 50 ++++++++-- vllm/envs.py | 13 +++ vllm/model_executor/layers/layernorm.py | 94 +++++++++++++++---- 5 files changed, 173 insertions(+), 29 deletions(-) diff --git a/Dockerfile.rocm_base b/Dockerfile.rocm_base index e33e73b303098..38d6a33636eba 100644 --- a/Dockerfile.rocm_base +++ b/Dockerfile.rocm_base @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="b7d29fb" ARG FA_REPO="https://github.com/ROCm/flash-attention.git" +ARG AITER_BRANCH="21d47a9" +ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl +ARG AITER_REPO +ARG AITER_BRANCH +RUN git clone --recursive ${AITER_REPO} +RUN cd aiter \ + && git checkout ${AITER_BRANCH} \ + && git submodule update --init --recursive \ + && pip install -r requirements.txt \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter + ARG BASE_IMAGE ARG HIPBLASLT_BRANCH +ARG HIPBLAS_COMMON_BRANCH ARG LEGACY_HIPBLASLT_OPTION ARG RCCL_BRANCH ARG RCCL_REPO @@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ - && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ + && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ + && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 4a6a766b8ca0b..24147b741278b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -7,7 +7,10 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import ( + RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, + rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.platforms import current_platform # Registered subclass for test @@ -87,3 +90,27 @@ def test_enabled_ops_invalid(env: str): custom_ops=env.split(","))) with set_current_vllm_config(vllm_config): RMSNorm(1024).enabled() + + +@pytest.mark.parametrize("add_residual", [True, False]) +@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) +@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"]) +@pytest.mark.skipif(not current_platform.is_rocm(), + reason="AITER is a feature exclusive for ROCm") +def test_rms_norm_dispatch(add_residual: bool, use_rocm_aiter: str, + use_rocm_aiter_norm: str, monkeypatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + monkeypatch.setenv("VLLM_ROCM_USE_AITER_RMSNORM", use_rocm_aiter_norm) + rms_norm_func = dispatch_cuda_rmsnorm_func(add_residual) + + if not add_residual: + if current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_norm): + assert rms_norm_func == rocm_aiter_rms_norm + else: + assert rms_norm_func == rms_norm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_norm): + assert rms_norm_func == rocm_aiter_fused_add_rms_norm + else: + assert rms_norm_func == fused_add_rms_norm diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/decoder_only/language/test_models.py index a49926ea220e8..79fa3fa997738 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/decoder_only/language/test_models.py @@ -3,7 +3,11 @@ Run `pytest tests/models/test_models.py`. """ + import pytest +import torch + +from vllm.platforms import current_platform from ...utils import check_logprobs_close @@ -13,7 +17,21 @@ from ...utils import check_logprobs_close # https://github.com/vllm-project/vllm/issues/14524 REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] +# This list contains the model that are using AITER kernel. +# Skip model that are not using AITER tests. +# When more AITER kernels are added, this list will not be +# needed as all the models will be calling AITER kernels +# in parts of the operators +AITER_MODEL_LIST = [ + "meta-llama/Llama-3.2-1B-Instruct", + "openbmb/MiniCPM3-4B", + "Qwen/Qwen-7B", + "Qwen/Qwen2.5-0.5B-Instruct", + "ehristoforu/Falcon3-MoE-2x7B-Insruct", +] + +# @maybe_test_rocm_aiter @pytest.mark.parametrize( "model", [ @@ -69,19 +87,24 @@ REQUIRES_V0 = ["microsoft/phi-2", "stabilityai/stablelm-3b-4e1t"] @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, - monkeypatch, -) -> None: +@pytest.mark.parametrize( + "use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_models(hf_runner, vllm_runner, example_prompts, model: str, + dtype: str, max_tokens: int, num_logprobs: int, + use_rocm_aiter: bool, monkeypatch) -> None: + if model in REQUIRES_V0: monkeypatch.setenv("VLLM_USE_V1", "0") + if use_rocm_aiter and (model in AITER_MODEL_LIST): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + elif use_rocm_aiter and model not in AITER_MODEL_LIST: + # Skip model that are not using AITER tests. + # When more AITER kernels are added, this list will not be + # needed as all the models will be calling AITER kernels + # in parts of the operators + pytest.skip(f"Skipping '{model}' model test with AITER kernel.") + with hf_runner(model, dtype=dtype) as hf_model: if model.startswith("THUDM/chatglm3"): hf_model.model.get_output_embeddings = lambda: \ @@ -100,3 +123,10 @@ def test_models( name_0="hf", name_1="vllm", ) + if use_rocm_aiter: + # this is to ensure that vllm engine + # has deallocated the memory before running the next + # unit tests. On ROCm, when using AITER + # the memory might not be deallocated completely + # before running the next test case + torch.cuda.synchronize() diff --git a/vllm/envs.py b/vllm/envs.py index d54de9da25315..7c07940c26c26 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,8 @@ if TYPE_CHECKING: VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True + VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -528,6 +530,17 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), + # Disable aiter ops unless specifically enabled. + # Acts as a parent switch to enable the rest of the other operations. + "VLLM_ROCM_USE_AITER": + lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in + ("true", "1")), + + # use aiter rms norm op if aiter ops are enabled. + "VLLM_ROCM_USE_AITER_RMSNORM": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in + ("true", "1")), + # Pad the fp8 weights to 256 bytes for ROCm "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b476fb0dbc7eb..76d3acb92fb81 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -5,7 +5,77 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn +import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform + + +def is_rocm_aiter_rmsnorm_enabled() -> bool: + return current_platform.is_rocm() \ + and envs.VLLM_ROCM_USE_AITER_RMSNORM \ + and envs.VLLM_ROCM_USE_AITER + + +def rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + + +def fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + ops.fused_add_rms_norm( + x, + residual, + weight, + variance_epsilon, + ) + return x, residual + + +def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> torch.Tensor: + + import aiter as rocm_aiter + return rocm_aiter.rms_norm(x, weight, variance_epsilon) + + +def rocm_aiter_fused_add_rms_norm( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]: + + import aiter as rocm_aiter + + # Assuming the correct signature for rmsnorm2d_fwd_with_add + rocm_aiter.rmsnorm2d_fwd_with_add( + x, # output + x, # input + residual, # residual input + residual, # residual output + weight, + variance_epsilon, + ) + return x, residual + + +def dispatch_cuda_rmsnorm_func(add_residual: bool): + if add_residual: + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_fused_add_rms_norm + return fused_add_rms_norm + + if is_rocm_aiter_rmsnorm_enabled(): + return rocm_aiter_rms_norm + return rms_norm @CustomOp.register("rms_norm") @@ -81,24 +151,14 @@ class RMSNorm(CustomOp): if self.variance_size_override is not None: return self.forward_native(x, residual) - from vllm import _custom_ops as ops + add_residual = residual is not None + norm_func = dispatch_cuda_rmsnorm_func(add_residual) - if residual is not None: - ops.fused_add_rms_norm( - x, - residual, - self.weight.data, - self.variance_epsilon, - ) - return x, residual - out = torch.empty_like(x) - ops.rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - return out + if add_residual: + return norm_func(x, residual, self.weight.data, + self.variance_epsilon) + else: + return norm_func(x, self.weight.data, self.variance_epsilon) def forward_hpu( self,