mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Bugfix] Fix interns1-vit qk norm code path (#27480)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
0f67d4d962
commit
acc78aeb88
@ -217,16 +217,15 @@ class InternSdpaAttention(nn.Module):
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
"""x shape: (B, N, C)"""
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
x = self.attn(q, k, v)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user