[Bugfix] Fix interns1-vit qk norm code path (#27480)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-25 01:43:45 +08:00 committed by GitHub
parent 0f67d4d962
commit acc78aeb88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -217,16 +217,15 @@ class InternSdpaAttention(nn.Module):
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape """x shape: (B, N, C)"""
q = self.q_proj(x) q = self.q_proj(x)
k = self.k_proj(x) k = self.k_proj(x)
v = self.v_proj(x) v = self.v_proj(x)
if self.qk_normalization: if self.qk_normalization:
B_, N_, H_, D_ = q.shape q = self.q_norm(q)
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) k = self.k_norm(k)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
# Use unified MultiHeadAttention with automatic backend selection # Use unified MultiHeadAttention with automatic backend selection
x = self.attn(q, k, v) x = self.attn(q, k, v)