mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:55:46 +08:00
[bugfix] fix MHA for models like OpenGVLab/InternVL3_5-38B (#25146)
Signed-off-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
f2718d2948
commit
a684c0124c
@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
bsz, q_len, _ = query.size()
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user