mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 03:25:01 +08:00
[CI] Fix race condition with StatelessProcessGroup.barrier (#18506)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
176d62e4ea
commit
6e0fd34d3c
@ -9,7 +9,7 @@ import torch.distributed as dist
|
|||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.utils import get_ip, get_open_port, update_environment_variables
|
from vllm.utils import get_open_port, update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||||
@ -60,12 +60,12 @@ def worker_fn():
|
|||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
port = get_open_port()
|
port = get_open_port()
|
||||||
ip = get_ip()
|
ip = '127.0.0.1'
|
||||||
dist.broadcast_object_list([ip, port], src=0)
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
else:
|
else:
|
||||||
recv = [None, None]
|
recv = [None, None]
|
||||||
dist.broadcast_object_list(recv, src=0)
|
dist.broadcast_object_list(recv, src=0)
|
||||||
ip, port = recv
|
ip, port = recv # type: ignore
|
||||||
|
|
||||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||||
dist.get_world_size())
|
dist.get_world_size())
|
||||||
@ -107,10 +107,10 @@ def worker_fn():
|
|||||||
|
|
||||||
if pg == dist.group.WORLD:
|
if pg == dist.group.WORLD:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
print("torch distributed passed the test!")
|
print(f"torch distributed passed the test! Rank {rank}")
|
||||||
else:
|
else:
|
||||||
pg.barrier()
|
pg.barrier()
|
||||||
print("StatelessProcessGroup passed the test!")
|
print(f"StatelessProcessGroup passed the test! Rank {rank}")
|
||||||
|
|
||||||
|
|
||||||
def test_shm_broadcast():
|
def test_shm_broadcast():
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import os
|
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -19,7 +17,7 @@ from zmq import IPV6 # type: ignore
|
|||||||
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||||
is_valid_ipv6_address)
|
is_valid_ipv6_address)
|
||||||
@ -28,20 +26,6 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
|
||||||
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
|
||||||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
|
||||||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
|
||||||
or (sys.version_info[:2] == (3, 10)
|
|
||||||
and sys.version_info[2] >= 8))
|
|
||||||
|
|
||||||
|
|
||||||
def sched_yield():
|
|
||||||
if USE_SCHED_YIELD:
|
|
||||||
os.sched_yield()
|
|
||||||
else:
|
|
||||||
time.sleep(0)
|
|
||||||
|
|
||||||
|
|
||||||
class ShmRingBuffer:
|
class ShmRingBuffer:
|
||||||
|
|
||||||
|
|||||||
@ -6,9 +6,12 @@
|
|||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import datetime
|
import datetime
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import socket
|
import socket
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@ -27,6 +30,20 @@ from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||||
|
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||||
|
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||||
|
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||||
|
or (sys.version_info[:2] == (3, 10)
|
||||||
|
and sys.version_info[2] >= 8))
|
||||||
|
|
||||||
|
|
||||||
|
def sched_yield():
|
||||||
|
if USE_SCHED_YIELD:
|
||||||
|
os.sched_yield()
|
||||||
|
else:
|
||||||
|
time.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
def ensure_divisibility(numerator, denominator):
|
def ensure_divisibility(numerator, denominator):
|
||||||
"""Ensure that numerator is divisible by the denominator."""
|
"""Ensure that numerator is divisible by the denominator."""
|
||||||
@ -212,10 +229,141 @@ class StatelessProcessGroup:
|
|||||||
gathered_objs.append(recv_obj)
|
gathered_objs.append(recv_obj)
|
||||||
return gathered_objs
|
return gathered_objs
|
||||||
|
|
||||||
def barrier(self):
|
def barrier(self, timeout: float = 30.0):
|
||||||
"""A barrier to synchronize all ranks."""
|
"""A robust barrier to synchronize all ranks.
|
||||||
|
|
||||||
|
|
||||||
|
Uses a multi-phase approach to ensure all processes reach the barrier
|
||||||
|
before proceeding:
|
||||||
|
|
||||||
|
1. Each process signals it has reached the barrier
|
||||||
|
|
||||||
|
2. Each process signals that it has confirmed the arrival of all other
|
||||||
|
ranks.
|
||||||
|
|
||||||
|
3. Rank 0 waits for all other ranks to signal their departure to ensure
|
||||||
|
that all ranks have departed the barrier first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time in seconds to wait for each phase (in seconds)
|
||||||
|
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If coordination fails or times out
|
||||||
|
"""
|
||||||
|
# Generate a barrier ID that is globally unique
|
||||||
|
try:
|
||||||
|
if self.rank == 0:
|
||||||
|
barrier_id = f"barrier_{uuid.uuid4()}"
|
||||||
|
self.broadcast_obj(barrier_id, src=0)
|
||||||
|
else:
|
||||||
|
barrier_id = self.broadcast_obj(None, src=0)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed to broadcast barrier_id") from e
|
||||||
|
|
||||||
|
# Phase 1: Signal arrival at barrier
|
||||||
|
# Wait for all processes to arrive
|
||||||
|
# We need all ranks to confirm the arrival of all other ranks.
|
||||||
|
# This is the key synchronization point.
|
||||||
|
arrival_key = f"arrival_{barrier_id}_{self.rank}"
|
||||||
|
try:
|
||||||
|
self.store.set(arrival_key, b"1")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed to signal barrier arrival") from e
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
processes_arrived: set[int] = set()
|
||||||
|
|
||||||
|
while len(processes_arrived) < self.world_size:
|
||||||
|
# Check for timeout
|
||||||
|
cur_time = time.time()
|
||||||
|
if cur_time - start_time > timeout:
|
||||||
|
raise RuntimeError("Barrier timed out after %f seconds",
|
||||||
|
timeout)
|
||||||
|
|
||||||
|
# Check for each process
|
||||||
for i in range(self.world_size):
|
for i in range(self.world_size):
|
||||||
self.broadcast_obj(None, src=i)
|
if i in processes_arrived:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = f"arrival_{barrier_id}_{i}"
|
||||||
|
try:
|
||||||
|
# Try to get the key - if it exists, we'll get a value
|
||||||
|
# If it doesn't exist, it will throw an exception
|
||||||
|
self.store.get(key)
|
||||||
|
processes_arrived.add(i)
|
||||||
|
except KeyError:
|
||||||
|
# Key doesn't exist yet
|
||||||
|
pass
|
||||||
|
except Exception as check_e:
|
||||||
|
logger.debug("Error checking key existence: %s", check_e)
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
# Short sleep to avoid tight polling
|
||||||
|
if len(processes_arrived) < self.world_size:
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
# Phase 2: Signal departure from barrier
|
||||||
|
# We only care to block at this stage in rank 0, which runs the
|
||||||
|
# server side of the TCPStore. We want to make sure that all
|
||||||
|
# clients have departed the barrier before rank 0 in case the
|
||||||
|
# next thing after the barrier is a shutdown, including tearing
|
||||||
|
# down the TCPStore. Other ranks can exit the barrier immediately
|
||||||
|
# after signaling their departure.
|
||||||
|
departure_key = f"departure_{barrier_id}_{self.rank}"
|
||||||
|
try:
|
||||||
|
self.store.set(departure_key, b"1")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed to signal barrier departure") from e
|
||||||
|
|
||||||
|
if self.rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Make rank 0 wait for all processes to signal departure
|
||||||
|
start_time = time.time()
|
||||||
|
processes_departed: set[int] = set()
|
||||||
|
|
||||||
|
while len(processes_departed) < self.world_size:
|
||||||
|
# Check for timeout
|
||||||
|
if time.time() - start_time > timeout:
|
||||||
|
raise RuntimeError("Barrier departure timed out after %f s",
|
||||||
|
timeout)
|
||||||
|
|
||||||
|
# Check for each process
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i in processes_departed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = f"departure_{barrier_id}_{i}"
|
||||||
|
try:
|
||||||
|
# Try to get the key - if it exists, we'll get a value
|
||||||
|
# If it doesn't exist, it will throw an exception
|
||||||
|
self.store.get(key)
|
||||||
|
processes_departed.add(i)
|
||||||
|
except KeyError:
|
||||||
|
# Key doesn't exist yet
|
||||||
|
pass
|
||||||
|
except Exception as check_e:
|
||||||
|
logger.debug("Error checking key existence: %s", check_e)
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
# Short sleep to avoid tight polling
|
||||||
|
if len(processes_departed) < self.world_size:
|
||||||
|
sched_yield()
|
||||||
|
|
||||||
|
# Clean up keys to avoid leaking memory in the store
|
||||||
|
for i in range(self.world_size):
|
||||||
|
try:
|
||||||
|
self.store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Error deleting key: %s",
|
||||||
|
f'arrival_{barrier_id}_{i}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.store.delete_key(f"departure_{barrier_id}_{i}")
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Error deleting key: %s",
|
||||||
|
f'departure_{barrier_id}_{i}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user