vllm/tests/distributed/test_shm_broadcast.py
Harry Mellor d6953beb91
Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-05 07:06:22 -07:00

117 lines
3.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import random
import time
import numpy as np
import torch.distributed as dist
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_open_port, update_environment_variables
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
np.random.seed(seed)
sizes = np.random.randint(1, 10_000, n)
# on average, each array will have 5k elements
# with int64, each array will have 40kb
return [np.random.randint(1, 100, i) for i in sizes]
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,))
processes.append(p)
p.start()
for p in processes:
p.join()
for p in processes:
assert p.exitcode == 0
def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
dist.init_process_group(backend="gloo")
fn()
return wrapped_fn
@worker_fn_wrapper
def worker_fn():
rank = dist.get_rank()
if rank == 0:
port = get_open_port()
ip = "127.0.0.1"
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv # type: ignore
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
writer_rank = 2
broadcaster = MessageQueue.create_from_process_group(
pg, 40 * 1024, 2, writer_rank
)
if rank == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
else:
recv = [None]
dist.broadcast_object_list(recv, writer_rank)
seed = recv[0] # type: ignore
if pg == dist.group.WORLD:
dist.barrier()
else:
pg.barrier()
# in case we find a race condition
# print the seed so that we can reproduce the error
print(f"Rank {rank} got seed {seed}")
# test broadcasting with about 400MB of data
N = 10_000
if rank == writer_rank:
arrs = get_arrays(N, seed)
for x in arrs:
broadcaster.broadcast_object(x)
time.sleep(random.random() / 1000)
else:
arrs = get_arrays(N, seed)
for x in arrs:
y = broadcaster.broadcast_object(None)
assert np.array_equal(x, y)
time.sleep(random.random() / 1000)
if pg == dist.group.WORLD:
dist.barrier()
print(f"torch distributed passed the test! Rank {rank}")
else:
pg.barrier()
print(f"StatelessProcessGroup passed the test! Rank {rank}")
def test_shm_broadcast():
distributed_run(worker_fn, 4)