mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 16:56:32 +08:00
210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import os
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from vllm.distributed.utils import pickle
|
|
from vllm.platforms import current_platform
|
|
from vllm.platforms.interface import CpuArchEnum
|
|
|
|
from .base_device_communicator import DeviceCommunicatorBase
|
|
|
|
|
|
class CpuCommunicator(DeviceCommunicatorBase):
|
|
def __init__(
|
|
self,
|
|
cpu_group: ProcessGroup,
|
|
device: Optional[torch.device] = None,
|
|
device_group: Optional[ProcessGroup] = None,
|
|
unique_name: str = "",
|
|
):
|
|
super().__init__(cpu_group, device, device_group, unique_name)
|
|
self.dist_module = torch.distributed
|
|
|
|
if (
|
|
(current_platform.get_cpu_architecture() == CpuArchEnum.X86)
|
|
and hasattr(torch.ops._C, "init_shm_manager")
|
|
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
|
):
|
|
self.dist_module = _CPUSHMDistributed(self)
|
|
|
|
def all_reduce(self, input_):
|
|
self.dist_module.all_reduce(input_, group=self.device_group)
|
|
return input_
|
|
|
|
def gather(
|
|
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
NOTE: We assume that the input tensor is on the same device across
|
|
all the ranks.
|
|
NOTE: `dst` is the local rank of the destination rank.
|
|
"""
|
|
world_size = self.world_size
|
|
assert -input_.dim() <= dim < input_.dim(), (
|
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
|
)
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
|
|
# Allocate output tensor.
|
|
if self.rank_in_group == dst:
|
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
else:
|
|
gather_list = None
|
|
|
|
# Gather.
|
|
self.dist_module.gather(
|
|
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
|
)
|
|
|
|
if self.rank_in_group == dst:
|
|
output_tensor = torch.cat(gather_list, dim=dim)
|
|
else:
|
|
output_tensor = None
|
|
return output_tensor
|
|
|
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
if dim < 0:
|
|
# Convert negative dim to positive.
|
|
dim += input_.dim()
|
|
input_size = input_.size()
|
|
# NOTE: we have to use concat-style all-gather here,
|
|
# stack-style all-gather has compatibility issues with
|
|
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
|
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
|
# Allocate output tensor.
|
|
output_tensor = torch.empty(
|
|
output_size, dtype=input_.dtype, device=input_.device
|
|
)
|
|
# All-gather.
|
|
self.dist_module.all_gather_into_tensor(
|
|
output_tensor, input_, group=self.device_group
|
|
)
|
|
|
|
# Reshape
|
|
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
|
output_tensor = output_tensor.movedim(0, dim)
|
|
output_tensor = output_tensor.reshape(
|
|
input_size[:dim]
|
|
+ (self.world_size * input_size[dim],)
|
|
+ input_size[dim + 1 :]
|
|
)
|
|
return output_tensor
|
|
|
|
def send_tensor_dict(
|
|
self,
|
|
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
|
dst: int,
|
|
) -> None:
|
|
return self.dist_module.send_tensor_dict(tensor_dict, dst)
|
|
|
|
def recv_tensor_dict(
|
|
self,
|
|
src: int,
|
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
|
return self.dist_module.recv_tensor_dict(src)
|
|
|
|
|
|
class _CPUSHMDistributed:
|
|
def __init__(self, communicator: CpuCommunicator):
|
|
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
|
unique_name = communicator.unique_name
|
|
instance_identifier = f"{instance_identifier}-{unique_name}"
|
|
self.communicator = communicator
|
|
|
|
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
|
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
|
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
|
|
|
self.handle = self._init_cpu_shm()
|
|
|
|
def _init_cpu_shm(self) -> int:
|
|
handle = torch.ops._C.init_shm_manager(
|
|
self.group_name,
|
|
self.communicator.world_size,
|
|
self.communicator.rank,
|
|
)
|
|
torch.distributed.barrier(self.communicator.device_group)
|
|
torch.ops._C.join_shm_manager(
|
|
handle,
|
|
self.group_name,
|
|
)
|
|
torch.distributed.barrier(self.communicator.device_group)
|
|
|
|
return handle
|
|
|
|
def all_reduce(
|
|
self, input: torch.Tensor, group: Optional[ProcessGroup] = None
|
|
) -> None:
|
|
torch.ops._C.shm_allreduce(self.handle, input)
|
|
|
|
def gather(
|
|
self,
|
|
input: torch.Tensor,
|
|
gather_list: Optional[list[torch.Tensor]],
|
|
dst: int = -1,
|
|
group: Optional[ProcessGroup] = None,
|
|
) -> None:
|
|
# Note: different from the torch gather, here we use local dst rank.
|
|
torch.ops._C.shm_gather(
|
|
self.handle,
|
|
input,
|
|
gather_list,
|
|
torch.distributed.get_group_rank(group, dst),
|
|
)
|
|
|
|
def all_gather_into_tensor(
|
|
self,
|
|
output: torch.Tensor,
|
|
input: torch.Tensor,
|
|
group: Optional[ProcessGroup] = None,
|
|
) -> None:
|
|
torch.ops._C.shm_all_gather(self.handle, input, output)
|
|
|
|
def send_tensor_dict(
|
|
self,
|
|
tensor_dict: dict[str, Union[torch.Tensor, Any]],
|
|
dst: int,
|
|
) -> None:
|
|
key_list = list(tensor_dict.keys())
|
|
value_list = list(tensor_dict.values())
|
|
size_list = []
|
|
for v in value_list:
|
|
if not isinstance(v, torch.Tensor):
|
|
raise RuntimeError("CpuCommunicator only supports sending tensors.")
|
|
size_list.append(v.size())
|
|
key_size_tensor = torch.frombuffer(
|
|
pickle.dumps([key_list, size_list]), dtype=torch.uint8
|
|
)
|
|
value_list.append(key_size_tensor)
|
|
|
|
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
|
|
|
|
return None
|
|
|
|
def recv_tensor_dict(
|
|
self,
|
|
src: int,
|
|
) -> dict[str, Union[torch.Tensor, Any]]:
|
|
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
|
|
|
|
value_list: list[torch.Tensor] = tensor_list[:-1]
|
|
key_size_tensor = tensor_list[-1]
|
|
|
|
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
|
|
key_list = key_size[0]
|
|
size_list = key_size[1]
|
|
assert len(key_list) == len(size_list)
|
|
assert len(key_list) == len(value_list)
|
|
|
|
tensor_dict: dict[str, torch.Tensor] = {}
|
|
for key, size, t in zip(key_list, size_list, value_list):
|
|
tensor_dict[key] = t.view(size)
|
|
return tensor_dict
|