mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:18:39 +08:00
[platform] add base class for communicators (#13208)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
124776ebd5
commit
a0231b7c25
@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
It can use the `cpu_group` to initialize the communicator.
|
||||
If the device has PyTorch integration (PyTorch can recognize its
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
33
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
33
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
self.ipex_available = False
|
||||
self.dist_module = torch.distributed
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
self.ipex_available = True
|
||||
self.dist_module = ipex.distributed
|
||||
except ImportError:
|
||||
"""
|
||||
Intel IPEX not found. Falling back to PyTorch native
|
||||
all_reduce for CPU (e.g. MacOS)
|
||||
"""
|
||||
pass
|
||||
|
||||
def all_reduce(self, input_):
|
||||
return self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
106
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
106
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
|
||||
class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "pp" in unique_name:
|
||||
# pipeline parallel does not need custom allreduce
|
||||
use_custom_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_pynccl = True
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def all_reduce(self, input_):
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
ca_comm = self.ca_comm
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.send(tensor, dst)
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
dtype: torch.dtype,
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.recv(tensor, src)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
@ -2,45 +2,40 @@
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
if current_platform.is_hpu():
|
||||
import habana_frameworks.torch as htorch # noqa: F401
|
||||
|
||||
|
||||
class HpuCommunicator:
|
||||
class HpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not current_platform.is_hpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += x.dim()
|
||||
input_size = x.size()
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
htorch.core.mark_step()
|
||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
||||
dist.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
@ -16,19 +18,20 @@ if current_platform.is_tpu():
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator:
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not current_platform.is_tpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
|
||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||
# must be used together. Therefore, the local rank and world size can
|
||||
# be simply calculated as follows.
|
||||
global_rank = dist.get_rank(group)
|
||||
global_world_size = dist.get_world_size(group)
|
||||
global_rank = self.global_rank
|
||||
global_world_size = self.global_world_size
|
||||
|
||||
# Calculate how many TPU nodes are in the current deployment. This
|
||||
# is the Ray placement group if it is deployed with Ray. Default
|
||||
@ -55,9 +58,9 @@ class TpuCommunicator:
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, input_)
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(x, dim=dim)
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
|
||||
@ -1,49 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class XpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not current_platform.is_xpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
rank_in_group: int,
|
||||
dst: int = 0,
|
||||
dim: int = -1):
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.group)
|
||||
if rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
DeviceCommunicatorBase)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
||||
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||
supports_custom_op)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -130,9 +133,8 @@ class GroupCoordinator:
|
||||
PyTorch ProcessGroup is bound to one specific communication backend,
|
||||
e.g. NCCL, Gloo, MPI, etc.
|
||||
GroupCoordinator takes charge of all the communication operations among
|
||||
the processes in the group. It can route the communication to
|
||||
a specific implementation (e.g. switch allreduce implementation
|
||||
based on the tensor size and cuda graph mode).
|
||||
the processes in the group. It manages both CPU and device
|
||||
communication.
|
||||
"""
|
||||
|
||||
# available attributes:
|
||||
@ -150,11 +152,8 @@ class GroupCoordinator:
|
||||
rank_in_group: int # rank inside the group
|
||||
cpu_group: ProcessGroup # group for CPU communication
|
||||
device_group: ProcessGroup # group for device communication
|
||||
use_pynccl: bool # a hint of whether to use PyNccl
|
||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
use_device_communicator: bool # whether to use device communicator
|
||||
device_communicator: DeviceCommunicatorBase # device communicator
|
||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
@ -162,11 +161,7 @@ class GroupCoordinator:
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_tpu_communicator: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_device_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
@ -196,56 +191,26 @@ class GroupCoordinator:
|
||||
assert self.device_group is not None
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# TODO: fix it for other platforms
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_tpu_communicator = use_tpu_communicator
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
self.use_device_communicator = use_device_communicator
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
current_platform.get_device_communicator_cls())
|
||||
self.device_communicator = device_comm_cls(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
device_group=self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
)
|
||||
|
||||
self.ca_comm: Optional[CustomAllreduce] = None
|
||||
if use_custom_allreduce and self.world_size > 1:
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
self.ca_comm = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.tpu_communicator import (
|
||||
TpuCommunicator)
|
||||
self.tpu_communicator: Optional[TpuCommunicator] = None
|
||||
if use_tpu_communicator and self.world_size > 1:
|
||||
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
||||
|
||||
from vllm.distributed.device_communicators.hpu_communicator import (
|
||||
HpuCommunicator)
|
||||
self.hpu_communicator: Optional[HpuCommunicator]
|
||||
if use_hpu_communicator and self.world_size > 1:
|
||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
||||
|
||||
from vllm.distributed.device_communicators.xpu_communicator import (
|
||||
XpuCommunicator)
|
||||
self.xpu_communicator: Optional[XpuCommunicator]
|
||||
if use_xpu_communicator and self.world_size > 1:
|
||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
MessageQueue)
|
||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||
@ -253,6 +218,9 @@ class GroupCoordinator:
|
||||
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
||||
self.cpu_group, 1 << 22, 6)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
self.use_custom_op_call = current_platform.is_cuda_alike()
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@ -296,9 +264,16 @@ class GroupCoordinator:
|
||||
else:
|
||||
stream = graph_capture_context.stream
|
||||
|
||||
ca_comm = self.ca_comm
|
||||
maybe_ca_context = nullcontext(
|
||||
) if ca_comm is None else ca_comm.capture()
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
if self.device_communicator is not None:
|
||||
assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
|
||||
# ensure all initialization operations complete before attempting to
|
||||
# capture the graph on another stream
|
||||
@ -328,54 +303,14 @@ class GroupCoordinator:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if input_.is_cpu:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
except ImportError:
|
||||
"""
|
||||
Intel IPEX not found. Falling back to PyTorch native
|
||||
all_reduce for CPU
|
||||
"""
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
if self.tpu_communicator is not None and \
|
||||
not self.tpu_communicator.disabled:
|
||||
# TPU handles Dynamo with its own logic.
|
||||
return self.tpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.hpu_communicator is not None and \
|
||||
not self.hpu_communicator.disabled:
|
||||
return self.hpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
||||
if self.use_custom_op_call:
|
||||
return torch.ops.vllm.all_reduce(input_,
|
||||
group_name=self.unique_name)
|
||||
else:
|
||||
return self._all_reduce_out_place(input_)
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
ca_comm = self.ca_comm
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
return self.device_communicator.all_reduce(input_)
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
@ -385,40 +320,7 @@ class GroupCoordinator:
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_gather(input_, dim)
|
||||
|
||||
# For HPUs, use HPU communicator.
|
||||
hpu_comm = self.hpu_communicator
|
||||
if hpu_comm is not None and not hpu_comm.disabled:
|
||||
return hpu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(output_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
||||
return self.device_communicator.all_gather(input_, dim)
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
@ -433,30 +335,7 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
||||
dst, dim)
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
return self.device_communicator.gather(input_, dst, dim)
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||
"""Broadcast the input tensor.
|
||||
@ -798,14 +677,7 @@ class GroupCoordinator:
|
||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.send(tensor, dst)
|
||||
else:
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
self.device_communicator.send(tensor, dst)
|
||||
|
||||
def recv(self,
|
||||
size: torch.Size,
|
||||
@ -813,16 +685,7 @@ class GroupCoordinator:
|
||||
src: Optional[int] = None) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.recv(tensor, src)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
if self.device_group is not None:
|
||||
@ -831,10 +694,8 @@ class GroupCoordinator:
|
||||
if self.cpu_group is not None:
|
||||
torch.distributed.destroy_process_group(self.cpu_group)
|
||||
self.cpu_group = None
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
if self.ca_comm is not None:
|
||||
self.ca_comm = None
|
||||
if self.device_communicator is not None:
|
||||
self.device_communicator.destroy()
|
||||
if self.mq_broadcaster is not None:
|
||||
self.mq_broadcaster = None
|
||||
|
||||
@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
group_ranks=[ranks],
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=False,
|
||||
use_custom_allreduce=False,
|
||||
use_tpu_communicator=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
)
|
||||
|
||||
@ -866,23 +723,15 @@ def init_model_parallel_group(
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
backend: str,
|
||||
use_custom_allreduce: Optional[bool] = None,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=current_platform.is_cuda_alike(),
|
||||
use_custom_allreduce=current_platform.is_cuda_alike()
|
||||
and use_custom_allreduce,
|
||||
use_tpu_communicator=True,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_device_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
@ -1053,11 +902,9 @@ def initialize_model_parallel(
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group_ranks.append(ranks)
|
||||
# pipeline parallel does not need custom allreduce
|
||||
_PP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False,
|
||||
group_name="pp")
|
||||
|
||||
|
||||
|
||||
@ -146,3 +146,10 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||
|
||||
@ -233,6 +233,10 @@ class CudaPlatformBase(Platform):
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
||||
@ -88,3 +88,7 @@ class HpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
|
||||
|
||||
@ -322,6 +322,13 @@ class Platform:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
"""
|
||||
Get device specific communicator class for distributed communication.
|
||||
"""
|
||||
return "vllm.distributed.device_communicator.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -186,3 +186,7 @@ class RocmPlatform(Platform):
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
||||
device)[0]
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||
|
||||
@ -115,3 +115,7 @@ class TpuPlatform(Platform):
|
||||
def is_pin_memory_available(cls):
|
||||
logger.warning("Pin memory is not supported on TPU.")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user