mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:45:49 +08:00
[misc][distributed] auto port selection and disable tests (#10226)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4800339c62
commit
8a7fe47d32
@ -1,3 +1,5 @@
|
|||||||
|
import socket
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
@ -5,7 +7,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.utils import (cuda_device_count_stateless,
|
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
@ -40,14 +42,13 @@ def test_cuda_device_count_stateless():
|
|||||||
assert ray.get(actor.get_count.remote()) == 0
|
assert ray.get(actor.get_count.remote()) == 0
|
||||||
|
|
||||||
|
|
||||||
def cpu_worker(rank, WORLD_SIZE):
|
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29500",
|
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29501",
|
pg2 = StatelessProcessGroup.create(
|
||||||
rank=rank,
|
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
|
||||||
world_size=3)
|
|
||||||
data = torch.tensor([rank])
|
data = torch.tensor([rank])
|
||||||
data = pg1.broadcast_obj(data, src=2)
|
data = pg1.broadcast_obj(data, src=2)
|
||||||
assert data.item() == 2
|
assert data.item() == 2
|
||||||
@ -59,17 +60,16 @@ def cpu_worker(rank, WORLD_SIZE):
|
|||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
def gpu_worker(rank, WORLD_SIZE):
|
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29502",
|
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
pynccl1.disabled = False
|
pynccl1.disabled = False
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29503",
|
pg2 = StatelessProcessGroup.create(
|
||||||
rank=rank,
|
init_method=f"tcp://127.0.0.1:{port2}", rank=rank, world_size=3)
|
||||||
world_size=3)
|
|
||||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
pynccl2.disabled = False
|
pynccl2.disabled = False
|
||||||
data = torch.tensor([rank]).cuda()
|
data = torch.tensor([rank]).cuda()
|
||||||
@ -88,8 +88,8 @@ def gpu_worker(rank, WORLD_SIZE):
|
|||||||
assert item == 18
|
assert item == 18
|
||||||
|
|
||||||
|
|
||||||
def broadcast_worker(rank, WORLD_SIZE):
|
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29504",
|
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
if rank == 2:
|
if rank == 2:
|
||||||
@ -100,8 +100,8 @@ def broadcast_worker(rank, WORLD_SIZE):
|
|||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
def allgather_worker(rank, WORLD_SIZE):
|
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||||
pg1 = StatelessProcessGroup.create(init_method="tcp://127.0.0.1:29505",
|
pg1 = StatelessProcessGroup.create(init_method=f"tcp://127.0.0.1:{port1}",
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
data = pg1.all_gather_obj(rank)
|
data = pg1.all_gather_obj(rank)
|
||||||
@ -109,17 +109,24 @@ def allgather_worker(rank, WORLD_SIZE):
|
|||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: investigate why this test is flaky. It hangs during initialization.
|
||||||
|
@pytest.mark.skip("Skip the test because it is flaky.")
|
||||||
@multi_gpu_test(num_gpus=4)
|
@multi_gpu_test(num_gpus=4)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||||
def test_stateless_process_group(worker):
|
def test_stateless_process_group(worker):
|
||||||
|
port1 = get_open_port()
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("", port1))
|
||||||
|
port2 = get_open_port()
|
||||||
WORLD_SIZE = 4
|
WORLD_SIZE = 4
|
||||||
from multiprocessing import get_context
|
from multiprocessing import get_context
|
||||||
ctx = get_context("fork")
|
ctx = get_context("fork")
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(WORLD_SIZE):
|
for i in range(WORLD_SIZE):
|
||||||
rank = i
|
rank = i
|
||||||
processes.append(ctx.Process(target=worker, args=(rank, WORLD_SIZE)))
|
processes.append(
|
||||||
|
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.start()
|
p.start()
|
||||||
for p in processes:
|
for p in processes:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user