[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)

This commit is contained in:
youkaichao 2024-06-25 21:56:02 -07:00 committed by GitHub
parent 3aa7b6cf66
commit 515080ad2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 76 additions and 46 deletions

View File

@ -1,7 +1,9 @@
import multiprocessing import multiprocessing
import random import random
import time import time
from typing import List
import numpy as np
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import ( from vllm.distributed.device_communicators.shm_broadcast import (
@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes = [] processes = []
@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
def worker_fn(): def worker_fn():
writer_rank = 2 writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group( broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank) dist.group.WORLD, 1024 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank: if dist.get_rank() == writer_rank:
time.sleep(random.random()) seed = random.randint(0, 1000)
broadcaster.broadcast_object(0) dist.broadcast_object_list([seed], writer_rank)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
else: else:
time.sleep(random.random()) recv = [None]
a = broadcaster.broadcast_object(None) dist.broadcast_object_list(recv, writer_rank)
time.sleep(random.random()) seed = recv[0] # type: ignore
b = broadcaster.broadcast_object(None) dist.barrier()
time.sleep(random.random()) # in case we find a race condition
c = broadcaster.broadcast_object(None) # print the seed so that we can reproduce the error
assert a == 0 print(f"Rank {dist.get_rank()} got seed {seed}")
assert b == {} # test broadcasting with about 400MB of data
assert c == [] N = 10_000
if dist.get_rank() == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
dist.barrier() dist.barrier()

View File

@ -14,6 +14,12 @@ from vllm.logger import init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7
logger = init_logger(__name__) logger = init_logger(__name__)
@ -145,8 +151,7 @@ class ShmRingBufferIO:
@contextmanager @contextmanager
def acquire_write(self): def acquire_write(self):
assert self._is_writer, "Only writers can acquire write" assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx start_time = time.monotonic()
start_time = time.time()
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
@ -154,19 +159,21 @@ class ShmRingBufferIO:
written_flag = metadata_buffer[0] written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader: if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers # this block is written and not read by all readers
# try to write to the next block # for writers, `self.current_idx` is the next block to write
self.current_idx = (self.current_idx + # if this block is not ready to write,
1) % self.buffer.max_chunks # we need to wait until it is read by all readers
if self.current_idx == start_index:
# no empty block found # wait for a while
if time.time( time.sleep(RINGBUFFER_SLEEP_INTERVAL)
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning( # if we wait for a long time, we should warn the user
"No available block found in %s second. ", if time.monotonic(
VLLM_RINGBUFFER_WARNING_INTERVAL) ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
n_warning += 1 logger.warning(
# wait for a while (0.1 us) "No available block found in %s second. ",
time.sleep(1e-7) VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue continue
# found a block that is either # found a block that is either
# (1) not written # (1) not written
@ -188,13 +195,14 @@ class ShmRingBufferIO:
metadata_buffer[i] = 0 metadata_buffer[i] = 0
# mark the block as written # mark the block as written
metadata_buffer[0] = 1 metadata_buffer[0] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break break
@contextmanager @contextmanager
def acquire_read(self): def acquire_read(self):
assert self._is_reader, "Only readers can acquire read" assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx start_time = time.monotonic()
start_time = time.time()
n_warning = 1 n_warning = 1
while True: while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer: with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
@ -204,19 +212,22 @@ class ShmRingBufferIO:
# this block is either # this block is either
# (1) not written # (1) not written
# (2) already read by this reader # (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx + # for readers, `self.current_idx` is the next block to read
1) % self.buffer.max_chunks # if this block is not ready,
if self.current_idx == start_index: # we need to wait until it is written
# no block found
if time.time( # wait for a while
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa time.sleep(RINGBUFFER_SLEEP_INTERVAL)
logger.warning(
"No available block found in %s second. ", # if we wait for a long time, we should warn the user
VLLM_RINGBUFFER_WARNING_INTERVAL) if time.monotonic(
n_warning += 1 ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
# wait for a while (0.1 us) logger.warning(
time.sleep(1e-7) "No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
continue continue
# found a block that is not read by this reader # found a block that is not read by this reader
# let caller read from the buffer # let caller read from the buffer
@ -226,6 +237,8 @@ class ShmRingBufferIO:
# caller has read from the buffer # caller has read from the buffer
# set the read flag # set the read flag
metadata_buffer[self.reader_rank + 1] = 1 metadata_buffer[self.reader_rank + 1] = 1
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
break break
def enqueue(self, obj): def enqueue(self, obj):