[Bugfix] Fix distributed bug again in Qwen2.5-VL & Qwen2.5-Omni (#16974)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
This commit is contained in:
Yang Fan 2025-04-22 20:23:17 +08:00 committed by GitHub
parent 4b91c927f6
commit 571e8dd65e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -198,8 +198,11 @@ class Qwen2_5_VisionMLP(nn.Module):
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group.""" """All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
parallel_state.get_tp_group().all_gather(gathered_tensors, local_tensor) dist.all_gather(gathered_tensors,
local_tensor,
group=parallel_state.get_tp_group().device_group)
gathered_tensors_split = [ gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1) torch.split(tensor, hidden_size // tp_size, -1)