# SPDX-License-Identifier: Apache-2.0 # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import datetime import pickle import time from collections import deque from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch from torch.distributed import ProcessGroup, TCPStore from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, _unregister_process_group, is_nccl_available) from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" assert numerator % denominator == 0, "{} is not divisible by {}".format( numerator, denominator) def divide(numerator, denominator): """Ensure that numerator is divisible by the denominator and return the division value.""" ensure_divisibility(numerator, denominator) return numerator // denominator def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> Sequence[torch.Tensor]: """ Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = divide(tensor.size()[last_dim], num_partitions) # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # NOTE: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: """Try to evenly distribute layers across partitions. If the number of layers is not divisible by the number of partitions, the remaining layers are evenly distributed across all but the last partition. The last partition is excluded because it often contains an additional norm layer and we are attempting to balance compute. If `pp_size > 2` and the number of remaining layers is `0 < x <= pp_size - 2` then the remaining layers are evenly distributed across the middle partitions. The first and last partitions are excluded because they contain the input and output embeddings respectively and we are attempting to reduce maximum memory consumption across partitions. """ partition_list_str = envs.VLLM_PP_LAYER_PARTITION if partition_list_str is not None: try: partitions = [ int(layer) for layer in partition_list_str.split(",") ] except ValueError as err: raise ValueError("Invalid partition string: {}".format( partition_list_str)) from err if len(partitions) != pp_size: raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") if sum(partitions) != num_hidden_layers: raise ValueError( f"{sum(partitions)=} does not match {num_hidden_layers=}.") else: layers_per_partition = num_hidden_layers // pp_size partitions = [layers_per_partition for _ in range(pp_size)] if remaining_layers := num_hidden_layers % pp_size: for i in range(2, remaining_layers + 2): partitions[-i] += 1 logger.info("Hidden layers were unevenly partitioned: %s", ",".join(str(p) for p in partitions)) logger.info("This can be manually overridden using the " "VLLM_PP_LAYER_PARTITION environment variable") start_layer = sum(partitions[:pp_rank]) end_layer = start_layer + partitions[pp_rank] return (start_layer, end_layer) @dataclasses.dataclass class StatelessProcessGroup: """A dataclass to hold a metadata store, and the rank, world_size of the group. Only use it to communicate metadata between processes. For data-plane communication, create NCCL-related objects. """ rank: int world_size: int store: torch._C._distributed_c10d.Store data_expiration_seconds: int = 3600 # 1 hour # dst rank -> counter send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) # src rank -> counter recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) broadcast_send_counter: int = 0 broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( default_factory=dict) # A deque to store the data entries, with key and timestamp. entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque) def __post_init__(self): assert self.rank < self.world_size self.send_dst_counter = {i: 0 for i in range(self.world_size)} self.recv_src_counter = {i: 0 for i in range(self.world_size)} self.broadcast_recv_src_counter = { i: 0 for i in range(self.world_size) } def send_obj(self, obj: Any, dst: int): """Send an object to a destination rank.""" self.expire_data() key = f"send_to/{dst}/{self.send_dst_counter[dst]}" self.store.set(key, pickle.dumps(obj)) self.send_dst_counter[dst] += 1 self.entries.append((key, time.time())) def expire_data(self): """Expire data that is older than `data_expiration_seconds` seconds.""" while self.entries: # check the oldest entry key, timestamp = self.entries[0] if time.time() - timestamp > self.data_expiration_seconds: self.store.delete_key(key) self.entries.popleft() else: break def recv_obj(self, src: int) -> Any: """Receive an object from a source rank.""" obj = pickle.loads( self.store.get( f"send_to/{self.rank}/{self.recv_src_counter[src]}")) self.recv_src_counter[src] += 1 return obj def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: """Broadcast an object from a source rank to all other ranks. It does not clean up after all ranks have received the object. Use it for limited times, e.g., for initialization. """ if self.rank == src: self.expire_data() key = (f"broadcast_from/{src}/" f"{self.broadcast_send_counter}") self.store.set(key, pickle.dumps(obj)) self.broadcast_send_counter += 1 self.entries.append((key, time.time())) return obj else: key = (f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}") recv_obj = pickle.loads(self.store.get(key)) self.broadcast_recv_src_counter[src] += 1 return recv_obj def all_gather_obj(self, obj: Any) -> list[Any]: """All gather an object from all ranks.""" gathered_objs = [] for i in range(self.world_size): if i == self.rank: gathered_objs.append(obj) self.broadcast_obj(obj, src=self.rank) else: recv_obj = self.broadcast_obj(None, src=i) gathered_objs.append(recv_obj) return gathered_objs def barrier(self): """A barrier to synchronize all ranks.""" for i in range(self.world_size): self.broadcast_obj(None, src=i) @staticmethod def create( host: str, port: int, rank: int, world_size: int, data_expiration_seconds: int = 3600, store_timeout: int = 300, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. If we have process A and process B called `torch.distributed.init_process_group` to form a group, and then we want to form another group with process A, B, C, D, it is not possible in PyTorch, because process A and process B have already formed a group, and process C and process D cannot join that group. This function is a workaround for this issue. `torch.distributed.init_process_group` is a global call, while this function is a stateless call. It will return a `StatelessProcessGroup` object that can be used for exchanging metadata. With this function, process A and process B can call `StatelessProcessGroup.create` to form a group, and then process A, B, C, and D can call `StatelessProcessGroup.create` to form another group. """ # noqa store = TCPStore( host_name=host, port=port, world_size=world_size, is_master=(rank == 0), timeout=datetime.timedelta(seconds=store_timeout), ) return StatelessProcessGroup( rank=rank, world_size=world_size, store=store, data_expiration_seconds=data_expiration_seconds) def stateless_init_torch_distributed_process_group( host: str, port: int, rank: int, world_size: int, backend: str) -> ProcessGroup: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for some operations such as `allreduce`, because it does not depend on the global rank. However, some operations such as `broadcast` cannot be used because it depends on the global rank. # TODO: ask for help from PyTorch team if we need the `broadcast` operation. This function is useful when we are not sure about the total number of processes in the process group. For example, we may have process 1, 2, ..., 8 who want to communicate, and process 9 might be the same process as process 1, or it might be a different process; process 10 might be the same process as process 5, or it might be a different process. In this case, how can we reliably form a communication channel within process 9 and 10, without affecting the communication channel within process 1, 2, ..., 8? One possible solution is to figure out if process 9 and 10 are the same as process 1 and 5 beforehand, and then form a communication channel based on the information, adjusting the ranks and world_size etc. However, figuring out the information is not always easy, and it will interfere with the main communication channel. Our solution is to always form a communication channel with process 1, 2, ..., 8, and then use this function to form another communication channel with process 9 and 10. This way, regardless of whether process 9 and 10 are the same as process 1 and 5, the main communication channel is always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. """ init_method = f"tcp://{host}:{port}" backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) store, rank, world_size = next( rendezvous(init_method, rank, world_size, timeout=timeout)) store.set_timeout(timeout) group_rank = rank group_size = world_size # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) pg: ProcessGroup = ProcessGroup( prefix_store, group_rank, group_size, ) if backend == "gloo": from torch.distributed.distributed_c10d import ProcessGroupGloo backend_class = ProcessGroupGloo(prefix_store, group_rank, group_size, timeout=timeout) backend_type = ProcessGroup.BackendType.GLOO device = torch.device("cpu") elif backend == "nccl": assert is_nccl_available() from torch.distributed.distributed_c10d import ProcessGroupNCCL backend_options = ProcessGroupNCCL.Options() backend_options._timeout = timeout backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, backend_options) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") else: raise RuntimeError(f"Unsupported torch distributed backend: {backend}") pg._set_default_backend(backend_type) backend_class._set_sequence_number_for_group() pg._register_backend(device, backend_type, backend_class) return pg def stateless_destroy_torch_distributed_process_group( pg: ProcessGroup) -> None: """ Destroy ProcessGroup returned by stateless_init_torch_distributed_process_group(). """ # Lazy import for non-CUDA backends. from torch.distributed.distributed_c10d import _shutdown_backend _shutdown_backend(pg) _unregister_process_group(pg.group_name)