fix gemma3 results all zero (#17364)

Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
Qiming Zhang 2025-04-29 09:40:25 -07:00 committed by GitHub
parent a39203f99e
commit d3cf61b89b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -241,7 +241,10 @@ class GemmaRMSNorm(CustomOp):
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
x = x + residual
if orig_dtype == torch.float16:
x = x + residual.float()
else:
x = x + residual
residual = x
x = x.float()