mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 12:55:37 +08:00
[bugfix][distributed] fix shm broadcast when the queue size is full (#5801)
This commit is contained in:
parent
3aa7b6cf66
commit
515080ad2f
@ -1,7 +1,9 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||||
@ -9,6 +11,14 @@ from vllm.distributed.device_communicators.shm_broadcast import (
|
|||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
|
||||||
|
def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]:
|
||||||
|
np.random.seed(seed)
|
||||||
|
sizes = np.random.randint(1, 10_000, n)
|
||||||
|
# on average, each array will have 5k elements
|
||||||
|
# with int64, each array will have 40kb
|
||||||
|
return [np.random.randint(1, 100, i) for i in sizes]
|
||||||
|
|
||||||
|
|
||||||
def distributed_run(fn, world_size):
|
def distributed_run(fn, world_size):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes = []
|
processes = []
|
||||||
@ -47,24 +57,31 @@ def worker_fn_wrapper(fn):
|
|||||||
def worker_fn():
|
def worker_fn():
|
||||||
writer_rank = 2
|
writer_rank = 2
|
||||||
broadcaster = ShmRingBufferIO.create_from_process_group(
|
broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||||
dist.group.WORLD, 1024, 2, writer_rank)
|
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
|
||||||
if dist.get_rank() == writer_rank:
|
if dist.get_rank() == writer_rank:
|
||||||
time.sleep(random.random())
|
seed = random.randint(0, 1000)
|
||||||
broadcaster.broadcast_object(0)
|
dist.broadcast_object_list([seed], writer_rank)
|
||||||
time.sleep(random.random())
|
|
||||||
broadcaster.broadcast_object({})
|
|
||||||
time.sleep(random.random())
|
|
||||||
broadcaster.broadcast_object([])
|
|
||||||
else:
|
else:
|
||||||
time.sleep(random.random())
|
recv = [None]
|
||||||
a = broadcaster.broadcast_object(None)
|
dist.broadcast_object_list(recv, writer_rank)
|
||||||
time.sleep(random.random())
|
seed = recv[0] # type: ignore
|
||||||
b = broadcaster.broadcast_object(None)
|
dist.barrier()
|
||||||
time.sleep(random.random())
|
# in case we find a race condition
|
||||||
c = broadcaster.broadcast_object(None)
|
# print the seed so that we can reproduce the error
|
||||||
assert a == 0
|
print(f"Rank {dist.get_rank()} got seed {seed}")
|
||||||
assert b == {}
|
# test broadcasting with about 400MB of data
|
||||||
assert c == []
|
N = 10_000
|
||||||
|
if dist.get_rank() == writer_rank:
|
||||||
|
arrs = get_arrays(N, seed)
|
||||||
|
for x in arrs:
|
||||||
|
broadcaster.broadcast_object(x)
|
||||||
|
time.sleep(random.random() / 1000)
|
||||||
|
else:
|
||||||
|
arrs = get_arrays(N, seed)
|
||||||
|
for x in arrs:
|
||||||
|
y = broadcaster.broadcast_object(None)
|
||||||
|
assert np.array_equal(x, y)
|
||||||
|
time.sleep(random.random() / 1000)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,12 @@ from vllm.logger import init_logger
|
|||||||
|
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||||
|
|
||||||
|
# time to wait if the queue is full or empty
|
||||||
|
# if we sleep for too short, it will consume too much CPU
|
||||||
|
# if we sleep for too long, it will slow down the writer/reader
|
||||||
|
# 0.1 us is a good balance
|
||||||
|
RINGBUFFER_SLEEP_INTERVAL = 1e-7
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -145,8 +151,7 @@ class ShmRingBufferIO:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_write(self):
|
def acquire_write(self):
|
||||||
assert self._is_writer, "Only writers can acquire write"
|
assert self._is_writer, "Only writers can acquire write"
|
||||||
start_index = self.current_idx
|
start_time = time.monotonic()
|
||||||
start_time = time.time()
|
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
while True:
|
while True:
|
||||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||||
@ -154,19 +159,21 @@ class ShmRingBufferIO:
|
|||||||
written_flag = metadata_buffer[0]
|
written_flag = metadata_buffer[0]
|
||||||
if written_flag and read_count != self.buffer.n_reader:
|
if written_flag and read_count != self.buffer.n_reader:
|
||||||
# this block is written and not read by all readers
|
# this block is written and not read by all readers
|
||||||
# try to write to the next block
|
# for writers, `self.current_idx` is the next block to write
|
||||||
self.current_idx = (self.current_idx +
|
# if this block is not ready to write,
|
||||||
1) % self.buffer.max_chunks
|
# we need to wait until it is read by all readers
|
||||||
if self.current_idx == start_index:
|
|
||||||
# no empty block found
|
# wait for a while
|
||||||
if time.time(
|
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
|
||||||
|
|
||||||
|
# if we wait for a long time, we should warn the user
|
||||||
|
if time.monotonic(
|
||||||
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No available block found in %s second. ",
|
"No available block found in %s second. ",
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
# wait for a while (0.1 us)
|
|
||||||
time.sleep(1e-7)
|
|
||||||
continue
|
continue
|
||||||
# found a block that is either
|
# found a block that is either
|
||||||
# (1) not written
|
# (1) not written
|
||||||
@ -188,13 +195,14 @@ class ShmRingBufferIO:
|
|||||||
metadata_buffer[i] = 0
|
metadata_buffer[i] = 0
|
||||||
# mark the block as written
|
# mark the block as written
|
||||||
metadata_buffer[0] = 1
|
metadata_buffer[0] = 1
|
||||||
|
self.current_idx = (self.current_idx +
|
||||||
|
1) % self.buffer.max_chunks
|
||||||
break
|
break
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_read(self):
|
def acquire_read(self):
|
||||||
assert self._is_reader, "Only readers can acquire read"
|
assert self._is_reader, "Only readers can acquire read"
|
||||||
start_index = self.current_idx
|
start_time = time.monotonic()
|
||||||
start_time = time.time()
|
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
while True:
|
while True:
|
||||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||||
@ -204,19 +212,22 @@ class ShmRingBufferIO:
|
|||||||
# this block is either
|
# this block is either
|
||||||
# (1) not written
|
# (1) not written
|
||||||
# (2) already read by this reader
|
# (2) already read by this reader
|
||||||
# try to read the next block
|
|
||||||
self.current_idx = (self.current_idx +
|
# for readers, `self.current_idx` is the next block to read
|
||||||
1) % self.buffer.max_chunks
|
# if this block is not ready,
|
||||||
if self.current_idx == start_index:
|
# we need to wait until it is written
|
||||||
# no block found
|
|
||||||
if time.time(
|
# wait for a while
|
||||||
|
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
|
||||||
|
|
||||||
|
# if we wait for a long time, we should warn the user
|
||||||
|
if time.monotonic(
|
||||||
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No available block found in %s second. ",
|
"No available block found in %s second. ",
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
# wait for a while (0.1 us)
|
|
||||||
time.sleep(1e-7)
|
|
||||||
continue
|
continue
|
||||||
# found a block that is not read by this reader
|
# found a block that is not read by this reader
|
||||||
# let caller read from the buffer
|
# let caller read from the buffer
|
||||||
@ -226,6 +237,8 @@ class ShmRingBufferIO:
|
|||||||
# caller has read from the buffer
|
# caller has read from the buffer
|
||||||
# set the read flag
|
# set the read flag
|
||||||
metadata_buffer[self.reader_rank + 1] = 1
|
metadata_buffer[self.reader_rank + 1] = 1
|
||||||
|
self.current_idx = (self.current_idx +
|
||||||
|
1) % self.buffer.max_chunks
|
||||||
break
|
break
|
||||||
|
|
||||||
def enqueue(self, obj):
|
def enqueue(self, obj):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user