vllm/tests/distributed/test_symm_mem_allreduce.py
Ilya Markov 0313cf854d
[PERF] PyTorch Symmetric Memory All-Reduce (#20759)
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-08-22 15:39:08 -06:00

109 lines
4.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import typing
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
torch.manual_seed(42)
random.seed(44)
test_size_elements = 4 * 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.")
inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip(
"SymmMemCommunicator isn't used for this world and input size."
)
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem,
atol=2.5,
rtol=0.1)
# Test tensor_model_parallel_all_reduce which should use symm_mem
inp_tensor_parallel = torch.randint(-23,
1, (test_size_elements, ),
dtype=dtype,
device=device)
original_inp_tensor_parallel = inp_tensor_parallel.clone()
out_tensor_parallel = tensor_model_parallel_all_reduce(
inp_tensor_parallel)
dist.all_reduce(original_inp_tensor_parallel, group=group)
torch.testing.assert_close(out_tensor_parallel,
original_inp_tensor_parallel,
atol=2.5,
rtol=0.1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
cleanup_dist_env_and_memory()