mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:08:49 +08:00
Add group as an argument in broadcast ops (#2522)
This commit is contained in:
parent
00efdc84ba
commit
5b23c3f26f
@ -1,6 +1,8 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
@ -86,47 +88,59 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
|
|||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
def broadcast(input_: torch.Tensor, src: int = 0):
|
def broadcast(input_: torch.Tensor,
|
||||||
|
src: int = 0,
|
||||||
|
group: Optional[ProcessGroup] = None):
|
||||||
"""Broadcast the input tensor."""
|
"""Broadcast the input tensor."""
|
||||||
world_size = torch.distributed.get_world_size()
|
group = group or torch.distributed.group.WORLD
|
||||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
ranks = torch.distributed.get_process_group_ranks(group)
|
||||||
|
assert src in ranks, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
world_size = torch.distributed.get_world_size(group=group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
# Broadcast.
|
# Broadcast.
|
||||||
torch.distributed.broadcast(input_, src=src)
|
torch.distributed.broadcast(input_, src=src, group=group)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
|
||||||
def broadcast_object_list(obj_list: List[Any], src: int = 0):
|
def broadcast_object_list(obj_list: List[Any],
|
||||||
|
src: int = 0,
|
||||||
|
group: Optional[ProcessGroup] = None):
|
||||||
"""Broadcast the input object list."""
|
"""Broadcast the input object list."""
|
||||||
world_size = torch.distributed.get_world_size()
|
group = group or torch.distributed.group.WORLD
|
||||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
ranks = torch.distributed.get_process_group_ranks(group)
|
||||||
|
assert src in ranks, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
world_size = torch.distributed.get_world_size(group=group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return obj_list
|
return obj_list
|
||||||
# Broadcast.
|
# Broadcast.
|
||||||
torch.distributed.broadcast_object_list(obj_list, src=src)
|
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
|
||||||
return obj_list
|
return obj_list
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||||
|
|
||||||
|
|
||||||
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
def broadcast_tensor_dict(
|
||||||
Any]]] = None,
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||||
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
|
src: int = 0,
|
||||||
|
group: Optional[ProcessGroup] = None,
|
||||||
|
) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||||
"""Broadcast the input tensor dictionary."""
|
"""Broadcast the input tensor dictionary."""
|
||||||
rank = torch.distributed.get_rank()
|
group = group or torch.distributed.group.WORLD
|
||||||
world_size = torch.distributed.get_world_size()
|
ranks = torch.distributed.get_process_group_ranks(group)
|
||||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
assert src in ranks, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
world_size = torch.distributed.get_world_size(group=group)
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
|
||||||
|
rank = torch.distributed.get_rank()
|
||||||
if rank == src:
|
if rank == src:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
@ -141,14 +155,18 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
|||||||
(key, TensorMetadata(value.dtype, value.size())))
|
(key, TensorMetadata(value.dtype, value.size())))
|
||||||
else:
|
else:
|
||||||
metadata_list.append((key, value))
|
metadata_list.append((key, value))
|
||||||
torch.distributed.broadcast_object_list([metadata_list], src=src)
|
torch.distributed.broadcast_object_list([metadata_list],
|
||||||
|
src=src,
|
||||||
|
group=group)
|
||||||
for key, value in metadata_list:
|
for key, value in metadata_list:
|
||||||
if isinstance(value, TensorMetadata):
|
if isinstance(value, TensorMetadata):
|
||||||
tensor = tensor_dict[key]
|
tensor = tensor_dict[key]
|
||||||
torch.distributed.broadcast(tensor, src=src)
|
torch.distributed.broadcast(tensor, src=src)
|
||||||
else:
|
else:
|
||||||
recv_metadata_list = [None]
|
recv_metadata_list = [None]
|
||||||
torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
|
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||||
|
src=src,
|
||||||
|
group=group)
|
||||||
metadata_list = recv_metadata_list[0]
|
metadata_list = recv_metadata_list[0]
|
||||||
tensor_dict = {}
|
tensor_dict = {}
|
||||||
async_handles = []
|
async_handles = []
|
||||||
@ -159,7 +177,8 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
|||||||
device="cuda")
|
device="cuda")
|
||||||
async_handle = torch.distributed.broadcast(tensor,
|
async_handle = torch.distributed.broadcast(tensor,
|
||||||
src=src,
|
src=src,
|
||||||
async_op=True)
|
async_op=True,
|
||||||
|
group=group)
|
||||||
async_handles.append(async_handle)
|
async_handles.append(async_handle)
|
||||||
tensor_dict[key] = tensor
|
tensor_dict[key] = tensor
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user