mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 01:37:04 +08:00
[Refactor] Remove redundant TP gather/split in split_qkv in QwenVL (#28271)
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
This commit is contained in:
parent
f76e85c299
commit
bc5bd45c7d
@ -291,25 +291,6 @@ class Qwen2_5_VisionMLP(nn.Module):
|
||||
return x_down
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(
|
||||
gathered_tensors, local_tensor, group=parallel_state.get_tp_group().device_group
|
||||
)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1) for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||
]
|
||||
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
||||
return result_tensor
|
||||
|
||||
|
||||
class Qwen2_5_VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -383,21 +364,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = all_gather_interleave(qkv, self.qkv.hidden_size, self.tp_size)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(
|
||||
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
|
||||
)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
|
||||
@ -50,7 +50,7 @@ from vllm.attention.layer import (
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@ -396,21 +396,10 @@ class Qwen2VisionAttention(nn.Module):
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = tensor_model_parallel_all_gather(qkv)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(
|
||||
dist_utils.split_tensor_along_last_dim, num_partitions=self.tp_size
|
||||
)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user