mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:15:01 +08:00
[platform] add base class for communicators (#13208)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
124776ebd5
commit
a0231b7c25
@ -0,0 +1,117 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceCommunicatorBase:
|
||||||
|
"""
|
||||||
|
Base class for device-specific communicator.
|
||||||
|
It can use the `cpu_group` to initialize the communicator.
|
||||||
|
If the device has PyTorch integration (PyTorch can recognize its
|
||||||
|
communication backend), the `device_group` will also be given.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cpu_group: ProcessGroup,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
device_group: Optional[ProcessGroup] = None,
|
||||||
|
unique_name: str = ""):
|
||||||
|
self.device = device or torch.device("cpu")
|
||||||
|
self.cpu_group = cpu_group
|
||||||
|
self.device_group = device_group
|
||||||
|
self.unique_name = unique_name
|
||||||
|
self.rank = dist.get_rank(cpu_group)
|
||||||
|
self.world_size = dist.get_world_size(cpu_group)
|
||||||
|
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||||
|
self.global_rank = dist.get_rank()
|
||||||
|
self.global_world_size = dist.get_world_size()
|
||||||
|
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||||
|
self.global_rank)
|
||||||
|
|
||||||
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
dist.all_reduce(input_, group=self.device_group)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
input_size = input_.size()
|
||||||
|
# NOTE: we have to use concat-style all-gather here,
|
||||||
|
# stack-style all-gather has compatibility issues with
|
||||||
|
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||||
|
output_size = (input_size[0] * self.world_size, ) + input_size[1:]
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty(output_size,
|
||||||
|
dtype=input_.dtype,
|
||||||
|
device=input_.device)
|
||||||
|
# All-gather.
|
||||||
|
dist.all_gather_into_tensor(output_tensor,
|
||||||
|
input_,
|
||||||
|
group=self.device_group)
|
||||||
|
# Reshape
|
||||||
|
output_tensor = output_tensor.reshape((self.world_size, ) + input_size)
|
||||||
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
|
(self.world_size *
|
||||||
|
input_size[dim], ) +
|
||||||
|
input_size[dim + 1:])
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def gather(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dst: int = 0,
|
||||||
|
dim: int = -1) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
NOTE: We assume that the input tensor is on the same device across
|
||||||
|
all the ranks.
|
||||||
|
NOTE: `dst` is the local rank of the destination rank.
|
||||||
|
"""
|
||||||
|
world_size = self.world_size
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Allocate output tensor.
|
||||||
|
if self.rank_in_group == dst:
|
||||||
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
else:
|
||||||
|
gather_list = None
|
||||||
|
# Gather.
|
||||||
|
torch.distributed.gather(input_,
|
||||||
|
gather_list,
|
||||||
|
dst=self.ranks[dst],
|
||||||
|
group=self.device_group)
|
||||||
|
if self.rank_in_group == dst:
|
||||||
|
output_tensor = torch.cat(gather_list, dim=dim)
|
||||||
|
else:
|
||||||
|
output_tensor = None
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
if dst is None:
|
||||||
|
dst = (self.rank_in_group + 1) % self.world_size
|
||||||
|
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||||
|
|
||||||
|
def recv(self,
|
||||||
|
size: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
src: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""Receives a tensor from the source rank."""
|
||||||
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
|
if src is None:
|
||||||
|
src = (self.rank_in_group - 1) % self.world_size
|
||||||
|
|
||||||
|
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||||
|
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
pass
|
||||||
33
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
33
vllm/distributed/device_communicators/cpu_communicator.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
|
|
||||||
|
class CpuCommunicator(DeviceCommunicatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cpu_group: ProcessGroup,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
device_group: Optional[ProcessGroup] = None,
|
||||||
|
unique_name: str = ""):
|
||||||
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||||||
|
self.ipex_available = False
|
||||||
|
self.dist_module = torch.distributed
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
self.ipex_available = True
|
||||||
|
self.dist_module = ipex.distributed
|
||||||
|
except ImportError:
|
||||||
|
"""
|
||||||
|
Intel IPEX not found. Falling back to PyTorch native
|
||||||
|
all_reduce for CPU (e.g. MacOS)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def all_reduce(self, input_):
|
||||||
|
return self.dist_module.all_reduce(input_, group=self.device_group)
|
||||||
106
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
106
vllm/distributed/device_communicators/cuda_communicator.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
|
|
||||||
|
class CudaCommunicator(DeviceCommunicatorBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cpu_group: ProcessGroup,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
device_group: Optional[ProcessGroup] = None,
|
||||||
|
unique_name: str = ""):
|
||||||
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||||||
|
if "pp" in unique_name:
|
||||||
|
# pipeline parallel does not need custom allreduce
|
||||||
|
use_custom_allreduce = False
|
||||||
|
else:
|
||||||
|
from vllm.distributed.parallel_state import (
|
||||||
|
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||||
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
|
use_pynccl = True
|
||||||
|
|
||||||
|
self.use_pynccl = use_pynccl
|
||||||
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
|
||||||
|
# lazy import to avoid documentation build error
|
||||||
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
|
CustomAllreduce)
|
||||||
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
PyNcclCommunicator)
|
||||||
|
|
||||||
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||||
|
if use_pynccl and self.world_size > 1:
|
||||||
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ca_comm: Optional[CustomAllreduce] = None
|
||||||
|
if use_custom_allreduce and self.world_size > 1:
|
||||||
|
# Initialize a custom fast all-reduce implementation.
|
||||||
|
self.ca_comm = CustomAllreduce(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def all_reduce(self, input_):
|
||||||
|
# always try custom allreduce first,
|
||||||
|
# and then pynccl.
|
||||||
|
ca_comm = self.ca_comm
|
||||||
|
if ca_comm is not None and not ca_comm.disabled and \
|
||||||
|
ca_comm.should_custom_ar(input_):
|
||||||
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
|
assert out is not None
|
||||||
|
return out
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
assert pynccl_comm is not None
|
||||||
|
out = pynccl_comm.all_reduce(input_)
|
||||||
|
if out is None:
|
||||||
|
# fall back to the default all-reduce using PyTorch.
|
||||||
|
# this usually happens during testing.
|
||||||
|
# when we run the model, allreduce only happens for the TP
|
||||||
|
# group, where we always have either custom allreduce or pynccl.
|
||||||
|
out = input_.clone()
|
||||||
|
torch.distributed.all_reduce(out, group=self.device_group)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
|
if dst is None:
|
||||||
|
dst = (self.rank_in_group + 1) % self.world_size
|
||||||
|
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
|
pynccl_comm.send(tensor, dst)
|
||||||
|
else:
|
||||||
|
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||||
|
|
||||||
|
def recv(self,
|
||||||
|
size: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
src: Optional[int] = None) -> torch.Tensor:
|
||||||
|
"""Receives a tensor from the source rank."""
|
||||||
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
|
if src is None:
|
||||||
|
src = (self.rank_in_group - 1) % self.world_size
|
||||||
|
|
||||||
|
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||||
|
pynccl_comm.recv(tensor, src)
|
||||||
|
else:
|
||||||
|
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
if self.pynccl_comm is not None:
|
||||||
|
self.pynccl_comm = None
|
||||||
|
if self.ca_comm is not None:
|
||||||
|
self.ca_comm = None
|
||||||
@ -2,45 +2,40 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
if current_platform.is_hpu():
|
if current_platform.is_hpu():
|
||||||
import habana_frameworks.torch as htorch # noqa: F401
|
import habana_frameworks.torch as htorch # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class HpuCommunicator:
|
class HpuCommunicator(DeviceCommunicatorBase):
|
||||||
|
|
||||||
def __init__(self, group: ProcessGroup):
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
if not current_platform.is_hpu():
|
|
||||||
self.disabled = True
|
|
||||||
return
|
|
||||||
self.disabled = False
|
|
||||||
self.group = group
|
|
||||||
self.world_size = dist.get_world_size(self.group)
|
|
||||||
|
|
||||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||||
# (which is required for tensor parallel HPUGraph inference)
|
# (which is required for tensor parallel HPUGraph inference)
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
dist.all_reduce(x, group=self.group)
|
dist.all_reduce(input_, group=self.device_group)
|
||||||
return x
|
return input_
|
||||||
|
|
||||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
# Convert negative dim to positive.
|
# Convert negative dim to positive.
|
||||||
dim += x.dim()
|
dim += input_.dim()
|
||||||
input_size = x.size()
|
input_size = input_.size()
|
||||||
# Allocate output tensor.
|
# Allocate output tensor.
|
||||||
output_tensor = torch.empty((world_size, ) + input_size,
|
output_tensor = torch.empty((world_size, ) + input_size,
|
||||||
dtype=x.dtype,
|
dtype=input_.dtype,
|
||||||
device=x.device)
|
device=input_.device)
|
||||||
# All-gather.
|
# All-gather.
|
||||||
htorch.core.mark_step()
|
htorch.core.mark_step()
|
||||||
dist.all_gather_into_tensor(output_tensor, x, group=self.group)
|
dist.all_gather_into_tensor(output_tensor,
|
||||||
|
input_,
|
||||||
|
group=self.device_group)
|
||||||
# Reshape
|
# Reshape
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
@ -16,19 +18,20 @@ if current_platform.is_tpu():
|
|||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
|
|
||||||
|
|
||||||
class TpuCommunicator:
|
class TpuCommunicator(DeviceCommunicatorBase):
|
||||||
|
|
||||||
def __init__(self, group: ProcessGroup):
|
def __init__(self,
|
||||||
if not current_platform.is_tpu():
|
cpu_group: ProcessGroup,
|
||||||
self.disabled = True
|
device: Optional[torch.device] = None,
|
||||||
return
|
device_group: Optional[ProcessGroup] = None,
|
||||||
self.disabled = False
|
unique_name: str = ""):
|
||||||
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||||||
|
|
||||||
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||||
# must be used together. Therefore, the local rank and world size can
|
# must be used together. Therefore, the local rank and world size can
|
||||||
# be simply calculated as follows.
|
# be simply calculated as follows.
|
||||||
global_rank = dist.get_rank(group)
|
global_rank = self.global_rank
|
||||||
global_world_size = dist.get_world_size(group)
|
global_world_size = self.global_world_size
|
||||||
|
|
||||||
# Calculate how many TPU nodes are in the current deployment. This
|
# Calculate how many TPU nodes are in the current deployment. This
|
||||||
# is the Ray placement group if it is deployed with Ray. Default
|
# is the Ray placement group if it is deployed with Ray. Default
|
||||||
@ -55,9 +58,9 @@ class TpuCommunicator:
|
|||||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||||
xr._init_world_size_ordinal()
|
xr._init_world_size_ordinal()
|
||||||
|
|
||||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
return xm.all_reduce(xm.REDUCE_SUM, input_)
|
||||||
|
|
||||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||||
return xm.all_gather(x, dim=dim)
|
return xm.all_gather(input_, dim=dim)
|
||||||
|
|||||||
@ -1,49 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
class XpuCommunicator:
|
|
||||||
|
|
||||||
def __init__(self, group: ProcessGroup):
|
|
||||||
if not current_platform.is_xpu():
|
|
||||||
self.disabled = True
|
|
||||||
return
|
|
||||||
self.disabled = False
|
|
||||||
self.group = group
|
|
||||||
self.world_size = dist.get_world_size(self.group)
|
|
||||||
|
|
||||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
dist.all_reduce(x, group=self.group)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def gather(self,
|
|
||||||
input_: torch.Tensor,
|
|
||||||
rank_in_group: int,
|
|
||||||
dst: int = 0,
|
|
||||||
dim: int = -1):
|
|
||||||
# For xpu path, gather doesn't work properly together with ray
|
|
||||||
# cluster so we use all_gather instead for now.
|
|
||||||
input_size = input_.size()
|
|
||||||
# Allocate output tensor.
|
|
||||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
|
||||||
dtype=input_.dtype,
|
|
||||||
device=input_.device)
|
|
||||||
# All-gather.
|
|
||||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
|
||||||
input_,
|
|
||||||
group=self.group)
|
|
||||||
if rank_in_group == dst:
|
|
||||||
# Reshape
|
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
||||||
(self.world_size *
|
|
||||||
input_size[dim], ) +
|
|
||||||
input_size[dim + 1:])
|
|
||||||
else:
|
|
||||||
output_tensor = None
|
|
||||||
return output_tensor
|
|
||||||
@ -39,9 +39,12 @@ from torch.distributed import Backend, ProcessGroup
|
|||||||
|
|
||||||
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||||
|
DeviceCommunicatorBase)
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import direct_register_custom_op, supports_custom_op
|
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
|
||||||
|
supports_custom_op)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -130,9 +133,8 @@ class GroupCoordinator:
|
|||||||
PyTorch ProcessGroup is bound to one specific communication backend,
|
PyTorch ProcessGroup is bound to one specific communication backend,
|
||||||
e.g. NCCL, Gloo, MPI, etc.
|
e.g. NCCL, Gloo, MPI, etc.
|
||||||
GroupCoordinator takes charge of all the communication operations among
|
GroupCoordinator takes charge of all the communication operations among
|
||||||
the processes in the group. It can route the communication to
|
the processes in the group. It manages both CPU and device
|
||||||
a specific implementation (e.g. switch allreduce implementation
|
communication.
|
||||||
based on the tensor size and cuda graph mode).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# available attributes:
|
# available attributes:
|
||||||
@ -150,11 +152,8 @@ class GroupCoordinator:
|
|||||||
rank_in_group: int # rank inside the group
|
rank_in_group: int # rank inside the group
|
||||||
cpu_group: ProcessGroup # group for CPU communication
|
cpu_group: ProcessGroup # group for CPU communication
|
||||||
device_group: ProcessGroup # group for device communication
|
device_group: ProcessGroup # group for device communication
|
||||||
use_pynccl: bool # a hint of whether to use PyNccl
|
use_device_communicator: bool # whether to use device communicator
|
||||||
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
device_communicator: DeviceCommunicatorBase # device communicator
|
||||||
# communicators are only created for world size > 1
|
|
||||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
|
||||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
|
||||||
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -162,11 +161,7 @@ class GroupCoordinator:
|
|||||||
group_ranks: List[List[int]],
|
group_ranks: List[List[int]],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
torch_distributed_backend: Union[str, Backend],
|
torch_distributed_backend: Union[str, Backend],
|
||||||
use_pynccl: bool,
|
use_device_communicator: bool,
|
||||||
use_custom_allreduce: bool,
|
|
||||||
use_tpu_communicator: bool,
|
|
||||||
use_hpu_communicator: bool,
|
|
||||||
use_xpu_communicator: bool,
|
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -196,56 +191,26 @@ class GroupCoordinator:
|
|||||||
assert self.device_group is not None
|
assert self.device_group is not None
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# TODO: fix it for other platforms
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
self.device = torch.device(f"cuda:{local_rank}")
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
else:
|
else:
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
self.use_device_communicator = use_device_communicator
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
|
||||||
self.use_tpu_communicator = use_tpu_communicator
|
|
||||||
self.use_hpu_communicator = use_hpu_communicator
|
|
||||||
self.use_xpu_communicator = use_xpu_communicator
|
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
self.device_communicator: DeviceCommunicatorBase = None # type: ignore
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
if use_device_communicator and self.world_size > 1:
|
||||||
CustomAllreduce)
|
device_comm_cls = resolve_obj_by_qualname(
|
||||||
from vllm.distributed.device_communicators.pynccl import (
|
current_platform.get_device_communicator_cls())
|
||||||
PyNcclCommunicator)
|
self.device_communicator = device_comm_cls(
|
||||||
|
cpu_group=self.cpu_group,
|
||||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
|
||||||
if use_pynccl and self.world_size > 1:
|
|
||||||
self.pynccl_comm = PyNcclCommunicator(
|
|
||||||
group=self.cpu_group,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
device_group=self.device_group,
|
||||||
|
unique_name=self.unique_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ca_comm: Optional[CustomAllreduce] = None
|
|
||||||
if use_custom_allreduce and self.world_size > 1:
|
|
||||||
# Initialize a custom fast all-reduce implementation.
|
|
||||||
self.ca_comm = CustomAllreduce(
|
|
||||||
group=self.cpu_group,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.tpu_communicator import (
|
|
||||||
TpuCommunicator)
|
|
||||||
self.tpu_communicator: Optional[TpuCommunicator] = None
|
|
||||||
if use_tpu_communicator and self.world_size > 1:
|
|
||||||
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.hpu_communicator import (
|
|
||||||
HpuCommunicator)
|
|
||||||
self.hpu_communicator: Optional[HpuCommunicator]
|
|
||||||
if use_hpu_communicator and self.world_size > 1:
|
|
||||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.xpu_communicator import (
|
|
||||||
XpuCommunicator)
|
|
||||||
self.xpu_communicator: Optional[XpuCommunicator]
|
|
||||||
if use_xpu_communicator and self.world_size > 1:
|
|
||||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||||
MessageQueue)
|
MessageQueue)
|
||||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||||
@ -253,6 +218,9 @@ class GroupCoordinator:
|
|||||||
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
self.mq_broadcaster = MessageQueue.create_from_process_group(
|
||||||
self.cpu_group, 1 << 22, 6)
|
self.cpu_group, 1 << 22, 6)
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
self.use_custom_op_call = current_platform.is_cuda_alike()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first_rank(self):
|
def first_rank(self):
|
||||||
"""Return the global rank of the first process in the group"""
|
"""Return the global rank of the first process in the group"""
|
||||||
@ -296,9 +264,16 @@ class GroupCoordinator:
|
|||||||
else:
|
else:
|
||||||
stream = graph_capture_context.stream
|
stream = graph_capture_context.stream
|
||||||
|
|
||||||
ca_comm = self.ca_comm
|
# only cuda uses this function,
|
||||||
maybe_ca_context = nullcontext(
|
# so we don't abstract it into the base class
|
||||||
) if ca_comm is None else ca_comm.capture()
|
maybe_ca_context = nullcontext()
|
||||||
|
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||||
|
CudaCommunicator)
|
||||||
|
if self.device_communicator is not None:
|
||||||
|
assert isinstance(self.device_communicator, CudaCommunicator)
|
||||||
|
ca_comm = self.device_communicator.ca_comm
|
||||||
|
if ca_comm is not None:
|
||||||
|
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||||
|
|
||||||
# ensure all initialization operations complete before attempting to
|
# ensure all initialization operations complete before attempting to
|
||||||
# capture the graph on another stream
|
# capture the graph on another stream
|
||||||
@ -328,54 +303,14 @@ class GroupCoordinator:
|
|||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
if input_.is_cpu:
|
if self.use_custom_op_call:
|
||||||
try:
|
return torch.ops.vllm.all_reduce(input_,
|
||||||
import intel_extension_for_pytorch as ipex
|
group_name=self.unique_name)
|
||||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
else:
|
||||||
return input_
|
return self._all_reduce_out_place(input_)
|
||||||
except ImportError:
|
|
||||||
"""
|
|
||||||
Intel IPEX not found. Falling back to PyTorch native
|
|
||||||
all_reduce for CPU
|
|
||||||
"""
|
|
||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
|
||||||
return input_
|
|
||||||
|
|
||||||
if self.tpu_communicator is not None and \
|
|
||||||
not self.tpu_communicator.disabled:
|
|
||||||
# TPU handles Dynamo with its own logic.
|
|
||||||
return self.tpu_communicator.all_reduce(input_)
|
|
||||||
|
|
||||||
if self.hpu_communicator is not None and \
|
|
||||||
not self.hpu_communicator.disabled:
|
|
||||||
return self.hpu_communicator.all_reduce(input_)
|
|
||||||
|
|
||||||
if self.xpu_communicator is not None and \
|
|
||||||
not self.xpu_communicator.disabled:
|
|
||||||
return self.xpu_communicator.all_reduce(input_)
|
|
||||||
|
|
||||||
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
|
||||||
|
|
||||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
# always try custom allreduce first,
|
return self.device_communicator.all_reduce(input_)
|
||||||
# and then pynccl.
|
|
||||||
ca_comm = self.ca_comm
|
|
||||||
if ca_comm is not None and not ca_comm.disabled and \
|
|
||||||
ca_comm.should_custom_ar(input_):
|
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
|
||||||
assert out is not None
|
|
||||||
return out
|
|
||||||
pynccl_comm = self.pynccl_comm
|
|
||||||
assert pynccl_comm is not None
|
|
||||||
out = pynccl_comm.all_reduce(input_)
|
|
||||||
if out is None:
|
|
||||||
# fall back to the default all-reduce using PyTorch.
|
|
||||||
# this usually happens during testing.
|
|
||||||
# when we run the model, allreduce only happens for the TP
|
|
||||||
# group, where we always have either custom allreduce or pynccl.
|
|
||||||
out = input_.clone()
|
|
||||||
torch.distributed.all_reduce(out, group=self.device_group)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
@ -385,40 +320,7 @@ class GroupCoordinator:
|
|||||||
assert -input_.dim() <= dim < input_.dim(), (
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
# For TPUs, use TPU communicator.
|
return self.device_communicator.all_gather(input_, dim)
|
||||||
tpu_comm = self.tpu_communicator
|
|
||||||
if tpu_comm is not None and not tpu_comm.disabled:
|
|
||||||
return tpu_comm.all_gather(input_, dim)
|
|
||||||
|
|
||||||
# For HPUs, use HPU communicator.
|
|
||||||
hpu_comm = self.hpu_communicator
|
|
||||||
if hpu_comm is not None and not hpu_comm.disabled:
|
|
||||||
return hpu_comm.all_gather(input_, dim)
|
|
||||||
|
|
||||||
if dim < 0:
|
|
||||||
# Convert negative dim to positive.
|
|
||||||
dim += input_.dim()
|
|
||||||
input_size = input_.size()
|
|
||||||
# NOTE: we have to use concat-style all-gather here,
|
|
||||||
# stack-style all-gather has compatibility issues with
|
|
||||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
|
||||||
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
|
||||||
# Allocate output tensor.
|
|
||||||
output_tensor = torch.empty(output_size,
|
|
||||||
dtype=input_.dtype,
|
|
||||||
device=input_.device)
|
|
||||||
# All-gather.
|
|
||||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
|
||||||
input_,
|
|
||||||
group=self.device_group)
|
|
||||||
# Reshape
|
|
||||||
output_tensor = output_tensor.reshape((world_size, ) + input_size)
|
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
||||||
(world_size *
|
|
||||||
input_size[dim], ) +
|
|
||||||
input_size[dim + 1:])
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
@ -433,30 +335,7 @@ class GroupCoordinator:
|
|||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
assert -input_.dim() <= dim < input_.dim(), (
|
return self.device_communicator.gather(input_, dst, dim)
|
||||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
||||||
if dim < 0:
|
|
||||||
# Convert negative dim to positive.
|
|
||||||
dim += input_.dim()
|
|
||||||
if self.xpu_communicator is not None and \
|
|
||||||
not self.xpu_communicator.disabled:
|
|
||||||
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
|
||||||
dst, dim)
|
|
||||||
# Allocate output tensor.
|
|
||||||
if self.rank_in_group == dst:
|
|
||||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
else:
|
|
||||||
gather_list = None
|
|
||||||
# Gather.
|
|
||||||
torch.distributed.gather(input_,
|
|
||||||
gather_list,
|
|
||||||
dst=self.ranks[dst],
|
|
||||||
group=self.device_group)
|
|
||||||
if self.rank_in_group == dst:
|
|
||||||
output_tensor = torch.cat(gather_list, dim=dim)
|
|
||||||
else:
|
|
||||||
output_tensor = None
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||||
"""Broadcast the input tensor.
|
"""Broadcast the input tensor.
|
||||||
@ -798,14 +677,7 @@ class GroupCoordinator:
|
|||||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
if dst is None:
|
self.device_communicator.send(tensor, dst)
|
||||||
dst = (self.rank_in_group + 1) % self.world_size
|
|
||||||
|
|
||||||
pynccl_comm = self.pynccl_comm
|
|
||||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
||||||
pynccl_comm.send(tensor, dst)
|
|
||||||
else:
|
|
||||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
|
||||||
|
|
||||||
def recv(self,
|
def recv(self,
|
||||||
size: torch.Size,
|
size: torch.Size,
|
||||||
@ -813,16 +685,7 @@ class GroupCoordinator:
|
|||||||
src: Optional[int] = None) -> torch.Tensor:
|
src: Optional[int] = None) -> torch.Tensor:
|
||||||
"""Receives a tensor from the source rank."""
|
"""Receives a tensor from the source rank."""
|
||||||
"""NOTE: `src` is the local rank of the source rank."""
|
"""NOTE: `src` is the local rank of the source rank."""
|
||||||
if src is None:
|
return self.device_communicator.recv(size, dtype, src)
|
||||||
src = (self.rank_in_group - 1) % self.world_size
|
|
||||||
|
|
||||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
|
||||||
pynccl_comm = self.pynccl_comm
|
|
||||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
|
||||||
pynccl_comm.recv(tensor, src)
|
|
||||||
else:
|
|
||||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def destroy(self):
|
def destroy(self):
|
||||||
if self.device_group is not None:
|
if self.device_group is not None:
|
||||||
@ -831,10 +694,8 @@ class GroupCoordinator:
|
|||||||
if self.cpu_group is not None:
|
if self.cpu_group is not None:
|
||||||
torch.distributed.destroy_process_group(self.cpu_group)
|
torch.distributed.destroy_process_group(self.cpu_group)
|
||||||
self.cpu_group = None
|
self.cpu_group = None
|
||||||
if self.pynccl_comm is not None:
|
if self.device_communicator is not None:
|
||||||
self.pynccl_comm = None
|
self.device_communicator.destroy()
|
||||||
if self.ca_comm is not None:
|
|
||||||
self.ca_comm = None
|
|
||||||
if self.mq_broadcaster is not None:
|
if self.mq_broadcaster is not None:
|
||||||
self.mq_broadcaster = None
|
self.mq_broadcaster = None
|
||||||
|
|
||||||
@ -853,11 +714,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
group_ranks=[ranks],
|
group_ranks=[ranks],
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=False,
|
use_device_communicator=False,
|
||||||
use_custom_allreduce=False,
|
|
||||||
use_tpu_communicator=False,
|
|
||||||
use_hpu_communicator=False,
|
|
||||||
use_xpu_communicator=False,
|
|
||||||
group_name="world",
|
group_name="world",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -866,23 +723,15 @@ def init_model_parallel_group(
|
|||||||
group_ranks: List[List[int]],
|
group_ranks: List[List[int]],
|
||||||
local_rank: int,
|
local_rank: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
use_custom_allreduce: Optional[bool] = None,
|
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
) -> GroupCoordinator:
|
) -> GroupCoordinator:
|
||||||
if use_custom_allreduce is None:
|
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
return GroupCoordinator(
|
return GroupCoordinator(
|
||||||
group_ranks=group_ranks,
|
group_ranks=group_ranks,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=current_platform.is_cuda_alike(),
|
use_device_communicator=True,
|
||||||
use_custom_allreduce=current_platform.is_cuda_alike()
|
|
||||||
and use_custom_allreduce,
|
|
||||||
use_tpu_communicator=True,
|
|
||||||
use_hpu_communicator=True,
|
|
||||||
use_xpu_communicator=True,
|
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
)
|
)
|
||||||
@ -1053,11 +902,9 @@ def initialize_model_parallel(
|
|||||||
for i in range(num_pipeline_model_parallel_groups):
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||||
group_ranks.append(ranks)
|
group_ranks.append(ranks)
|
||||||
# pipeline parallel does not need custom allreduce
|
|
||||||
_PP = init_model_parallel_group(group_ranks,
|
_PP = init_model_parallel_group(group_ranks,
|
||||||
get_world_group().local_rank,
|
get_world_group().local_rank,
|
||||||
backend,
|
backend,
|
||||||
use_custom_allreduce=False,
|
|
||||||
group_name="pp")
|
group_name="pp")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -146,3 +146,10 @@ class CpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
"""
|
||||||
|
Get device specific communicator class for distributed communication.
|
||||||
|
"""
|
||||||
|
return "vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator" # noqa
|
||||||
|
|||||||
@ -233,6 +233,10 @@ class CudaPlatformBase(Platform):
|
|||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -88,3 +88,7 @@ class HpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
|
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa
|
||||||
|
|||||||
@ -322,6 +322,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
"""
|
||||||
|
Get device specific communicator class for distributed communication.
|
||||||
|
"""
|
||||||
|
return "vllm.distributed.device_communicator.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -186,3 +186,7 @@ class RocmPlatform(Platform):
|
|||||||
torch.cuda.reset_peak_memory_stats(device)
|
torch.cuda.reset_peak_memory_stats(device)
|
||||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
||||||
device)[0]
|
device)[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||||
|
|||||||
@ -115,3 +115,7 @@ class TpuPlatform(Platform):
|
|||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
logger.warning("Pin memory is not supported on TPU.")
|
logger.warning("Pin memory is not supported on TPU.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user