[core][distributed] use tcp store directly (#10275)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-12 17:36:08 -08:00 committed by GitHub
parent 112fa0bbe5
commit 0d4ea3fb5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 25 deletions

View File

@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
def cpu_worker(rank, WORLD_SIZE, port1, port2): def cpu_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create( pg2 = StatelessProcessGroup.create(host="127.0.0.1",
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) port=port2,
rank=rank,
world_size=3)
data = torch.tensor([rank]) data = torch.tensor([rank])
data = pg1.broadcast_obj(data, src=2) data = pg1.broadcast_obj(data, src=2)
assert data.item() == 2 assert data.item() == 2
@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
def gpu_worker(rank, WORLD_SIZE, port1, port2): def gpu_worker(rank, WORLD_SIZE, port1, port2):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
pynccl1 = PyNcclCommunicator(pg1, device=rank) pynccl1 = PyNcclCommunicator(pg1, device=rank)
pynccl1.disabled = False pynccl1.disabled = False
if rank <= 2: if rank <= 2:
pg2 = StatelessProcessGroup.create( pg2 = StatelessProcessGroup.create(host="127.0.0.1",
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3) port=port2,
rank=rank,
world_size=3)
pynccl2 = PyNcclCommunicator(pg2, device=rank) pynccl2 = PyNcclCommunicator(pg2, device=rank)
pynccl2.disabled = False pynccl2.disabled = False
data = torch.tensor([rank]).cuda() data = torch.tensor([rank]).cuda()
@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
def broadcast_worker(rank, WORLD_SIZE, port1, port2): def broadcast_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
if rank == 2: if rank == 2:
@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
def allgather_worker(rank, WORLD_SIZE, port1, port2): def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}", pg1 = StatelessProcessGroup.create(host="127.0.0.1",
port=port1,
rank=rank, rank=rank,
world_size=WORLD_SIZE) world_size=WORLD_SIZE)
data = pg1.all_gather_obj(rank) data = pg1.all_gather_obj(rank)
@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
pg1.barrier() pg1.barrier()
# TODO: investigate why this test is flaky. It hangs during initialization.
@pytest.mark.skip("Skip the test because it is flaky.")
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]) "worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])

View File

@ -9,7 +9,7 @@ from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch import torch
from torch.distributed.rendezvous import rendezvous from torch.distributed import TCPStore
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes. group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects. For data-plane communication, create NCCL-related objects.
""" """
prefix: str
rank: int rank: int
world_size: int world_size: int
store: torch._C._distributed_c10d.Store store: torch._C._distributed_c10d.Store
@ -127,7 +126,7 @@ class StatelessProcessGroup:
def send_obj(self, obj: Any, dst: int): def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank.""" """Send an object to a destination rank."""
self.expire_data() self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}" key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1 self.send_dst_counter[dst] += 1
self.entries.append((key, time.time())) self.entries.append((key, time.time()))
@ -147,8 +146,7 @@ class StatelessProcessGroup:
"""Receive an object from a source rank.""" """Receive an object from a source rank."""
obj = pickle.loads( obj = pickle.loads(
self.store.get( self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}" f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
))
self.recv_src_counter[src] += 1 self.recv_src_counter[src] += 1
return obj return obj
@ -159,14 +157,14 @@ class StatelessProcessGroup:
""" """
if self.rank == src: if self.rank == src:
self.expire_data() self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/" key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}") f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj)) self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1 self.broadcast_send_counter += 1
self.entries.append((key, time.time())) self.entries.append((key, time.time()))
return obj return obj
else: else:
key = (f"{self.prefix}/broadcast_from/{src}/" key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}") f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key)) recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1 self.broadcast_recv_src_counter[src] += 1
@ -194,7 +192,8 @@ class StatelessProcessGroup:
@staticmethod @staticmethod
def create( def create(
init_method: str, host: str,
port: int,
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
@ -214,15 +213,14 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B, can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT store = TCPStore(
timeout = _DEFAULT_PG_TIMEOUT host_name=host,
port=port,
store, rank, world_size = next( world_size=world_size,
rendezvous(init_method, rank, world_size, timeout=timeout)) is_master=(rank == 0),
store.set_timeout(timeout) )
return StatelessProcessGroup( return StatelessProcessGroup(
prefix=init_method,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
store=store, store=store,