[bugfix] fix MHA for models like OpenGVLab/InternVL3_5-38B (#25146)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Yan Ma 2025-09-19 16:45:06 +08:00 committed by GitHub
parent f2718d2948
commit a684c0124c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA3
bsz, q_len, _ = query.size()
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_size)