mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
Add group as an argument in broadcast ops (#2522)
This commit is contained in:
parent
00efdc84ba
commit
5b23c3f26f
@ -1,6 +1,8 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
return output_tensor
|
||||
|
||||
|
||||
def broadcast(input_: torch.Tensor, src: int = 0):
|
||||
def broadcast(input_: torch.Tensor,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None):
|
||||
"""Broadcast the input tensor."""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return input_
|
||||
# Broadcast.
|
||||
torch.distributed.broadcast(input_, src=src)
|
||||
torch.distributed.broadcast(input_, src=src, group=group)
|
||||
return input_
|
||||
|
||||
|
||||
def broadcast_object_list(obj_list: List[Any], src: int = 0):
|
||||
def broadcast_object_list(obj_list: List[Any],
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None):
|
||||
"""Broadcast the input object list."""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return obj_list
|
||||
# Broadcast.
|
||||
torch.distributed.broadcast_object_list(obj_list, src=src)
|
||||
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
|
||||
return obj_list
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||
|
||||
|
||||
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||
Any]]] = None,
|
||||
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||
"""Broadcast the input tensor dictionary."""
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == src:
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||
(key, TensorMetadata(value.dtype, value.size())))
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
torch.distributed.broadcast_object_list([metadata_list], src=src)
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=group)
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
torch.distributed.broadcast(tensor, src=src)
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
src=src,
|
||||
group=group)
|
||||
metadata_list = recv_metadata_list[0]
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||
device="cuda")
|
||||
async_handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
async_op=True)
|
||||
async_op=True,
|
||||
group=group)
|
||||
async_handles.append(async_handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user