mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 10:26:11 +08:00
[core][distributed] add stateless process group (#10216)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
36fc439de0
commit
e6de9784d2
@ -1,10 +1,10 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.utils import stateless_init_process_group
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.utils import (cuda_device_count_stateless,
|
from vllm.utils import (cuda_device_count_stateless,
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
|
||||||
@ -41,42 +41,45 @@ def test_cuda_device_count_stateless():
|
|||||||
|
|
||||||
|
|
||||||
def cpu_worker(rank, WORLD_SIZE):
|
def cpu_worker(rank, WORLD_SIZE):
|
||||||
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500",
|
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE,
|
world_size=WORLD_SIZE)
|
||||||
backend="gloo")
|
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501",
|
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=3,
|
world_size=3)
|
||||||
backend="gloo")
|
|
||||||
data = torch.tensor([rank])
|
data = torch.tensor([rank])
|
||||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
|
data = pg1.broadcast_obj(data, src=2)
|
||||||
|
assert data.item() == 2
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
|
data = torch.tensor([rank + 1])
|
||||||
item = data[0].item()
|
data = pg2.broadcast_obj(data, src=2)
|
||||||
print(f"rank: {rank}, item: {item}")
|
assert data.item() == 3
|
||||||
if rank == 3:
|
pg2.barrier()
|
||||||
assert item == 6
|
pg1.barrier()
|
||||||
else:
|
|
||||||
assert item == 18
|
|
||||||
|
|
||||||
|
|
||||||
def gpu_worker(rank, WORLD_SIZE):
|
def gpu_worker(rank, WORLD_SIZE):
|
||||||
pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502",
|
|
||||||
rank=rank,
|
|
||||||
world_size=WORLD_SIZE,
|
|
||||||
backend="nccl")
|
|
||||||
if rank <= 2:
|
|
||||||
pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503",
|
|
||||||
rank=rank,
|
|
||||||
world_size=3,
|
|
||||||
backend="nccl")
|
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
data = torch.tensor([rank]).cuda()
|
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
|
||||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1)
|
rank=rank,
|
||||||
|
world_size=WORLD_SIZE)
|
||||||
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
|
pynccl1.disabled = False
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2)
|
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
|
||||||
|
rank=rank,
|
||||||
|
world_size=3)
|
||||||
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
|
pynccl2.disabled = False
|
||||||
|
data = torch.tensor([rank]).cuda()
|
||||||
|
pynccl1.all_reduce(data)
|
||||||
|
pg1.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if rank <= 2:
|
||||||
|
pynccl2.all_reduce(data)
|
||||||
|
pg2.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
item = data[0].item()
|
item = data[0].item()
|
||||||
print(f"rank: {rank}, item: {item}")
|
print(f"rank: {rank}, item: {item}")
|
||||||
if rank == 3:
|
if rank == 3:
|
||||||
@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE):
|
|||||||
assert item == 18
|
assert item == 18
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_worker(rank, WORLD_SIZE):
|
||||||
|
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
|
||||||
|
rank=rank,
|
||||||
|
world_size=WORLD_SIZE)
|
||||||
|
if rank == 2:
|
||||||
|
pg1.broadcast_obj("secret", src=2)
|
||||||
|
else:
|
||||||
|
obj = pg1.broadcast_obj(None, src=2)
|
||||||
|
assert obj == "secret"
|
||||||
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def allgather_worker(rank, WORLD_SIZE):
|
||||||
|
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
|
||||||
|
rank=rank,
|
||||||
|
world_size=WORLD_SIZE)
|
||||||
|
data = pg1.all_gather_obj(rank)
|
||||||
|
assert data == list(range(WORLD_SIZE))
|
||||||
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker])
|
@pytest.mark.parametrize(
|
||||||
def test_stateless_init_process_group(worker):
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||||
|
def test_stateless_process_group(worker):
|
||||||
WORLD_SIZE = 4
|
WORLD_SIZE = 4
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
ctx = get_context("fork")
|
ctx = get_context("fork")
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp
|
|||||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||||
ncclRedOpTypeEnum, ncclUniqueId)
|
ncclRedOpTypeEnum, ncclUniqueId)
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -18,7 +19,7 @@ class PyNcclCommunicator:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group: ProcessGroup,
|
group: Union[ProcessGroup, StatelessProcessGroup],
|
||||||
device: Union[int, str, torch.device],
|
device: Union[int, str, torch.device],
|
||||||
library_path: Optional[str] = None,
|
library_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -33,13 +34,18 @@ class PyNcclCommunicator:
|
|||||||
It is the caller's responsibility to make sure each communicator
|
It is the caller's responsibility to make sure each communicator
|
||||||
is bind to a unique device.
|
is bind to a unique device.
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(group, StatelessProcessGroup):
|
||||||
assert dist.is_initialized()
|
assert dist.is_initialized()
|
||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||||
self.group = group
|
|
||||||
# note: this rank is the rank in the group
|
# note: this rank is the rank in the group
|
||||||
self.rank = dist.get_rank(group)
|
self.rank = dist.get_rank(group)
|
||||||
self.world_size = dist.get_world_size(group)
|
self.world_size = dist.get_world_size(group)
|
||||||
|
else:
|
||||||
|
self.rank = group.rank
|
||||||
|
self.world_size = group.world_size
|
||||||
|
|
||||||
|
self.group = group
|
||||||
|
|
||||||
# if world_size == 1, no need to create communicator
|
# if world_size == 1, no need to create communicator
|
||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
@ -68,6 +74,8 @@ class PyNcclCommunicator:
|
|||||||
else:
|
else:
|
||||||
# construct an empty unique id
|
# construct an empty unique id
|
||||||
self.unique_id = ncclUniqueId()
|
self.unique_id = ncclUniqueId()
|
||||||
|
|
||||||
|
if not isinstance(group, StatelessProcessGroup):
|
||||||
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
tensor = torch.ByteTensor(list(self.unique_id.internal))
|
||||||
ranks = dist.get_process_group_ranks(group)
|
ranks = dist.get_process_group_ranks(group)
|
||||||
# arg `src` in `broadcast` is the global rank
|
# arg `src` in `broadcast` is the global rank
|
||||||
@ -75,6 +83,8 @@ class PyNcclCommunicator:
|
|||||||
byte_list = tensor.tolist()
|
byte_list = tensor.tolist()
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
|
else:
|
||||||
|
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
|
||||||
if isinstance(device, int):
|
if isinstance(device, int):
|
||||||
device = torch.device(f"cuda:{device}")
|
device = torch.device(f"cuda:{device}")
|
||||||
elif isinstance(device, str):
|
elif isinstance(device, str):
|
||||||
|
|||||||
@ -2,13 +2,13 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
from typing import Sequence, Tuple
|
import dataclasses
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
|
||||||
_get_default_timeout,
|
|
||||||
is_nccl_available)
|
|
||||||
from torch.distributed.rendezvous import rendezvous
|
from torch.distributed.rendezvous import rendezvous
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -91,8 +91,114 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
|||||||
return (start_layer, end_layer)
|
return (start_layer, end_layer)
|
||||||
|
|
||||||
|
|
||||||
def stateless_init_process_group(init_method: str, rank: int, world_size: int,
|
@dataclasses.dataclass
|
||||||
backend: str) -> ProcessGroup:
|
class StatelessProcessGroup:
|
||||||
|
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
||||||
|
group. Only use it to communicate metadata between processes.
|
||||||
|
For data-plane communication, create NCCL-related objects.
|
||||||
|
"""
|
||||||
|
prefix: str
|
||||||
|
rank: int
|
||||||
|
world_size: int
|
||||||
|
store: torch._C._distributed_c10d.Store
|
||||||
|
data_expiration_seconds: int = 3600 # 1 hour
|
||||||
|
|
||||||
|
# dst rank -> counter
|
||||||
|
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||||
|
# src rank -> counter
|
||||||
|
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
||||||
|
broadcast_send_counter: int = 0
|
||||||
|
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(
|
||||||
|
default_factory=dict)
|
||||||
|
|
||||||
|
# A deque to store the data entries, with key and timestamp.
|
||||||
|
entries: Deque[Tuple[str,
|
||||||
|
float]] = dataclasses.field(default_factory=deque)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
assert self.rank < self.world_size
|
||||||
|
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
||||||
|
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
||||||
|
self.broadcast_recv_src_counter = {
|
||||||
|
i: 0
|
||||||
|
for i in range(self.world_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
def send_obj(self, obj: Any, dst: int):
|
||||||
|
"""Send an object to a destination rank."""
|
||||||
|
self.expire_data()
|
||||||
|
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||||
|
self.store.set(key, pickle.dumps(obj))
|
||||||
|
self.send_dst_counter[dst] += 1
|
||||||
|
self.entries.append((key, time.time()))
|
||||||
|
|
||||||
|
def expire_data(self):
|
||||||
|
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
||||||
|
while self.entries:
|
||||||
|
# check the oldest entry
|
||||||
|
key, timestamp = self.entries[0]
|
||||||
|
if time.time() - timestamp > self.data_expiration_seconds:
|
||||||
|
self.store.delete_key(key)
|
||||||
|
self.entries.popleft()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
def recv_obj(self, src: int) -> Any:
|
||||||
|
"""Receive an object from a source rank."""
|
||||||
|
obj = pickle.loads(
|
||||||
|
self.store.get(
|
||||||
|
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
|
||||||
|
))
|
||||||
|
self.recv_src_counter[src] += 1
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
||||||
|
"""Broadcast an object from a source rank to all other ranks.
|
||||||
|
It does not clean up after all ranks have received the object.
|
||||||
|
Use it for limited times, e.g., for initialization.
|
||||||
|
"""
|
||||||
|
if self.rank == src:
|
||||||
|
self.expire_data()
|
||||||
|
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||||
|
f"{self.broadcast_send_counter}")
|
||||||
|
self.store.set(key, pickle.dumps(obj))
|
||||||
|
self.broadcast_send_counter += 1
|
||||||
|
self.entries.append((key, time.time()))
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||||
|
f"{self.broadcast_recv_src_counter[src]}")
|
||||||
|
recv_obj = pickle.loads(self.store.get(key))
|
||||||
|
self.broadcast_recv_src_counter[src] += 1
|
||||||
|
return recv_obj
|
||||||
|
|
||||||
|
def all_gather_obj(self, obj: Any) -> list[Any]:
|
||||||
|
"""All gather an object from all ranks."""
|
||||||
|
gathered_objs = []
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i == self.rank:
|
||||||
|
gathered_objs.append(obj)
|
||||||
|
self.broadcast_obj(obj, src=self.rank)
|
||||||
|
else:
|
||||||
|
recv_obj = self.broadcast_obj(None, src=i)
|
||||||
|
gathered_objs.append(recv_obj)
|
||||||
|
return gathered_objs
|
||||||
|
|
||||||
|
def barrier(self):
|
||||||
|
"""A barrier to synchronize all ranks."""
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i == self.rank:
|
||||||
|
self.broadcast_obj(None, src=self.rank)
|
||||||
|
else:
|
||||||
|
self.broadcast_obj(None, src=i)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
init_method: str,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
data_expiration_seconds: int = 3600,
|
||||||
|
) -> "StatelessProcessGroup":
|
||||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||||
pollute the global state.
|
pollute the global state.
|
||||||
|
|
||||||
@ -103,57 +209,21 @@ def stateless_init_process_group(init_method: str, rank: int, world_size: int,
|
|||||||
function is a workaround for this issue.
|
function is a workaround for this issue.
|
||||||
|
|
||||||
`torch.distributed.init_process_group` is a global call, while this function
|
`torch.distributed.init_process_group` is a global call, while this function
|
||||||
is a stateless call. It will return a `ProcessGroup` object that can be used
|
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
||||||
for collective communication. With this function, process A and process B
|
used for exchanging metadata. With this function, process A and process B
|
||||||
can call `stateless_init_process_group` to form a group, and then process A, B,
|
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||||
C, and D can call `stateless_init_process_group` to form another group.
|
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
|
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||||
backend = Backend(backend) # it is basically string
|
timeout = _DEFAULT_PG_TIMEOUT
|
||||||
timeout = _get_default_timeout(backend)
|
|
||||||
|
|
||||||
store, rank, world_size = next(
|
store, rank, world_size = next(
|
||||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||||
store.set_timeout(timeout)
|
store.set_timeout(timeout)
|
||||||
|
|
||||||
group_rank = rank
|
return StatelessProcessGroup(
|
||||||
group_size = world_size
|
prefix=init_method,
|
||||||
|
rank=rank,
|
||||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
world_size=world_size,
|
||||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
store=store,
|
||||||
prefix_store = PrefixStore(init_method, store)
|
data_expiration_seconds=data_expiration_seconds)
|
||||||
|
|
||||||
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
|
|
||||||
|
|
||||||
pg: ProcessGroup = ProcessGroup(
|
|
||||||
prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
pg_options,
|
|
||||||
)
|
|
||||||
|
|
||||||
if backend == "gloo":
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
|
||||||
backend_class = ProcessGroupGloo(prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
timeout=timeout)
|
|
||||||
backend_type = ProcessGroup.BackendType.GLOO
|
|
||||||
device = torch.device("cpu")
|
|
||||||
elif backend == "nccl":
|
|
||||||
assert is_nccl_available()
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
|
||||||
|
|
||||||
backend_options = ProcessGroupNCCL.Options()
|
|
||||||
backend_options._timeout = timeout
|
|
||||||
|
|
||||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
|
||||||
backend_options)
|
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
|
||||||
device = torch.device("cuda")
|
|
||||||
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
|
||||||
|
|
||||||
return pg
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user