Add group as an argument in broadcast ops (#2522)

This commit is contained in:
Junda Chen 2024-01-20 16:00:26 -08:00 committed by GitHub
parent 00efdc84ba
commit 5b23c3f26f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,8 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from torch.distributed import ProcessGroup
import torch import torch
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
return output_tensor return output_tensor
def broadcast(input_: torch.Tensor, src: int = 0): def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor.""" """Broadcast the input tensor."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return input_ return input_
# Broadcast. # Broadcast.
torch.distributed.broadcast(input_, src=src) torch.distributed.broadcast(input_, src=src, group=group)
return input_ return input_
def broadcast_object_list(obj_list: List[Any], src: int = 0): def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list.""" """Broadcast the input object list."""
world_size = torch.distributed.get_world_size() group = group or torch.distributed.group.WORLD
assert 0 <= src < world_size, f"Invalid src rank ({src})" ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return obj_list return obj_list
# Broadcast. # Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src) torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, def broadcast_tensor_dict(
Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]: src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
rank = torch.distributed.get_rank() group = group or torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size() ranks = torch.distributed.get_process_group_ranks(group)
assert 0 <= src < world_size, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1: if world_size == 1:
return tensor_dict return tensor_dict
rank = torch.distributed.get_rank()
if rank == src: if rank == src:
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
(key, TensorMetadata(value.dtype, value.size()))) (key, TensorMetadata(value.dtype, value.size())))
else: else:
metadata_list.append((key, value)) metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], src=src) torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src) torch.distributed.broadcast(tensor, src=src)
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, src=src) torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0] metadata_list = recv_metadata_list[0]
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
device="cuda") device="cuda")
async_handle = torch.distributed.broadcast(tensor, async_handle = torch.distributed.broadcast(tensor,
src=src, src=src,
async_op=True) async_op=True,
group=group)
async_handles.append(async_handle) async_handles.append(async_handle)
tensor_dict[key] = tensor tensor_dict[key] = tensor
else: else: