From a810b5b088b898bdfe25606589af69da30e85ae6 Mon Sep 17 00:00:00 2001 From: TJian Date: Sun, 11 May 2025 19:17:11 +0800 Subject: [PATCH] [BugFix] [ROCm]: Bugfix and handle addition case of input for `rocm_aiter_rms_norm` (#17857) Signed-off-by: tjtanaa --- tests/models/language/generation/test_common.py | 4 ++++ vllm/model_executor/layers/layernorm.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index c755593c9acb..05dd18fbdf8b 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -28,6 +28,7 @@ AITER_MODEL_LIST = [ "Qwen/Qwen-7B-Chat", "Qwen/Qwen2.5-0.5B-Instruct", "TitanML/tiny-mixtral", + "Qwen/Qwen3-8B", ] @@ -78,6 +79,9 @@ AITER_MODEL_LIST = [ "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen3-8B", # qwen (text-only) + ), pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 87d9b959e643..cdf9ecc25107 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float) -> torch.Tensor: import aiter as rocm_aiter + if x.dim() > 2: + x_original_shape = x.shape + x = x.reshape(-1, x_original_shape[-1]) + x = rocm_aiter.rms_norm(x, weight, variance_epsilon) + return x.reshape(x_original_shape) + return rocm_aiter.rms_norm(x, weight, variance_epsilon) @@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm( import aiter as rocm_aiter - # Assuming the correct signature for rmsnorm2d_fwd_with_add + residual_out = torch.empty_like(residual) + output = torch.empty_like(x) rocm_aiter.rmsnorm2d_fwd_with_add( - x, # output + output, # output x, # input residual, # residual input - residual, # residual output + residual_out, # residual output weight, variance_epsilon, ) - return x, residual + return output, residual_out def dispatch_cuda_rmsnorm_func(add_residual: bool):