[core][distributed] add stateless process group (#10216)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-11 09:02:14 -08:00 committed by GitHub
parent 36fc439de0
commit e6de9784d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 206 additions and 101 deletions

View File

@ -1,10 +1,10 @@
import pytest import pytest
import ray import ray
import torch import torch
import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import stateless_init_process_group from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import (cuda_device_count_stateless, from vllm.utils import (cuda_device_count_stateless,
update_environment_variables) update_environment_variables)
@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE): def cpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500", pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
rank=rank, rank=rank,
world_size=WORLD_SIZE, world_size=WORLD_SIZE)
backend="gloo")
if rank <= 2: if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501", pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
rank=rank, rank=rank,
world_size=3, world_size=3)
backend="gloo")
data = torch.tensor([rank]) data = torch.tensor([rank])
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) data = torch.tensor([rank + 1])
item = data[0].item() data = pg2.broadcast_obj(data, src=2)
print(f"rank: {rank}, item: {item}") assert data.item() == 3
if rank == 3: pg2.barrier()
assert item == 6 pg1.barrier()
else:
assert item == 18
def gpu_worker(rank, WORLD_SIZE): def gpu_worker(rank, WORLD_SIZE):
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
rank=rank,
world_size=WORLD_SIZE,
backend="nccl")
if rank <= 2:
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3,
backend="nccl")
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
data = torch.tensor([rank]).cuda() pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) rank=rank,
world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False
if rank <= 2: if rank <= 2:
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False
data = torch.tensor([rank]).cuda()
pynccl1.all_reduce(data)
pg1.barrier()
torch.cuda.synchronize()
if rank <= 2:
pynccl2.all_reduce(data)
pg2.barrier()
torch.cuda.synchronize()
item = data[0].item() item = data[0].item()
print(f"rank: {rank}, item: {item}") print(f"rank: {rank}, item: {item}")
if rank == 3: if rank == 3:
@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
assert item == 18 assert item == 18
def broadcast_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
rank=rank,
world_size=WORLD_SIZE)
if rank == 2:
pg1.broadcast_obj("secret", src=2)
else:
obj = pg1.broadcast_obj(None, src=2)
assert obj == "secret"
pg1.barrier()
def allgather_worker(rank, WORLD_SIZE):
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
rank=rank,
world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank)
assert data == list(range(WORLD_SIZE))
pg1.barrier()
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker]) @pytest.mark.parametrize(
def test_stateless_init_process_group(worker): "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
def test_stateless_process_group(worker):
WORLD_SIZE = 4 WORLD_SIZE = 4
from multiprocessing import get_context from multiprocessing import get_context
ctx = get_context("fork") ctx = get_context("fork")

View File

@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -18,7 +19,7 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: ProcessGroup, group: Union[ProcessGroup, StatelessProcessGroup],
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
@ -33,13 +34,18 @@ class PyNcclCommunicator:
It is the caller's responsibility to make sure each communicator It is the caller's responsibility to make sure each communicator
is bind to a unique device. is bind to a unique device.
""" """
if not isinstance(group, StatelessProcessGroup):
assert dist.is_initialized() assert dist.is_initialized()
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.") "PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group
# note: this rank is the rank in the group # note: this rank is the rank in the group
self.rank = dist.get_rank(group) self.rank = dist.get_rank(group)
self.world_size = dist.get_world_size(group) self.world_size = dist.get_world_size(group)
else:
self.rank = group.rank
self.world_size = group.world_size
self.group = group
# if world_size == 1, no need to create communicator # if world_size == 1, no need to create communicator
if self.world_size == 1: if self.world_size == 1:
@ -68,6 +74,8 @@ class PyNcclCommunicator:
else: else:
# construct an empty unique id # construct an empty unique id
self.unique_id = ncclUniqueId() self.unique_id = ncclUniqueId()
if not isinstance(group, StatelessProcessGroup):
tensor = torch.ByteTensor(list(self.unique_id.internal)) tensor = torch.ByteTensor(list(self.unique_id.internal))
ranks = dist.get_process_group_ranks(group) ranks = dist.get_process_group_ranks(group)
# arg `src` in `broadcast` is the global rank # arg `src` in `broadcast` is the global rank
@ -75,6 +83,8 @@ class PyNcclCommunicator:
byte_list = tensor.tolist() byte_list = tensor.tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
else:
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
if isinstance(device, int): if isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):

View File

@ -2,13 +2,13 @@
# Adapted from # Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence, Tuple import dataclasses
import pickle
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
@ -91,8 +91,114 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
return (start_layer, end_layer) return (start_layer, end_layer)
def stateless_init_process_group(init_method: str, rank: int, world_size: int, @dataclasses.dataclass
backend: str) -> ProcessGroup: class StatelessProcessGroup:
"""A dataclass to hold a metadata store, and the rank, world_size of the
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
# src rank -> counter
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
broadcast_send_counter: int = 0
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
default_factory=dict)
# A deque to store the data entries, with key and timestamp.
entries: Deque[Tuple[str,
float]] = dataclasses.field(default_factory=deque)
def __post_init__(self):
assert self.rank < self.world_size
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
self.broadcast_recv_src_counter = {
i: 0
for i in range(self.world_size)
}
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def expire_data(self):
"""Expire data that is older than `data_expiration_seconds` seconds."""
while self.entries:
# check the oldest entry
key, timestamp = self.entries[0]
if time.time() - timestamp > self.data_expiration_seconds:
self.store.delete_key(key)
self.entries.popleft()
else:
break
def recv_obj(self, src: int) -> Any:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
self.recv_src_counter[src] += 1
return obj
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
"""Broadcast an object from a source rank to all other ranks.
It does not clean up after all ranks have received the object.
Use it for limited times, e.g., for initialization.
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return recv_obj
def all_gather_obj(self, obj: Any) -> list[Any]:
"""All gather an object from all ranks."""
gathered_objs = []
for i in range(self.world_size):
if i == self.rank:
gathered_objs.append(obj)
self.broadcast_obj(obj, src=self.rank)
else:
recv_obj = self.broadcast_obj(None, src=i)
gathered_objs.append(recv_obj)
return gathered_objs
def barrier(self):
"""A barrier to synchronize all ranks."""
for i in range(self.world_size):
if i == self.rank:
self.broadcast_obj(None, src=self.rank)
else:
self.broadcast_obj(None, src=i)
@staticmethod
def create(
init_method: str,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
@ -103,57 +209,21 @@ def stateless_init_process_group(init_method: str, rank: int, world_size: int,
function is a workaround for this issue. function is a workaround for this issue.
`torch.distributed.init_process_group` is a global call, while this function `torch.distributed.init_process_group` is a global call, while this function
is a stateless call. It will return a `ProcessGroup` object that can be used is a stateless call. It will return a `StatelessProcessGroup` object that can be
for collective communication. With this function, process A and process B used for exchanging metadata. With this function, process A and process B
can call `stateless_init_process_group` to form a group, and then process A, B, can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `stateless_init_process_group` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
backend = Backend(backend) # it is basically string timeout = _DEFAULT_PG_TIMEOUT
timeout = _get_default_timeout(backend)
store, rank, world_size = next( store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout)) rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout) store.set_timeout(timeout)
group_rank = rank return StatelessProcessGroup(
group_size = world_size prefix=init_method,
rank=rank,
# Use a PrefixStore to avoid accidental overrides of keys used by world_size=world_size,
# different systems (e.g. RPC) in case the store is multi-tenant. store=store,
prefix_store = PrefixStore(init_method, store) data_expiration_seconds=data_expiration_seconds)
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
pg_options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg