mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
[BugFix] [ROCm]: Bugfix and handle addition case of input for rocm_aiter_rms_norm (#17857)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
009b3d5382
commit
a810b5b088
@ -28,6 +28,7 @@ AITER_MODEL_LIST = [
|
||||
"Qwen/Qwen-7B-Chat",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"TitanML/tiny-mixtral",
|
||||
"Qwen/Qwen3-8B",
|
||||
]
|
||||
|
||||
|
||||
@ -78,6 +79,9 @@ AITER_MODEL_LIST = [
|
||||
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
|
||||
marks=[pytest.mark.core_model],
|
||||
),
|
||||
pytest.param(
|
||||
"Qwen/Qwen3-8B", # qwen (text-only)
|
||||
),
|
||||
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
|
||||
pytest.param("bigcode/starcoder2-3b"), # starcoder2
|
||||
pytest.param(
|
||||
|
||||
@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
x = x.reshape(-1, x_original_shape[-1])
|
||||
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
return x.reshape(x_original_shape)
|
||||
|
||||
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
|
||||
@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm(
|
||||
|
||||
import aiter as rocm_aiter
|
||||
|
||||
# Assuming the correct signature for rmsnorm2d_fwd_with_add
|
||||
residual_out = torch.empty_like(residual)
|
||||
output = torch.empty_like(x)
|
||||
rocm_aiter.rmsnorm2d_fwd_with_add(
|
||||
x, # output
|
||||
output, # output
|
||||
x, # input
|
||||
residual, # residual input
|
||||
residual, # residual output
|
||||
residual_out, # residual output
|
||||
weight,
|
||||
variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
return output, residual_out
|
||||
|
||||
|
||||
def dispatch_cuda_rmsnorm_func(add_residual: bool):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user