[platform] add base class for communicators (#13208)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-16 22:14:22 +08:00 committed by GitHub
parent 124776ebd5
commit a0231b7c25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 364 additions and 282 deletions

View File

@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
It can use the `cpu_group` to initialize the communicator.
If the device has PyTorch integration (PyTorch can recognize its
communication backend), the `device_group` will also be given.
"""
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
self.global_rank = dist.get_rank()
self.global_world_size = dist.get_world_size()
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def gather(self,
input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> Optional[torch.Tensor]:
"""
NOTE: We assume that the input tensor is on the same device across
all the ranks.
NOTE: `dst` is the local rank of the destination rank.
"""
world_size = self.world_size
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
self.ipex_available = False
self.dist_module = torch.distributed
try:
import intel_extension_for_pytorch as ipex
self.ipex_available = True
self.dist_module = ipex.distributed
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU (e.g. MacOS)
"""
pass
def all_reduce(self, input_):
return self.dist_module.all_reduce(input_, group=self.device_group)

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
class CudaCommunicator(DeviceCommunicatorBase):
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_pynccl = True
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None

View File

@ -2,45 +2,40 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_hpu():
import habana_frameworks.torch as htorch # noqa: F401
class HpuCommunicator:
class HpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup):
if not current_platform.is_hpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
# (which is required for tensor parallel HPUGraph inference)
htorch.core.mark_step()
dist.all_reduce(x, group=self.group)
return x
dist.all_reduce(input_, group=self.device_group)
return input_
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
if dim < 0:
# Convert negative dim to positive.
dim += x.dim()
input_size = x.size()
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=x.dtype,
device=x.device)
dtype=input_.dtype,
device=input_.device)
# All-gather.
htorch.core.mark_step()
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
dist.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +

View File

@ -1,13 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
@ -16,19 +18,20 @@ if current_platform.is_tpu():
from vllm.executor import ray_utils
class TpuCommunicator:
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
def __init__(self,
cpu_group: ProcessGroup,
device: Optional[torch.device] = None,
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
global_rank = self.global_rank
global_world_size = self.global_world_size
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
@ -55,9 +58,9 @@ class TpuCommunicator:
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)
return xm.all_gather(input_, dim=dim)

View File

@ -1,49 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
class XpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_xpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x
def gather(self,
input_: torch.Tensor,
rank_in_group: int,
dst: int = 0,
dim: int = -1):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((self.world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.group)
if rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor

View File

@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op, supports_custom_op
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@ -130,9 +133,8 @@ class GroupCoordinator:
PyTorch ProcessGroup is bound to one specific communication backend,
e.g. NCCL, Gloo, MPI, etc.
GroupCoordinator takes charge of all the communication operations among
the processes in the group. It can route the communication to
a specific implementation (e.g. switch allreduce implementation
based on the tensor size and cuda graph mode).
the processes in the group. It manages both CPU and device
communication.
"""
# available attributes:
@ -150,11 +152,8 @@ class GroupCoordinator:
rank_in_group: int # rank inside the group
cpu_group: ProcessGroup # group for CPU communication
device_group: ProcessGroup # group for device communication
use_pynccl: bool # a hint of whether to use PyNccl
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
use_device_communicator: bool # whether to use device communicator
device_communicator: DeviceCommunicatorBase # device communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
@ -162,11 +161,7 @@ class GroupCoordinator:
group_ranks: List[List[int]],
local_rank: int,
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
@ -196,56 +191,26 @@ class GroupCoordinator:
assert self.device_group is not None
from vllm.platforms import current_platform
# TODO: fix it for other platforms
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
self.use_device_communicator = use_device_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls())
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
)
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.hpu_communicator import (
HpuCommunicator)
self.hpu_communicator: Optional[HpuCommunicator]
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator)
self.xpu_communicator: Optional[XpuCommunicator]
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
@ -253,6 +218,9 @@ class GroupCoordinator:
self.mq_broadcaster = MessageQueue.create_from_process_group(
self.cpu_group, 1 << 22, 6)
from vllm.platforms import current_platform
self.use_custom_op_call = current_platform.is_cuda_alike()
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
@ -296,9 +264,16 @@ class GroupCoordinator:
else:
stream = graph_capture_context.stream
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
@ -328,54 +303,14 @@ class GroupCoordinator:
if self.world_size == 1:
return input_
if input_.is_cpu:
try:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
return input_
except ImportError:
"""
Intel IPEX not found. Falling back to PyTorch native
all_reduce for CPU
"""
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self.tpu_communicator.all_reduce(input_)
if self.hpu_communicator is not None and \
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
if self.use_custom_op_call:
return torch.ops.vllm.all_reduce(input_,
group_name=self.unique_name)
else:
return self._all_reduce_out_place(input_)
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
# always try custom allreduce first,
# and then pynccl.
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
return self.device_communicator.all_reduce(input_)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
world_size = self.world_size
@ -385,40 +320,7 @@ class GroupCoordinator:
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
# For HPUs, use HPU communicator.
hpu_comm = self.hpu_communicator
if hpu_comm is not None and not hpu_comm.disabled:
return hpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# NOTE: we have to use concat-style all-gather here,
# stack-style all-gather has compatibility issues with
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((world_size, ) + input_size)
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
return self.device_communicator.all_gather(input_, dim)
def gather(self,
input_: torch.Tensor,
@ -433,30 +335,7 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
return self.device_communicator.gather(input_, dst, dim)
def broadcast(self, input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor.
@ -798,14 +677,7 @@ class GroupCoordinator:
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
self.device_communicator.send(tensor, dst)
def recv(self,
size: torch.Size,
@ -813,16 +685,7 @@ class GroupCoordinator:
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""
if src is None:
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if self.device_group is not None:
@ -831,10 +694,8 @@ class GroupCoordinator:
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
self.cpu_group = None
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.device_communicator is not None:
self.device_communicator.destroy()
if self.mq_broadcaster is not None:
self.mq_broadcaster = None
@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
group_ranks=[ranks],
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
use_device_communicator=False,
group_name="world",
)
@ -866,23 +723,15 @@ def init_model_parallel_group(
group_ranks: List[List[int]],
local_rank: int,
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
from vllm.platforms import current_platform
return GroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_pynccl=current_platform.is_cuda_alike(),
use_custom_allreduce=current_platform.is_cuda_alike()
and use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_device_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
@ -1053,11 +902,9 @@ def initialize_model_parallel(
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
# pipeline parallel does not need custom allreduce
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False,
group_name="pp")

View File

@ -146,3 +146,10 @@ class CpuPlatform(Platform):
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
"""
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa

View File

@ -233,6 +233,10 @@ class CudaPlatformBase(Platform):
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -88,3 +88,7 @@ class HpuPlatform(Platform):
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa

View File

@ -322,6 +322,13 @@ class Platform:
"""
raise NotImplementedError
@classmethod
def get_device_communicator_cls(cls) -> str:
"""
Get device specific communicator class for distributed communication.
"""
return "vllm.distributed.device_communicator.base_device_communicator.DeviceCommunicatorBase" # noqa
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -186,3 +186,7 @@ class RocmPlatform(Platform):
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
device)[0]
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa

View File

@ -115,3 +115,7 @@ class TpuPlatform(Platform):
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
return False
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa