diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 75a5317b10bad..87d9b959e643c 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -241,7 +241,10 @@ class GemmaRMSNorm(CustomOp): """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype if residual is not None: - x = x + residual + if orig_dtype == torch.float16: + x = x + residual.float() + else: + x = x + residual residual = x x = x.float()