mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-29 17:20:55 +08:00
347 lines
14 KiB
Python
347 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
# Adapted from
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
import dataclasses
|
|
import datetime
|
|
import pickle
|
|
import time
|
|
from collections import deque
|
|
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup, TCPStore
|
|
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
|
_get_default_timeout,
|
|
_unregister_process_group,
|
|
is_nccl_available)
|
|
from torch.distributed.rendezvous import rendezvous
|
|
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def ensure_divisibility(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator."""
|
|
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
|
numerator, denominator)
|
|
|
|
|
|
def divide(numerator, denominator):
|
|
"""Ensure that numerator is divisible by the denominator and return
|
|
the division value."""
|
|
ensure_divisibility(numerator, denominator)
|
|
return numerator // denominator
|
|
|
|
|
|
def split_tensor_along_last_dim(
|
|
tensor: torch.Tensor,
|
|
num_partitions: int,
|
|
contiguous_split_chunks: bool = False,
|
|
) -> Sequence[torch.Tensor]:
|
|
""" Split a tensor along its last dimension.
|
|
|
|
Arguments:
|
|
tensor: input tensor.
|
|
num_partitions: number of partitions to split the tensor
|
|
contiguous_split_chunks: If True, make each chunk contiguous
|
|
in memory.
|
|
|
|
Returns:
|
|
A list of Tensors
|
|
"""
|
|
# Get the size and dimension.
|
|
last_dim = tensor.dim() - 1
|
|
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
|
# Split.
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
|
# NOTE: torch.split does not create contiguous tensors by default.
|
|
if contiguous_split_chunks:
|
|
return tuple(chunk.contiguous() for chunk in tensor_list)
|
|
|
|
return tensor_list
|
|
|
|
|
|
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
|
pp_size: int) -> Tuple[int, int]:
|
|
"""Try to evenly distribute layers across partitions.
|
|
|
|
If the number of layers is not divisible by the number of partitions,
|
|
the remaining layers are evenly distributed across all but the last
|
|
partition. The last partition is excluded because it often contains an
|
|
additional norm layer and we are attempting to balance compute.
|
|
|
|
If `pp_size > 2` and the number of remaining layers is
|
|
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
|
|
across the middle partitions. The first and last partitions are excluded
|
|
because they contain the input and output embeddings respectively and we
|
|
are attempting to reduce maximum memory consumption across partitions.
|
|
"""
|
|
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
|
if partition_list_str is not None:
|
|
try:
|
|
partitions = [
|
|
int(layer) for layer in partition_list_str.split(",")
|
|
]
|
|
except ValueError as err:
|
|
raise ValueError("Invalid partition string: {}".format(
|
|
partition_list_str)) from err
|
|
if len(partitions) != pp_size:
|
|
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
|
if sum(partitions) != num_hidden_layers:
|
|
raise ValueError(
|
|
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
|
else:
|
|
layers_per_partition = num_hidden_layers // pp_size
|
|
partitions = [layers_per_partition for _ in range(pp_size)]
|
|
|
|
if remaining_layers := num_hidden_layers % pp_size:
|
|
for i in range(2, remaining_layers + 2):
|
|
partitions[-i] += 1
|
|
logger.info(
|
|
"Hidden layers were unevenly partitioned: [%s]. "
|
|
"This can be manually overridden using the "
|
|
"VLLM_PP_LAYER_PARTITION environment variable",
|
|
",".join(str(p) for p in partitions))
|
|
|
|
start_layer = sum(partitions[:pp_rank])
|
|
end_layer = start_layer + partitions[pp_rank]
|
|
|
|
return (start_layer, end_layer)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class StatelessProcessGroup:
|
|
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
|
group. Only use it to communicate metadata between processes.
|
|
For data-plane communication, create NCCL-related objects.
|
|
"""
|
|
rank: int
|
|
world_size: int
|
|
store: torch._C._distributed_c10d.Store
|
|
data_expiration_seconds: int = 3600 # 1 hour
|
|
|
|
# dst rank -> counter
|
|
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
|
# src rank -> counter
|
|
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
|
broadcast_send_counter: int = 0
|
|
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
|
|
default_factory=dict)
|
|
|
|
# A deque to store the data entries, with key and timestamp.
|
|
entries: Deque[Tuple[str,
|
|
float]] = dataclasses.field(default_factory=deque)
|
|
|
|
def __post_init__(self):
|
|
assert self.rank < self.world_size
|
|
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
|
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
|
self.broadcast_recv_src_counter = {
|
|
i: 0
|
|
for i in range(self.world_size)
|
|
}
|
|
|
|
def send_obj(self, obj: Any, dst: int):
|
|
"""Send an object to a destination rank."""
|
|
self.expire_data()
|
|
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
|
self.store.set(key, pickle.dumps(obj))
|
|
self.send_dst_counter[dst] += 1
|
|
self.entries.append((key, time.time()))
|
|
|
|
def expire_data(self):
|
|
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
|
while self.entries:
|
|
# check the oldest entry
|
|
key, timestamp = self.entries[0]
|
|
if time.time() - timestamp > self.data_expiration_seconds:
|
|
self.store.delete_key(key)
|
|
self.entries.popleft()
|
|
else:
|
|
break
|
|
|
|
def recv_obj(self, src: int) -> Any:
|
|
"""Receive an object from a source rank."""
|
|
obj = pickle.loads(
|
|
self.store.get(
|
|
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
|
|
self.recv_src_counter[src] += 1
|
|
return obj
|
|
|
|
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
|
"""Broadcast an object from a source rank to all other ranks.
|
|
It does not clean up after all ranks have received the object.
|
|
Use it for limited times, e.g., for initialization.
|
|
"""
|
|
if self.rank == src:
|
|
self.expire_data()
|
|
key = (f"broadcast_from/{src}/"
|
|
f"{self.broadcast_send_counter}")
|
|
self.store.set(key, pickle.dumps(obj))
|
|
self.broadcast_send_counter += 1
|
|
self.entries.append((key, time.time()))
|
|
return obj
|
|
else:
|
|
key = (f"broadcast_from/{src}/"
|
|
f"{self.broadcast_recv_src_counter[src]}")
|
|
recv_obj = pickle.loads(self.store.get(key))
|
|
self.broadcast_recv_src_counter[src] += 1
|
|
return recv_obj
|
|
|
|
def all_gather_obj(self, obj: Any) -> list[Any]:
|
|
"""All gather an object from all ranks."""
|
|
gathered_objs = []
|
|
for i in range(self.world_size):
|
|
if i == self.rank:
|
|
gathered_objs.append(obj)
|
|
self.broadcast_obj(obj, src=self.rank)
|
|
else:
|
|
recv_obj = self.broadcast_obj(None, src=i)
|
|
gathered_objs.append(recv_obj)
|
|
return gathered_objs
|
|
|
|
def barrier(self):
|
|
"""A barrier to synchronize all ranks."""
|
|
for i in range(self.world_size):
|
|
self.broadcast_obj(None, src=i)
|
|
|
|
@staticmethod
|
|
def create(
|
|
host: str,
|
|
port: int,
|
|
rank: int,
|
|
world_size: int,
|
|
data_expiration_seconds: int = 3600,
|
|
store_timeout: int = 300,
|
|
) -> "StatelessProcessGroup":
|
|
"""A replacement for `torch.distributed.init_process_group` that does not
|
|
pollute the global state.
|
|
|
|
If we have process A and process B called `torch.distributed.init_process_group`
|
|
to form a group, and then we want to form another group with process A, B, C,
|
|
D, it is not possible in PyTorch, because process A and process B have already
|
|
formed a group, and process C and process D cannot join that group. This
|
|
function is a workaround for this issue.
|
|
|
|
`torch.distributed.init_process_group` is a global call, while this function
|
|
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
|
used for exchanging metadata. With this function, process A and process B
|
|
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
|
C, and D can call `StatelessProcessGroup.create` to form another group.
|
|
""" # noqa
|
|
store = TCPStore(
|
|
host_name=host,
|
|
port=port,
|
|
world_size=world_size,
|
|
is_master=(rank == 0),
|
|
timeout=datetime.timedelta(seconds=store_timeout),
|
|
)
|
|
|
|
return StatelessProcessGroup(
|
|
rank=rank,
|
|
world_size=world_size,
|
|
store=store,
|
|
data_expiration_seconds=data_expiration_seconds)
|
|
|
|
|
|
def stateless_init_torch_distributed_process_group(
|
|
host: str, port: int, rank: int, world_size: int,
|
|
backend: str) -> ProcessGroup:
|
|
"""
|
|
A replacement for `torch.distributed.init_process_group` that does not
|
|
pollute the global state. The created ProcessGroup object can be used for
|
|
some operations such as `allreduce`, because it does not depend on the
|
|
global rank. However, some operations such as `broadcast` cannot be used
|
|
because it depends on the global rank.
|
|
|
|
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
|
|
|
This function is useful when we are not sure about the total number of
|
|
processes in the process group. For example, we may have process
|
|
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
|
process as process 1, or it might be a different process; process 10
|
|
might be the same process as process 5, or it might be a different process.
|
|
In this case, how can we reliably form a communication channel within
|
|
process 9 and 10, without affecting the communication channel within
|
|
process 1, 2, ..., 8?
|
|
|
|
One possible solution is to figure out if process 9 and 10 are the same
|
|
as process 1 and 5 beforehand, and then form a communication channel
|
|
based on the information, adjusting the ranks and world_size etc. However,
|
|
figuring out the information is not always easy, and it will interfere
|
|
with the main communication channel.
|
|
|
|
Our solution is to always form a communication channel with process 1, 2,
|
|
..., 8, and then use this function to form another communication channel
|
|
with process 9 and 10. This way, regardless of whether process 9 and 10
|
|
are the same as process 1 and 5, the main communication channel is
|
|
always formed with process 1, 2, ..., 8, and the additional communication
|
|
channel is formed with process 9 and 10.
|
|
"""
|
|
init_method = f"tcp://{host}:{port}"
|
|
backend = Backend(backend) # it is basically string
|
|
timeout = _get_default_timeout(backend)
|
|
|
|
store, rank, world_size = next(
|
|
rendezvous(init_method, rank, world_size, timeout=timeout))
|
|
store.set_timeout(timeout)
|
|
|
|
group_rank = rank
|
|
group_size = world_size
|
|
|
|
# Use a PrefixStore to avoid accidental overrides of keys used by
|
|
# different systems (e.g. RPC) in case the store is multi-tenant.
|
|
prefix_store = PrefixStore(init_method, store)
|
|
|
|
pg: ProcessGroup = ProcessGroup(
|
|
prefix_store,
|
|
group_rank,
|
|
group_size,
|
|
)
|
|
|
|
if backend == "gloo":
|
|
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
|
backend_class = ProcessGroupGloo(prefix_store,
|
|
group_rank,
|
|
group_size,
|
|
timeout=timeout)
|
|
backend_type = ProcessGroup.BackendType.GLOO
|
|
device = torch.device("cpu")
|
|
elif backend == "nccl":
|
|
assert is_nccl_available()
|
|
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
|
|
|
backend_options = ProcessGroupNCCL.Options()
|
|
backend_options._timeout = timeout
|
|
|
|
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
|
backend_options)
|
|
backend_type = ProcessGroup.BackendType.NCCL
|
|
device = torch.device("cuda")
|
|
else:
|
|
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
|
|
|
pg._set_default_backend(backend_type)
|
|
backend_class._set_sequence_number_for_group()
|
|
|
|
pg._register_backend(device, backend_type, backend_class)
|
|
|
|
return pg
|
|
|
|
|
|
def stateless_destroy_torch_distributed_process_group(
|
|
pg: ProcessGroup) -> None:
|
|
"""
|
|
Destroy ProcessGroup returned by
|
|
stateless_init_torch_distributed_process_group().
|
|
"""
|
|
# Lazy import for non-CUDA backends.
|
|
from torch.distributed.distributed_c10d import _shutdown_backend
|
|
_shutdown_backend(pg)
|
|
_unregister_process_group(pg.group_name)
|