diff --git a/vllm/model_executor/models/tpu/gemma.py b/vllm/model_executor/models/tpu/gemma.py index 56c0701b864f6..f9aba2797a8ea 100644 --- a/vllm/model_executor/models/tpu/gemma.py +++ b/vllm/model_executor/models/tpu/gemma.py @@ -54,7 +54,8 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2) - return x_out + # Reshape the output tensor to the original shape. + return x_out.reshape(x_out.shape[0], x_out.shape[1], -1) class RMSNorm(torch.nn.Module):