mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Bugfix] Fix interns1-vit qk norm code path (#27480)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
0f67d4d962
commit
acc78aeb88
@ -217,16 +217,15 @@ class InternSdpaAttention(nn.Module):
|
|||||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, N, C = x.shape
|
"""x shape: (B, N, C)"""
|
||||||
|
|
||||||
q = self.q_proj(x)
|
q = self.q_proj(x)
|
||||||
k = self.k_proj(x)
|
k = self.k_proj(x)
|
||||||
v = self.v_proj(x)
|
v = self.v_proj(x)
|
||||||
|
|
||||||
if self.qk_normalization:
|
if self.qk_normalization:
|
||||||
B_, N_, H_, D_ = q.shape
|
q = self.q_norm(q)
|
||||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
k = self.k_norm(k)
|
||||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
|
||||||
|
|
||||||
# Use unified MultiHeadAttention with automatic backend selection
|
# Use unified MultiHeadAttention with automatic backend selection
|
||||||
x = self.attn(q, k, v)
|
x = self.attn(q, k, v)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user