diff --git a/vllm/model_executor/models/interns1_vit.py b/vllm/model_executor/models/interns1_vit.py index cfc8b7e6084e..507503d75046 100644 --- a/vllm/model_executor/models/interns1_vit.py +++ b/vllm/model_executor/models/interns1_vit.py @@ -217,16 +217,15 @@ class InternSdpaAttention(nn.Module): self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale) def forward(self, x: torch.Tensor) -> torch.Tensor: - B, N, C = x.shape + """x shape: (B, N, C)""" q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_) - k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_) + q = self.q_norm(q) + k = self.k_norm(k) # Use unified MultiHeadAttention with automatic backend selection x = self.attn(q, k, v)