mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
Signed-off-by: dongbo910220 <1275604947@qq.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import multiprocessing
|
|
import random
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch.distributed as dist
|
|
|
|
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
|
from vllm.distributed.utils import StatelessProcessGroup
|
|
from vllm.utils.network_utils import get_open_port
|
|
from vllm.utils.system_utils import update_environment_variables
|
|
|
|
|
|
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
|
np.random.seed(seed)
|
|
sizes = np.random.randint(1, 10_000, n)
|
|
# on average, each array will have 5k elements
|
|
# with int64, each array will have 40kb
|
|
return [np.random.randint(1, 100, i) for i in sizes]
|
|
|
|
|
|
def distributed_run(fn, world_size):
|
|
number_of_processes = world_size
|
|
processes = []
|
|
for i in range(number_of_processes):
|
|
env = {}
|
|
env["RANK"] = str(i)
|
|
env["LOCAL_RANK"] = str(i)
|
|
env["WORLD_SIZE"] = str(number_of_processes)
|
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
|
env["MASTER_ADDR"] = "localhost"
|
|
env["MASTER_PORT"] = "12345"
|
|
p = multiprocessing.Process(target=fn, args=(env,))
|
|
processes.append(p)
|
|
p.start()
|
|
|
|
for p in processes:
|
|
p.join()
|
|
|
|
for p in processes:
|
|
assert p.exitcode == 0
|
|
|
|
|
|
def worker_fn_wrapper(fn):
|
|
# `multiprocessing.Process` cannot accept environment variables directly
|
|
# so we need to pass the environment variables as arguments
|
|
# and update the environment variables in the function
|
|
def wrapped_fn(env):
|
|
update_environment_variables(env)
|
|
dist.init_process_group(backend="gloo")
|
|
fn()
|
|
|
|
return wrapped_fn
|
|
|
|
|
|
@worker_fn_wrapper
|
|
def worker_fn():
|
|
rank = dist.get_rank()
|
|
if rank == 0:
|
|
port = get_open_port()
|
|
ip = "127.0.0.1"
|
|
dist.broadcast_object_list([ip, port], src=0)
|
|
else:
|
|
recv = [None, None]
|
|
dist.broadcast_object_list(recv, src=0)
|
|
ip, port = recv # type: ignore
|
|
|
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
|
|
|
for pg in [dist.group.WORLD, stateless_pg]:
|
|
writer_rank = 2
|
|
broadcaster = MessageQueue.create_from_process_group(
|
|
pg, 40 * 1024, 2, writer_rank
|
|
)
|
|
if rank == writer_rank:
|
|
seed = random.randint(0, 1000)
|
|
dist.broadcast_object_list([seed], writer_rank)
|
|
else:
|
|
recv = [None]
|
|
dist.broadcast_object_list(recv, writer_rank)
|
|
seed = recv[0] # type: ignore
|
|
|
|
if pg == dist.group.WORLD:
|
|
dist.barrier()
|
|
else:
|
|
pg.barrier()
|
|
|
|
# in case we find a race condition
|
|
# print the seed so that we can reproduce the error
|
|
print(f"Rank {rank} got seed {seed}")
|
|
# test broadcasting with about 400MB of data
|
|
N = 10_000
|
|
if rank == writer_rank:
|
|
arrs = get_arrays(N, seed)
|
|
for x in arrs:
|
|
broadcaster.broadcast_object(x)
|
|
time.sleep(random.random() / 1000)
|
|
else:
|
|
arrs = get_arrays(N, seed)
|
|
for x in arrs:
|
|
y = broadcaster.broadcast_object(None)
|
|
assert np.array_equal(x, y)
|
|
time.sleep(random.random() / 1000)
|
|
|
|
if pg == dist.group.WORLD:
|
|
dist.barrier()
|
|
print(f"torch distributed passed the test! Rank {rank}")
|
|
else:
|
|
pg.barrier()
|
|
print(f"StatelessProcessGroup passed the test! Rank {rank}")
|
|
|
|
|
|
def test_shm_broadcast():
|
|
distributed_run(worker_fn, 4)
|