From d3cf61b89bc53aa7709932ab43e7630b9a71f2b3 Mon Sep 17 00:00:00 2001 From: Qiming Zhang Date: Tue, 29 Apr 2025 09:40:25 -0700 Subject: [PATCH] fix gemma3 results all zero (#17364) Signed-off-by: mayuyuace --- vllm/model_executor/layers/layernorm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 75a5317b10bad..87d9b959e643c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -241,7 +241,10 @@ class GemmaRMSNorm(CustomOp): """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: - x = x + residual + if orig_dtype == torch.float16: + x = x + residual.float() + else: + x = x + residual residual = x x = x.float()