[Bugfix] Revert Qwen2-VL part of change in #28271 (#30542)

Signed-off-by: Zifei Tong <zifeitong@gmail.com>
This commit is contained in:
zifeitong 2025-12-14 05:20:08 -08:00 committed by GitHub
parent 5b64ac21f9
commit 48b8456ff9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,7 +50,7 @@ from vllm.attention.layer import (
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import QuickGELU
@ -360,10 +360,21 @@ class Qwen2VisionAttention(nn.Module):
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = tensor_model_parallel_all_gather(qkv)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (
seq_len,