mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:45:39 +08:00
[Bugfix] Fix distributed bug again in Qwen2.5-VL & Qwen2.5-Omni (#16974)
Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
This commit is contained in:
parent
4b91c927f6
commit
571e8dd65e
@ -198,8 +198,11 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
|
|
||||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||||
"""All-gather the input tensor interleavely across model parallel group."""
|
"""All-gather the input tensor interleavely across model parallel group."""
|
||||||
|
import torch.distributed as dist
|
||||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||||
parallel_state.get_tp_group().all_gather(gathered_tensors, local_tensor)
|
dist.all_gather(gathered_tensors,
|
||||||
|
local_tensor,
|
||||||
|
group=parallel_state.get_tp_group().device_group)
|
||||||
|
|
||||||
gathered_tensors_split = [
|
gathered_tensors_split = [
|
||||||
torch.split(tensor, hidden_size // tp_size, -1)
|
torch.split(tensor, hidden_size // tp_size, -1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user