[BugFix] [ROCm]: Bugfix and handle addition case of input for rocm_aiter_rms_norm (#17857)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-05-11 19:17:11 +08:00 committed by GitHub
parent 009b3d5382
commit a810b5b088
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 4 deletions

View File

@ -28,6 +28,7 @@ AITER_MODEL_LIST = [
"Qwen/Qwen-7B-Chat",
"Qwen/Qwen2.5-0.5B-Instruct",
"TitanML/tiny-mixtral",
"Qwen/Qwen3-8B",
]
@ -78,6 +79,9 @@ AITER_MODEL_LIST = [
"Qwen/Qwen2.5-0.5B-Instruct", # qwen2
marks=[pytest.mark.core_model],
),
pytest.param(
"Qwen/Qwen3-8B", # qwen (text-only)
),
pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm
pytest.param("bigcode/starcoder2-3b"), # starcoder2
pytest.param(

View File

@ -46,6 +46,12 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
return rocm_aiter.rms_norm(x, weight, variance_epsilon)
@ -55,16 +61,17 @@ def rocm_aiter_fused_add_rms_norm(
import aiter as rocm_aiter
# Assuming the correct signature for rmsnorm2d_fwd_with_add
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
x, # output
output, # output
x, # input
residual, # residual input
residual, # residual output
residual_out, # residual output
weight,
variance_epsilon,
)
return x, residual
return output, residual_out
def dispatch_cuda_rmsnorm_func(add_residual: bool):