mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 01:17:04 +08:00
Signed-off-by: Zifei Tong <zifeitong@gmail.com>
This commit is contained in:
parent
5b64ac21f9
commit
48b8456ff9
@ -50,7 +50,7 @@ from vllm.attention.layer import (
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@ -360,10 +360,21 @@ class Qwen2VisionAttention(nn.Module):
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = tensor_model_parallel_all_gather(qkv)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(
|
||||
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
|
||||
)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user