mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Core][Distributed] add shm broadcast (#5399)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
67005a07bc
commit
d9a252bc8e
@ -28,9 +28,11 @@ steps:
|
||||
|
||||
- label: Distributed Comm Ops Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s distributed/test_comm_ops.py
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- pytest -v -s distributed/test_comm_ops.py
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
|
||||
- label: Distributed Tests (2 GPUs)
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
82
tests/distributed/test_shm_broadcast.py
Normal file
82
tests/distributed/test_shm_broadcast.py
Normal file
@ -0,0 +1,82 @@
|
||||
import multiprocessing
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
ShmRingBuffer, ShmRingBufferIO)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
number_of_processes = world_size
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
env = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
for p in processes:
|
||||
assert p.exitcode == 0
|
||||
|
||||
|
||||
def worker_fn_wrapper(fn):
|
||||
# `multiprocessing.Process` cannot accept environment variables directly
|
||||
# so we need to pass the environment variables as arguments
|
||||
# and update the environment variables in the function
|
||||
def wrapped_fn(env):
|
||||
update_environment_variables(env)
|
||||
dist.init_process_group(backend="gloo")
|
||||
fn()
|
||||
|
||||
return wrapped_fn
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
writer_rank = 2
|
||||
broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||
dist.group.WORLD, 1024, 2, writer_rank)
|
||||
if dist.get_rank() == writer_rank:
|
||||
time.sleep(random.random())
|
||||
broadcaster.broadcast_object(0)
|
||||
time.sleep(random.random())
|
||||
broadcaster.broadcast_object({})
|
||||
time.sleep(random.random())
|
||||
broadcaster.broadcast_object([])
|
||||
else:
|
||||
time.sleep(random.random())
|
||||
a = broadcaster.broadcast_object(None)
|
||||
time.sleep(random.random())
|
||||
b = broadcaster.broadcast_object(None)
|
||||
time.sleep(random.random())
|
||||
c = broadcaster.broadcast_object(None)
|
||||
assert a == 0
|
||||
assert b == {}
|
||||
assert c == []
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def test_shm_broadcast():
|
||||
distributed_run(worker_fn, 4)
|
||||
|
||||
|
||||
def test_singe_process():
|
||||
buffer = ShmRingBuffer(1, 1024, 4)
|
||||
reader = ShmRingBufferIO(buffer, reader_rank=0)
|
||||
writer = ShmRingBufferIO(buffer, reader_rank=-1)
|
||||
writer.enqueue([0])
|
||||
writer.enqueue([1])
|
||||
assert reader.dequeue() == [0]
|
||||
assert reader.dequeue() == [1]
|
||||
259
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
259
vllm/distributed/device_communicators/shm_broadcast.py
Normal file
@ -0,0 +1,259 @@
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(self,
|
||||
n_reader: int,
|
||||
max_chunk_bytes: int,
|
||||
max_chunks: int,
|
||||
name: Optional[str] = None):
|
||||
"""
|
||||
A shared memory ring buffer implementation for broadcast communication.
|
||||
Essentially, it is a queue where only one will `enqueue` and multiple
|
||||
will `dequeue`. The max size of each item, together with the max number
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
| (current_idx) | (current_idx)
|
||||
v v
|
||||
+-------------------------------+----------------------------------------+
|
||||
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
|
||||
+-------------------------------+----------------------------------------+
|
||||
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
|
||||
|
||||
metadata memory layout: each byte is a flag, the first byte is the written
|
||||
flag, and the rest are reader flags. The flags are set to 0 by default.
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
|
||||
+--------------+--------------+--------------+-----+--------------+
|
||||
|
||||
During creation, `name` is None and the buffer is created. We can pass the
|
||||
created object to other processes by pickling it. The other processes will
|
||||
get the name of the shared memory and open it, so that they can access the
|
||||
same shared memory buffer.
|
||||
"""# noqa
|
||||
self.n_reader = n_reader
|
||||
self.metadata_size = 1 + n_reader
|
||||
self.max_chunk_bytes = max_chunk_bytes
|
||||
self.max_chunks = max_chunks
|
||||
self.total_bytes_of_buffer = (self.max_chunk_bytes +
|
||||
self.metadata_size) * self.max_chunks
|
||||
self.data_offset = 0
|
||||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
|
||||
|
||||
if name is None:
|
||||
# we are creating a buffer
|
||||
self.is_creator = True
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.total_bytes_of_buffer)
|
||||
# initialize the metadata section to 0
|
||||
with memoryview(self.shared_memory.buf[self.metadata_offset:]
|
||||
) as metadata_buffer:
|
||||
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
self.is_creator = False
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch("multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None):
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
assert self.shared_memory.size == self.total_bytes_of_buffer
|
||||
with memoryview(self.shared_memory.buf[self.metadata_offset:]
|
||||
) as metadata_buffer:
|
||||
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
|
||||
assert torch.all(tensor == 0)
|
||||
|
||||
def __reduce__(self):
|
||||
return (
|
||||
self.__class__,
|
||||
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||
self.shared_memory.name),
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
self.shared_memory.close()
|
||||
if self.is_creator:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
@contextmanager
|
||||
def get_data(self, current_idx: int):
|
||||
start = self.data_offset + current_idx * self.max_chunk_bytes
|
||||
end = start + self.max_chunk_bytes
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
@contextmanager
|
||||
def get_metadata(self, current_idx: int):
|
||||
start = self.metadata_offset + current_idx * self.metadata_size
|
||||
end = start + self.metadata_size
|
||||
with memoryview(self.shared_memory.buf[start:end]) as buf:
|
||||
yield buf
|
||||
|
||||
|
||||
class ShmRingBufferIO:
|
||||
|
||||
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
|
||||
self.buffer = buffer
|
||||
self.reader_rank = reader_rank
|
||||
self._is_writer = self.reader_rank == -1
|
||||
self._is_reader = not self._is_writer
|
||||
if self._is_reader:
|
||||
assert 0 <= self.reader_rank < buffer.n_reader, \
|
||||
(f"Invalid reader rank {self.reader_rank} for buffer"
|
||||
f" created with {buffer.n_reader} readers")
|
||||
self.current_idx = 0
|
||||
|
||||
@contextmanager
|
||||
def acquire_write(self):
|
||||
assert self._is_writer, "Only writers can acquire write"
|
||||
start_index = self.current_idx
|
||||
start_time = time.time()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_count = sum(metadata_buffer[1:])
|
||||
written_flag = metadata_buffer[0]
|
||||
if written_flag and read_count != self.buffer.n_reader:
|
||||
# this block is written and not read by all readers
|
||||
# try to write to the next block
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
if self.current_idx == start_index:
|
||||
# no empty block found
|
||||
if time.time(
|
||||
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
# wait for a while (0.1 us)
|
||||
time.sleep(1e-7)
|
||||
continue
|
||||
# found a block that is either
|
||||
# (1) not written
|
||||
# (2) read by all readers
|
||||
|
||||
# mark the block as not written
|
||||
metadata_buffer[0] = 0
|
||||
# let caller write to the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has written to the buffer
|
||||
# mark the block as written
|
||||
metadata_buffer[0] = 1
|
||||
for i in range(1, self.buffer.n_reader + 1):
|
||||
# set read flag to 0, meaning it is not read yet
|
||||
metadata_buffer[i] = 0
|
||||
break
|
||||
|
||||
@contextmanager
|
||||
def acquire_read(self):
|
||||
assert self._is_reader, "Only readers can acquire read"
|
||||
start_index = self.current_idx
|
||||
start_time = time.time()
|
||||
n_warning = 1
|
||||
while True:
|
||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||
read_flag = metadata_buffer[self.reader_rank + 1]
|
||||
written_flag = metadata_buffer[0]
|
||||
if not written_flag or read_flag:
|
||||
# this block is either
|
||||
# (1) not written
|
||||
# (2) already read by this reader
|
||||
# try to read the next block
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
if self.current_idx == start_index:
|
||||
# no block found
|
||||
if time.time(
|
||||
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
|
||||
logger.warning(
|
||||
"No available block found in %s second. ",
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||
n_warning += 1
|
||||
# wait for a while (0.1 us)
|
||||
time.sleep(1e-7)
|
||||
continue
|
||||
# found a block that is not read by this reader
|
||||
# let caller read from the buffer
|
||||
with self.buffer.get_data(self.current_idx) as buf:
|
||||
yield buf
|
||||
|
||||
# caller has read from the buffer
|
||||
# set the read flag
|
||||
metadata_buffer[self.reader_rank + 1] = 1
|
||||
break
|
||||
|
||||
def enqueue(self, obj):
|
||||
assert self._is_writer, "Only writers can enqueue"
|
||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
if len(serialized_obj) > self.buffer.max_chunk_bytes:
|
||||
raise RuntimeError(
|
||||
f"{len(serialized_obj)=} larger than the allowed value "
|
||||
f"{self.buffer.max_chunk_bytes},"
|
||||
"Please increase the max_chunk_bytes parameter.")
|
||||
with self.acquire_write() as buf:
|
||||
buf[:len(serialized_obj)] = serialized_obj
|
||||
|
||||
def dequeue(self):
|
||||
assert self._is_reader, "Only readers can dequeue"
|
||||
with self.acquire_read() as buf:
|
||||
# no need to know the size of serialized object
|
||||
# pickle format itself contains the size information internally
|
||||
# see https://docs.python.org/3/library/pickle.html
|
||||
obj = pickle.loads(buf)
|
||||
return obj
|
||||
|
||||
def broadcast_object(self, obj=None):
|
||||
if self._is_writer:
|
||||
self.enqueue(obj)
|
||||
return obj
|
||||
else:
|
||||
return self.dequeue()
|
||||
|
||||
def create_from_process_group(pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0) -> "ShmRingBufferIO":
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
ranks_inside_group = list(range(group_world_size))
|
||||
global_ranks = dist.get_process_group_ranks(pg)
|
||||
n_reader = group_world_size - 1
|
||||
buffer: ShmRingBuffer
|
||||
if group_rank == writer_rank:
|
||||
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
|
||||
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
|
||||
dist.barrier(pg)
|
||||
return ShmRingBufferIO(buffer, -1)
|
||||
else:
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
|
||||
dist.barrier(pg)
|
||||
buffer = recv[0] # type: ignore
|
||||
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
|
||||
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
|
||||
@ -98,6 +98,7 @@ class GroupCoordinator:
|
||||
# communicators are only created for world size > 1
|
||||
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||
shm_broadcaster: Optional[Any] # shared memory broadcaster
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -162,6 +163,13 @@ class GroupCoordinator:
|
||||
else:
|
||||
self.ca_comm = None
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
ShmRingBufferIO)
|
||||
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
|
||||
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
|
||||
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
|
||||
self.cpu_group, 1 << 20, 6)
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@ -324,6 +332,30 @@ class GroupCoordinator:
|
||||
group=self.device_group)
|
||||
return input_
|
||||
|
||||
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
|
||||
"""Broadcast the input object.
|
||||
NOTE: `src` is the local rank of the source rank.
|
||||
"""
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return obj
|
||||
if self.shm_broadcaster is not None:
|
||||
assert src == 0, "Shared memory broadcaster only supports src=0"
|
||||
return self.shm_broadcaster.broadcast_object(obj)
|
||||
if self.rank_in_group == src:
|
||||
torch.distributed.broadcast_object_list([obj],
|
||||
src=self.ranks[src],
|
||||
group=self.cpu_group)
|
||||
return obj
|
||||
else:
|
||||
recv = [None]
|
||||
torch.distributed.broadcast_object_list(recv,
|
||||
src=self.ranks[src],
|
||||
group=self.cpu_group)
|
||||
return recv[0]
|
||||
|
||||
def broadcast_object_list(self,
|
||||
obj_list: List[Any],
|
||||
src: int = 0,
|
||||
@ -371,9 +403,7 @@ class GroupCoordinator:
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `broadcast_object_list` has serialization & deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=metadata_group)
|
||||
self.broadcast_object(metadata_list, src=src)
|
||||
async_handles = []
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
@ -396,14 +426,10 @@ class GroupCoordinator:
|
||||
async_handle.wait()
|
||||
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
src=src,
|
||||
group=metadata_group)
|
||||
assert recv_metadata_list[0] is not None
|
||||
metadata_list = self.broadcast_object(None, src=src)
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
for key, value in recv_metadata_list[0]:
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
|
||||
@ -5,6 +5,7 @@ if TYPE_CHECKING:
|
||||
VLLM_HOST_IP: str = ""
|
||||
VLLM_PORT: Optional[int] = None
|
||||
VLLM_USE_MODELSCOPE: bool = False
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
|
||||
VLLM_INSTANCE_ID: Optional[str] = None
|
||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||
LD_LIBRARY_PATH: Optional[str] = None
|
||||
@ -114,6 +115,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_INSTANCE_ID":
|
||||
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
|
||||
|
||||
# Interval in seconds to log a warning message when the ring buffer is full
|
||||
"VLLM_RINGBUFFER_WARNING_INTERVAL":
|
||||
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),
|
||||
|
||||
# path to cudatoolkit home directory, under which should be bin, include,
|
||||
# and lib directories.
|
||||
"CUDA_HOME":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user