[CI] Fix race condition with StatelessProcessGroup.barrier (#18506)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-05-21 23:19:13 -04:00 committed by GitHub
parent 176d62e4ea
commit 6e0fd34d3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 157 additions and 25 deletions

View File

@ -9,7 +9,7 @@ import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port, update_environment_variables
from vllm.utils import get_open_port, update_environment_variables
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
@ -60,12 +60,12 @@ def worker_fn():
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = get_ip()
ip = '127.0.0.1'
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
ip, port = recv # type: ignore
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
dist.get_world_size())
@ -107,10 +107,10 @@ def worker_fn():
if pg == dist.group.WORLD:
dist.barrier()
print("torch distributed passed the test!")
print(f"torch distributed passed the test! Rank {rank}")
else:
pg.barrier()
print("StatelessProcessGroup passed the test!")
print(f"StatelessProcessGroup passed the test! Rank {rank}")
def test_shm_broadcast():

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import os
import pickle
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
from vllm.logger import init_logger
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
class ShmRingBuffer:

View File

@ -6,9 +6,12 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import os
import pickle
import socket
import sys
import time
import uuid
from collections import deque
from collections.abc import Sequence
from typing import Any, Optional
@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
logger = init_logger(__name__)
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))
def sched_yield():
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
@ -212,10 +229,141 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
Uses a multi-phase approach to ensure all processes reach the barrier
before proceeding:
1. Each process signals it has reached the barrier
2. Each process signals that it has confirmed the arrival of all other
ranks.
3. Rank 0 waits for all other ranks to signal their departure to ensure
that all ranks have departed the barrier first.
Args:
timeout: Maximum time in seconds to wait for each phase (in seconds)
Raises:
RuntimeError: If coordination fails or times out
"""
# Generate a barrier ID that is globally unique
try:
if self.rank == 0:
barrier_id = f"barrier_{uuid.uuid4()}"
self.broadcast_obj(barrier_id, src=0)
else:
barrier_id = self.broadcast_obj(None, src=0)
except Exception as e:
raise RuntimeError("Failed to broadcast barrier_id") from e
# Phase 1: Signal arrival at barrier
# Wait for all processes to arrive
# We need all ranks to confirm the arrival of all other ranks.
# This is the key synchronization point.
arrival_key = f"arrival_{barrier_id}_{self.rank}"
try:
self.store.set(arrival_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier arrival") from e
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < self.world_size:
# Check for timeout
cur_time = time.time()
if cur_time - start_time > timeout:
raise RuntimeError("Barrier timed out after %f seconds",
timeout)
# Check for each process
for i in range(self.world_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_arrived.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_arrived) < self.world_size:
sched_yield()
# Phase 2: Signal departure from barrier
# We only care to block at this stage in rank 0, which runs the
# server side of the TCPStore. We want to make sure that all
# clients have departed the barrier before rank 0 in case the
# next thing after the barrier is a shutdown, including tearing
# down the TCPStore. Other ranks can exit the barrier immediately
# after signaling their departure.
departure_key = f"departure_{barrier_id}_{self.rank}"
try:
self.store.set(departure_key, b"1")
except Exception as e:
raise RuntimeError("Failed to signal barrier departure") from e
if self.rank != 0:
return
# Make rank 0 wait for all processes to signal departure
start_time = time.time()
processes_departed: set[int] = set()
while len(processes_departed) < self.world_size:
# Check for timeout
if time.time() - start_time > timeout:
raise RuntimeError("Barrier departure timed out after %f s",
timeout)
# Check for each process
for i in range(self.world_size):
if i in processes_departed:
continue
key = f"departure_{barrier_id}_{i}"
try:
# Try to get the key - if it exists, we'll get a value
# If it doesn't exist, it will throw an exception
self.store.get(key)
processes_departed.add(i)
except KeyError:
# Key doesn't exist yet
pass
except Exception as check_e:
logger.debug("Error checking key existence: %s", check_e)
sched_yield()
# Short sleep to avoid tight polling
if len(processes_departed) < self.world_size:
sched_yield()
# Clean up keys to avoid leaking memory in the store
for i in range(self.world_size):
self.broadcast_obj(None, src=i)
try:
self.store.delete_key(f"arrival_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'arrival_{barrier_id}_{i}')
try:
self.store.delete_key(f"departure_{barrier_id}_{i}")
except Exception:
logger.debug("Error deleting key: %s",
f'departure_{barrier_id}_{i}')
@staticmethod
def create(