mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[CI] Fix race condition with StatelessProcessGroup.barrier (#18506)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
176d62e4ea
commit
6e0fd34d3c
@ -9,7 +9,7 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
|
||||
|
||||
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||
@ -60,12 +60,12 @@ def worker_fn():
|
||||
rank = dist.get_rank()
|
||||
if rank == 0:
|
||||
port = get_open_port()
|
||||
ip = get_ip()
|
||||
ip = '127.0.0.1'
|
||||
dist.broadcast_object_list([ip, port], src=0)
|
||||
else:
|
||||
recv = [None, None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv
|
||||
ip, port = recv # type: ignore
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||
dist.get_world_size())
|
||||
@ -107,10 +107,10 @@ def worker_fn():
|
||||
|
||||
if pg == dist.group.WORLD:
|
||||
dist.barrier()
|
||||
print("torch distributed passed the test!")
|
||||
print(f"torch distributed passed the test! Rank {rank}")
|
||||
else:
|
||||
pg.barrier()
|
||||
print("StatelessProcessGroup passed the test!")
|
||||
print(f"StatelessProcessGroup passed the test! Rank {rank}")
|
||||
|
||||
|
||||
def test_shm_broadcast():
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
|
||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address)
|
||||
@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||
or (sys.version_info[:2] == (3, 10)
|
||||
and sys.version_info[2] >= 8))
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
|
||||
@ -6,9 +6,12 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||
or (sys.version_info[:2] == (3, 10)
|
||||
and sys.version_info[2] >= 8))
|
||||
|
||||
|
||||
def sched_yield():
|
||||
if USE_SCHED_YIELD:
|
||||
os.sched_yield()
|
||||
else:
|
||||
time.sleep(0)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
@ -212,10 +229,141 @@ class StatelessProcessGroup:
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def barrier(self):
|
||||
"""A barrier to synchronize all ranks."""
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
|
||||
Uses a multi-phase approach to ensure all processes reach the barrier
|
||||
before proceeding:
|
||||
|
||||
1. Each process signals it has reached the barrier
|
||||
|
||||
2. Each process signals that it has confirmed the arrival of all other
|
||||
ranks.
|
||||
|
||||
3. Rank 0 waits for all other ranks to signal their departure to ensure
|
||||
that all ranks have departed the barrier first.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time in seconds to wait for each phase (in seconds)
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordination fails or times out
|
||||
"""
|
||||
# Generate a barrier ID that is globally unique
|
||||
try:
|
||||
if self.rank == 0:
|
||||
barrier_id = f"barrier_{uuid.uuid4()}"
|
||||
self.broadcast_obj(barrier_id, src=0)
|
||||
else:
|
||||
barrier_id = self.broadcast_obj(None, src=0)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to broadcast barrier_id") from e
|
||||
|
||||
# Phase 1: Signal arrival at barrier
|
||||
# Wait for all processes to arrive
|
||||
# We need all ranks to confirm the arrival of all other ranks.
|
||||
# This is the key synchronization point.
|
||||
arrival_key = f"arrival_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(arrival_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier arrival") from e
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < self.world_size:
|
||||
# Check for timeout
|
||||
cur_time = time.time()
|
||||
if cur_time - start_time > timeout:
|
||||
raise RuntimeError("Barrier timed out after %f seconds",
|
||||
timeout)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_arrived.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_arrived) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Phase 2: Signal departure from barrier
|
||||
# We only care to block at this stage in rank 0, which runs the
|
||||
# server side of the TCPStore. We want to make sure that all
|
||||
# clients have departed the barrier before rank 0 in case the
|
||||
# next thing after the barrier is a shutdown, including tearing
|
||||
# down the TCPStore. Other ranks can exit the barrier immediately
|
||||
# after signaling their departure.
|
||||
departure_key = f"departure_{barrier_id}_{self.rank}"
|
||||
try:
|
||||
self.store.set(departure_key, b"1")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Failed to signal barrier departure") from e
|
||||
|
||||
if self.rank != 0:
|
||||
return
|
||||
|
||||
# Make rank 0 wait for all processes to signal departure
|
||||
start_time = time.time()
|
||||
processes_departed: set[int] = set()
|
||||
|
||||
while len(processes_departed) < self.world_size:
|
||||
# Check for timeout
|
||||
if time.time() - start_time > timeout:
|
||||
raise RuntimeError("Barrier departure timed out after %f s",
|
||||
timeout)
|
||||
|
||||
# Check for each process
|
||||
for i in range(self.world_size):
|
||||
if i in processes_departed:
|
||||
continue
|
||||
|
||||
key = f"departure_{barrier_id}_{i}"
|
||||
try:
|
||||
# Try to get the key - if it exists, we'll get a value
|
||||
# If it doesn't exist, it will throw an exception
|
||||
self.store.get(key)
|
||||
processes_departed.add(i)
|
||||
except KeyError:
|
||||
# Key doesn't exist yet
|
||||
pass
|
||||
except Exception as check_e:
|
||||
logger.debug("Error checking key existence: %s", check_e)
|
||||
sched_yield()
|
||||
|
||||
# Short sleep to avoid tight polling
|
||||
if len(processes_departed) < self.world_size:
|
||||
sched_yield()
|
||||
|
||||
# Clean up keys to avoid leaking memory in the store
|
||||
for i in range(self.world_size):
|
||||
self.broadcast_obj(None, src=i)
|
||||
try:
|
||||
self.store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s",
|
||||
f'arrival_{barrier_id}_{i}')
|
||||
|
||||
try:
|
||||
self.store.delete_key(f"departure_{barrier_id}_{i}")
|
||||
except Exception:
|
||||
logger.debug("Error deleting key: %s",
|
||||
f'departure_{barrier_id}_{i}')
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user