mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-26 01:03:30 +08:00
Fix RoPE output shape
This commit is contained in:
parent
31d05f7edb
commit
46b31ed98d
@ -54,7 +54,8 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
|
||||
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
|
||||
-1).transpose(1, 2)
|
||||
return x_out
|
||||
# Reshape the output tensor to the original shape.
|
||||
return x_out.reshape(x_out.shape[0], x_out.shape[1], -1)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user