[Core][Distributed] add shm broadcast (#5399)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
youkaichao 2024-06-20 22:12:35 -07:00 committed by GitHub
parent 67005a07bc
commit d9a252bc8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 384 additions and 10 deletions

View File

@ -28,9 +28,11 @@ steps:
- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
command: pytest -v -s distributed/test_comm_ops.py
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- pytest -v -s distributed/test_comm_ops.py
- pytest -v -s distributed/test_shm_broadcast.py
- label: Distributed Tests (2 GPUs)
mirror_hardwares: [amd]

View File

@ -0,0 +1,82 @@
import multiprocessing
import random
import time
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.utils import update_environment_variables
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
dist.init_process_group(backend="gloo")
fn()
return wrapped_fn
@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
time.sleep(random.random())
broadcaster.broadcast_object(0)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
else:
time.sleep(random.random())
a = broadcaster.broadcast_object(None)
time.sleep(random.random())
b = broadcaster.broadcast_object(None)
time.sleep(random.random())
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
dist.barrier()
def test_shm_broadcast():
distributed_run(worker_fn, 4)
def test_singe_process():
buffer = ShmRingBuffer(1, 1024, 4)
reader = ShmRingBufferIO(buffer, reader_rank=0)
writer = ShmRingBufferIO(buffer, reader_rank=-1)
writer.enqueue([0])
writer.enqueue([1])
assert reader.dequeue() == [0]
assert reader.dequeue() == [1]

View File

@ -0,0 +1,259 @@
import pickle
import time
from contextlib import contextmanager
from multiprocessing import shared_memory
from typing import Optional
from unittest.mock import patch
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
logger = init_logger(__name__)
class ShmRingBuffer:
def __init__(self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None):
"""
A shared memory ring buffer implementation for broadcast communication.
Essentially, it is a queue where only one will `enqueue` and multiple
will `dequeue`. The max size of each item, together with the max number
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.
Buffer memory layout:
data metadata
| |
| (current_idx) | (current_idx)
v v
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
metadata memory layout: each byte is a flag, the first byte is the written
flag, and the rest are reader flags. The flags are set to 0 by default.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (self.max_chunk_bytes +
self.metadata_size) * self.max_chunks
self.data_offset = 0
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
if name is None:
# we are creating a buffer
self.is_creator = True
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.total_bytes_of_buffer)
# initialize the metadata section to 0
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
else:
# we are opening an existing buffer
self.is_creator = False
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name)
assert self.shared_memory.size == self.total_bytes_of_buffer
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
assert torch.all(tensor == 0)
def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
)
def __del__(self):
self.shared_memory.close()
if self.is_creator:
self.shared_memory.unlink()
@contextmanager
def get_data(self, current_idx: int):
start = self.data_offset + current_idx * self.max_chunk_bytes
end = start + self.max_chunk_bytes
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf
@contextmanager
def get_metadata(self, current_idx: int):
start = self.metadata_offset + current_idx * self.metadata_size
end = start + self.metadata_size
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf
class ShmRingBufferIO:
def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
self.buffer = buffer
self.reader_rank = reader_rank
self._is_writer = self.reader_rank == -1
self._is_reader = not self._is_writer
if self._is_reader:
assert 0 <= self.reader_rank < buffer.n_reader, \
(f"Invalid reader rank {self.reader_rank} for buffer"
f" created with {buffer.n_reader} readers")
self.current_idx = 0
@contextmanager
def acquire_write(self):
assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx
start_time = time.time()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# try to write to the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no empty block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
continue
# found a block that is either
# (1) not written
# (2) read by all readers
# mark the block as not written
metadata_buffer[0] = 0
# let caller write to the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has written to the buffer
# mark the block as written
metadata_buffer[0] = 1
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
break
@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx
start_time = time.time()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
continue
# found a block that is not read by this reader
# let caller read from the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf
# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
break
def enqueue(self, obj):
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if len(serialized_obj) > self.buffer.max_chunk_bytes:
raise RuntimeError(
f"{len(serialized_obj)=} larger than the allowed value "
f"{self.buffer.max_chunk_bytes},"
"Please increase the max_chunk_bytes parameter.")
with self.acquire_write() as buf:
buf[:len(serialized_obj)] = serialized_obj
def dequeue(self):
assert self._is_reader, "Only readers can dequeue"
with self.acquire_read() as buf:
# no need to know the size of serialized object
# pickle format itself contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf)
return obj
def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
return obj
else:
return self.dequeue()
def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "ShmRingBufferIO":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
ranks_inside_group = list(range(group_world_size))
global_ranks = dist.get_process_group_ranks(pg)
n_reader = group_world_size - 1
buffer: ShmRingBuffer
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
dist.barrier(pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
dist.barrier(pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))

View File

@ -98,6 +98,7 @@ class GroupCoordinator:
# communicators are only created for world size > 1
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
shm_broadcaster: Optional[Any] # shared memory broadcaster
def __init__(
self,
@ -162,6 +163,13 @@ class GroupCoordinator:
else:
self.ca_comm = None
from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBufferIO)
self.shm_broadcaster: Optional[ShmRingBufferIO] = None
if self.world_size > 1 and is_in_the_same_node(self.cpu_group):
self.shm_broadcaster = ShmRingBufferIO.create_from_process_group(
self.cpu_group, 1 << 20, 6)
@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
@ -324,6 +332,30 @@ class GroupCoordinator:
group=self.device_group)
return input_
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
"""Broadcast the input object.
NOTE: `src` is the local rank of the source rank.
"""
assert src < self.world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return obj
if self.shm_broadcaster is not None:
assert src == 0, "Shared memory broadcaster only supports src=0"
return self.shm_broadcaster.broadcast_object(obj)
if self.rank_in_group == src:
torch.distributed.broadcast_object_list([obj],
src=self.ranks[src],
group=self.cpu_group)
return obj
else:
recv = [None]
torch.distributed.broadcast_object_list(recv,
src=self.ranks[src],
group=self.cpu_group)
return recv[0]
def broadcast_object_list(self,
obj_list: List[Any],
src: int = 0,
@ -371,9 +403,7 @@ class GroupCoordinator:
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
self.broadcast_object(metadata_list, src=src)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
@ -396,14 +426,10 @@ class GroupCoordinator:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=metadata_group)
assert recv_metadata_list[0] is not None
metadata_list = self.broadcast_object(None, src=src)
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,

View File

@ -5,6 +5,7 @@ if TYPE_CHECKING:
VLLM_HOST_IP: str = ""
VLLM_PORT: Optional[int] = None
VLLM_USE_MODELSCOPE: bool = False
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
@ -114,6 +115,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_INSTANCE_ID":
lambda: os.environ.get("VLLM_INSTANCE_ID", None),
# Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL":
lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")),
# path to cudatoolkit home directory, under which should be bin, include,
# and lib directories.
"CUDA_HOME":