mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 20:04:25 +08:00
[core][distributed] use tcp store directly (#10275)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
112fa0bbe5
commit
0d4ea3fb5c
@ -43,12 +43,15 @@ def test_cuda_device_count_stateless():
|
|||||||
|
|
||||||
|
|
||||||
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
|
port=port1,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(
|
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
|
port=port2,
|
||||||
|
rank=rank,
|
||||||
|
world_size=3)
|
||||||
data = torch.tensor([rank])
|
data = torch.tensor([rank])
|
||||||
data = pg1.broadcast_obj(data, src=2)
|
data = pg1.broadcast_obj(data, src=2)
|
||||||
assert data.item() == 2
|
assert data.item() == 2
|
||||||
@ -62,14 +65,17 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
|
port=port1,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
pynccl1.disabled = False
|
pynccl1.disabled = False
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(
|
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
|
port=port2,
|
||||||
|
rank=rank,
|
||||||
|
world_size=3)
|
||||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
pynccl2.disabled = False
|
pynccl2.disabled = False
|
||||||
data = torch.tensor([rank]).cuda()
|
data = torch.tensor([rank]).cuda()
|
||||||
@ -89,7 +95,8 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
|
port=port1,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
@ -101,7 +108,8 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
|
|
||||||
|
|
||||||
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
|
port=port1,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
data = pg1.all_gather_obj(rank)
|
data = pg1.all_gather_obj(rank)
|
||||||
@ -109,8 +117,6 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
# TODO: investigate why this test is flaky. It hangs during initialization.
|
|
||||||
@pytest.mark.skip("Skip the test because it is flaky.")
|
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from collections import deque
|
|||||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.rendezvous import rendezvous
|
from torch.distributed import TCPStore
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -97,7 +97,6 @@ class StatelessProcessGroup:
|
|||||||
group. Only use it to communicate metadata between processes.
|
group. Only use it to communicate metadata between processes.
|
||||||
For data-plane communication, create NCCL-related objects.
|
For data-plane communication, create NCCL-related objects.
|
||||||
"""
|
"""
|
||||||
prefix: str
|
|
||||||
rank: int
|
rank: int
|
||||||
world_size: int
|
world_size: int
|
||||||
store: torch._C._distributed_c10d.Store
|
store: torch._C._distributed_c10d.Store
|
||||||
@ -127,7 +126,7 @@ class StatelessProcessGroup:
|
|||||||
def send_obj(self, obj: Any, dst: int):
|
def send_obj(self, obj: Any, dst: int):
|
||||||
"""Send an object to a destination rank."""
|
"""Send an object to a destination rank."""
|
||||||
self.expire_data()
|
self.expire_data()
|
||||||
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
|
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||||
self.store.set(key, pickle.dumps(obj))
|
self.store.set(key, pickle.dumps(obj))
|
||||||
self.send_dst_counter[dst] += 1
|
self.send_dst_counter[dst] += 1
|
||||||
self.entries.append((key, time.time()))
|
self.entries.append((key, time.time()))
|
||||||
@ -147,8 +146,7 @@ class StatelessProcessGroup:
|
|||||||
"""Receive an object from a source rank."""
|
"""Receive an object from a source rank."""
|
||||||
obj = pickle.loads(
|
obj = pickle.loads(
|
||||||
self.store.get(
|
self.store.get(
|
||||||
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
|
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
|
||||||
))
|
|
||||||
self.recv_src_counter[src] += 1
|
self.recv_src_counter[src] += 1
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
@ -159,14 +157,14 @@ class StatelessProcessGroup:
|
|||||||
"""
|
"""
|
||||||
if self.rank == src:
|
if self.rank == src:
|
||||||
self.expire_data()
|
self.expire_data()
|
||||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
key = (f"broadcast_from/{src}/"
|
||||||
f"{self.broadcast_send_counter}")
|
f"{self.broadcast_send_counter}")
|
||||||
self.store.set(key, pickle.dumps(obj))
|
self.store.set(key, pickle.dumps(obj))
|
||||||
self.broadcast_send_counter += 1
|
self.broadcast_send_counter += 1
|
||||||
self.entries.append((key, time.time()))
|
self.entries.append((key, time.time()))
|
||||||
return obj
|
return obj
|
||||||
else:
|
else:
|
||||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
key = (f"broadcast_from/{src}/"
|
||||||
f"{self.broadcast_recv_src_counter[src]}")
|
f"{self.broadcast_recv_src_counter[src]}")
|
||||||
recv_obj = pickle.loads(self.store.get(key))
|
recv_obj = pickle.loads(self.store.get(key))
|
||||||
self.broadcast_recv_src_counter[src] += 1
|
self.broadcast_recv_src_counter[src] += 1
|
||||||
@ -194,7 +192,8 @@ class StatelessProcessGroup:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
init_method: str,
|
host: str,
|
||||||
|
port: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
data_expiration_seconds: int = 3600,
|
data_expiration_seconds: int = 3600,
|
||||||
@ -214,15 +213,14 @@ class StatelessProcessGroup:
|
|||||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||||
""" # noqa
|
""" # noqa
|
||||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
store = TCPStore(
|
||||||
timeout = _DEFAULT_PG_TIMEOUT
|
host_name=host,
|
||||||
|
port=port,
|
||||||
store, rank, world_size = next(
|
world_size=world_size,
|
||||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
is_master=(rank == 0),
|
||||||
store.set_timeout(timeout)
|
)
|
||||||
|
|
||||||
return StatelessProcessGroup(
|
return StatelessProcessGroup(
|
||||||
prefix=init_method,
|
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
store=store,
|
store=store,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user