mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:56:08 +08:00
[bugfix] fix MHA for models like OpenGVLab/InternVL3_5-38B (#25146)
Signed-off-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
f2718d2948
commit
a684c0124c
@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
"""Input shape:
|
||||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
(batch_size x seq_len x hidden_size) or
|
||||||
bsz, q_len, _ = query.size()
|
(batch_size x seq_len x num_heads x head_size)
|
||||||
|
"""
|
||||||
|
bsz, q_len = query.size()[:2]
|
||||||
kv_len = key.size(1)
|
kv_len = key.size(1)
|
||||||
|
|
||||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user