diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 3c7facc12c59..d40b09a8b868 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -1,10 +1,10 @@ import pytest import ray import torch -import torch.distributed as dist import vllm.envs as envs -from vllm.distributed.utils import stateless_init_process_group +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.utils import StatelessProcessGroup from vllm.utils import (cuda_device_count_stateless, update_environment_variables) @@ -41,42 +41,45 @@ def test_cuda_device_count_stateless(): def cpu_worker(rank, WORLD_SIZE): - pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29500", + pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500", rank=rank, - world_size=WORLD_SIZE, - backend="gloo") + world_size=WORLD_SIZE) if rank <= 2: - pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29501", + pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501", rank=rank, - world_size=3, - backend="gloo") + world_size=3) data = torch.tensor([rank]) - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) + data = pg1.broadcast_obj(data, src=2) + assert data.item() == 2 if rank <= 2: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) - item = data[0].item() - print(f"rank: {rank}, item: {item}") - if rank == 3: - assert item == 6 - else: - assert item == 18 + data = torch.tensor([rank + 1]) + data = pg2.broadcast_obj(data, src=2) + assert data.item() == 3 + pg2.barrier() + pg1.barrier() def gpu_worker(rank, WORLD_SIZE): - pg1 = stateless_init_process_group(init_method="tcp://127.0.0.1:29502", - rank=rank, - world_size=WORLD_SIZE, - backend="nccl") - if rank <= 2: - pg2 = stateless_init_process_group(init_method="tcp://127.0.0.1:29503", - rank=rank, - world_size=3, - backend="nccl") torch.cuda.set_device(rank) - data = torch.tensor([rank]).cuda() - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg1) + pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502", + rank=rank, + world_size=WORLD_SIZE) + pynccl1 = PyNcclCommunicator(pg1, device=rank) + pynccl1.disabled = False if rank <= 2: - dist.all_reduce(data, op=dist.ReduceOp.SUM, group=pg2) + pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503", + rank=rank, + world_size=3) + pynccl2 = PyNcclCommunicator(pg2, device=rank) + pynccl2.disabled = False + data = torch.tensor([rank]).cuda() + pynccl1.all_reduce(data) + pg1.barrier() + torch.cuda.synchronize() + if rank <= 2: + pynccl2.all_reduce(data) + pg2.barrier() + torch.cuda.synchronize() item = data[0].item() print(f"rank: {rank}, item: {item}") if rank == 3: @@ -85,9 +88,31 @@ def gpu_worker(rank, WORLD_SIZE): assert item == 18 +def broadcast_worker(rank, WORLD_SIZE): + pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504", + rank=rank, + world_size=WORLD_SIZE) + if rank == 2: + pg1.broadcast_obj("secret", src=2) + else: + obj = pg1.broadcast_obj(None, src=2) + assert obj == "secret" + pg1.barrier() + + +def allgather_worker(rank, WORLD_SIZE): + pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505", + rank=rank, + world_size=WORLD_SIZE) + data = pg1.all_gather_obj(rank) + assert data == list(range(WORLD_SIZE)) + pg1.barrier() + + @multi_gpu_test(num_gpus=4) -@pytest.mark.parametrize("worker", [cpu_worker, gpu_worker]) -def test_stateless_init_process_group(worker): +@pytest.mark.parametrize( + "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) +def test_stateless_process_group(worker): WORLD_SIZE = 4 from multiprocessing import get_context ctx = get_context("fork") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 731956654567..7c6f48e88637 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -9,6 +9,7 @@ from torch.distributed import ProcessGroup, ReduceOp from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, ncclRedOpTypeEnum, ncclUniqueId) +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,7 +19,7 @@ class PyNcclCommunicator: def __init__( self, - group: ProcessGroup, + group: Union[ProcessGroup, StatelessProcessGroup], device: Union[int, str, torch.device], library_path: Optional[str] = None, ): @@ -33,13 +34,18 @@ class PyNcclCommunicator: It is the caller's responsibility to make sure each communicator is bind to a unique device. """ - assert dist.is_initialized() - assert dist.get_backend(group) != dist.Backend.NCCL, ( - "PyNcclCommunicator should be attached to a non-NCCL group.") + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + self.group = group - # note: this rank is the rank in the group - self.rank = dist.get_rank(group) - self.world_size = dist.get_world_size(group) # if world_size == 1, no need to create communicator if self.world_size == 1: @@ -68,13 +74,17 @@ class PyNcclCommunicator: else: # construct an empty unique id self.unique_id = ncclUniqueId() - tensor = torch.ByteTensor(list(self.unique_id.internal)) - ranks = dist.get_process_group_ranks(group) - # arg `src` in `broadcast` is the global rank - dist.broadcast(tensor, src=ranks[0], group=group) - byte_list = tensor.tolist() - for i, byte in enumerate(byte_list): - self.unique_id.internal[i] = byte + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index d24ce898707f..a77b41322f37 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -2,13 +2,13 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import Sequence, Tuple +import dataclasses +import pickle +import time +from collections import deque +from typing import Any, Deque, Dict, Optional, Sequence, Tuple import torch -from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import (Backend, PrefixStore, - _get_default_timeout, - is_nccl_available) from torch.distributed.rendezvous import rendezvous import vllm.envs as envs @@ -91,69 +91,139 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int, return (start_layer, end_layer) -def stateless_init_process_group(init_method: str, rank: int, world_size: int, - backend: str) -> ProcessGroup: - """A replacement for `torch.distributed.init_process_group` that does not - pollute the global state. +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + prefix: str + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour - If we have process A and process B called `torch.distributed.init_process_group` - to form a group, and then we want to form another group with process A, B, C, - D, it is not possible in PyTorch, because process A and process B have already - formed a group, and process C and process D cannot join that group. This - function is a workaround for this issue. + # dst rank -> counter + send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: Dict[int, int] = dataclasses.field( + default_factory=dict) - `torch.distributed.init_process_group` is a global call, while this function - is a stateless call. It will return a `ProcessGroup` object that can be used - for collective communication. With this function, process A and process B - can call `stateless_init_process_group` to form a group, and then process A, B, - C, and D can call `stateless_init_process_group` to form another group. - """ # noqa + # A deque to store the data entries, with key and timestamp. + entries: Deque[Tuple[str, + float]] = dataclasses.field(default_factory=deque) - backend = Backend(backend) # it is basically string - timeout = _get_default_timeout(backend) + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = { + i: 0 + for i in range(self.world_size) + } - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) - store.set_timeout(timeout) + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) - group_rank = rank - group_size = world_size + def expire_data(self): + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.time() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break - # Use a PrefixStore to avoid accidental overrides of keys used by - # different systems (e.g. RPC) in case the store is multi-tenant. - prefix_store = PrefixStore(init_method, store) + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get( + f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}" + )) + self.recv_src_counter[src] += 1 + return obj - pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) + def broadcast_obj(self, obj: Optional[Any], src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = (f"{self.prefix}/broadcast_from/{src}/" + f"{self.broadcast_send_counter}") + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return obj + else: + key = (f"{self.prefix}/broadcast_from/{src}/" + f"{self.broadcast_recv_src_counter[src]}") + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - pg_options, - ) + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs - if backend == "gloo": - from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo(prefix_store, - group_rank, - group_size, - timeout=timeout) - backend_type = ProcessGroup.BackendType.GLOO - device = torch.device("cpu") - elif backend == "nccl": - assert is_nccl_available() - from torch.distributed.distributed_c10d import ProcessGroupNCCL + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + if i == self.rank: + self.broadcast_obj(None, src=self.rank) + else: + self.broadcast_obj(None, src=i) - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout + @staticmethod + def create( + init_method: str, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. - backend_class._set_sequence_number_for_group() + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT + timeout = _DEFAULT_PG_TIMEOUT - pg._register_backend(device, backend_type, backend_class) + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) - return pg + return StatelessProcessGroup( + prefix=init_method, + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds)