From 46b31ed98d8879261cdf9f81115519acec028342 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 1 Apr 2024 08:22:47 +0000 Subject: [PATCH] Fix RoPE output shape --- vllm/model_executor/models/tpu/gemma.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/tpu/gemma.py b/vllm/model_executor/models/tpu/gemma.py index 56c0701b864f6..f9aba2797a8ea 100644 --- a/vllm/model_executor/models/tpu/gemma.py +++ b/vllm/model_executor/models/tpu/gemma.py @@ -54,7 +54,8 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2) - return x_out + # Reshape the output tensor to the original shape. + return x_out.reshape(x_out.shape[0], x_out.shape[1], -1) class RMSNorm(torch.nn.Module):