Fix RoPE output shape

This commit is contained in:
Woosuk Kwon 2024-04-01 08:22:47 +00:00
parent 31d05f7edb
commit 46b31ed98d

View File

@ -54,7 +54,8 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
# Reshape the output tensor to the original shape.
return x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
class RMSNorm(torch.nn.Module):