vllm/vllm/distributed/utils.py
Michael Goin 8d6a89dffd
[UX] Suppress gloo log spam (#29250)
Signed-off-by: mgoin <mgoin64@gmail.com>
2025-11-25 17:19:35 -08:00

546 lines
20 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import os
import pickle
import socket
import sys
import time
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
_get_default_timeout,
_unregister_process_group,
)
from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or (
sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8
)
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def get_pp_indices(
num_hidden_layers: int, pp_rank: int, pp_size: int
) -> tuple[int, int]:
"""Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions,
the remaining layers are evenly distributed across all but the last
partition. The last partition is excluded because it often contains an
additional norm layer and we are attempting to balance compute.
If `pp_size > 2` and the number of remaining layers is
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
across the middle partitions. The first and last partitions are excluded
because they contain the input and output embeddings respectively and we
are attempting to reduce maximum memory consumption across partitions.
"""
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None:
try:
partitions = [int(layer) for layer in partition_list_str.split(",")]
except ValueError as err:
raise ValueError(
"Invalid partition string: {}".format(partition_list_str)
) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable",
",".join(str(p) for p in partitions),
)
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
@dataclasses.dataclass
class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: socket.socket | None
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def expire_data(self):
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
)
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Any | None, src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = f"broadcast_from/{src}/{self.broadcast_send_counter}"
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}"
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
Uses a multi-phase approach to ensure all processes reach the barrier
before proceeding:
1. Each process signals it has reached the barrier
2. Each process signals that it has confirmed the arrival of all other
ranks.
3. Rank 0 waits for all other ranks to signal their departure to ensure
that all ranks have departed the barrier first.
Args:
timeout: Maximum time in seconds to wait for each phase (in seconds)
Raises:
RuntimeError: If coordination fails or times out
"""
# Generate a barrier ID that is globally unique
try:
if self.rank == 0:
barrier_id = f"barrier_{uuid.uuid4()}"
self.broadcast_obj(barrier_id, src=0)
else:
barrier_id = self.broadcast_obj(None, src=0)
except Exception as e:
raise RuntimeError("Failed to broadcast barrier_id") from e
# Phase 1: Signal arrival at barrier
# Wait for all processes to arrive
# We need all ranks to confirm the arrival of all other ranks.
# This is the key synchronization point.
arrival_key = f"arrival_{barrier_id}_{self.rank}"
try:
self.store.set(arrival_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier arrival") from e
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < self.world_size:
# Check for timeout
cur_time = time.time()
if cur_time - start_time > timeout:
raise RuntimeError(f"Barrier timed out after {timeout:.2f} seconds")
# Check for each process
for i in range(self.world_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_arrived.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_arrived) < self.world_size:
sched_yield()
# Phase 2: Signal departure from barrier
# We only care to block at this stage in rank 0, which runs the
# server side of the TCPStore. We want to make sure that all
# clients have departed the barrier before rank 0 in case the
# next thing after the barrier is a shutdown, including tearing
# down the TCPStore. Other ranks can exit the barrier immediately
# after signaling their departure.
departure_key = f"departure_{barrier_id}_{self.rank}"
try:
self.store.set(departure_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier departure") from e
if self.rank != 0:
return
# Make rank 0 wait for all processes to signal departure
start_time = time.time()
processes_departed: set[int] = set()
while len(processes_departed) < self.world_size:
# Check for timeout
if time.time() - start_time > timeout:
raise RuntimeError(
f"Barrier departure timed out after {timeout:.2f} seconds"
)
# Check for each process
for i in range(self.world_size):
if i in processes_departed:
continue
key = f"departure_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_departed.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_departed) < self.world_size:
sched_yield()
# Clean up keys to avoid leaking memory in the store
for i in range(self.world_size):
try:
self.store.delete_key(f"arrival_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}")
try:
self.store.delete_key(f"departure_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}")
@staticmethod
def create(
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
If we have process A and process B called `torch.distributed.init_process_group`
to form a group, and then we want to form another group with process A, B, C,
D, it is not possible in PyTorch, because process A and process B have already
formed a group, and process C and process D cannot join that group. This
function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `StatelessProcessGroup` object that can be
used for exchanging metadata. With this function, process A and process B
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=launch_server,
timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds,
)
def init_gloo_process_group(
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
"""
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
with suppress_stdout():
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(
prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout)
)
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
"""
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
if is_torch_equal_or_newer("2.7"):
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
_unregister_process_group(pg.group_name)