mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: ilmarkov <imarkov@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
109 lines
4.5 KiB
Python
109 lines
4.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import random
|
|
import typing
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
|
from vllm.distributed.device_communicators.cuda_communicator import (
|
|
CudaCommunicator)
|
|
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
|
get_tp_group,
|
|
init_distributed_environment,
|
|
initialize_model_parallel)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import update_environment_variables
|
|
|
|
torch.manual_seed(42)
|
|
random.seed(44)
|
|
|
|
test_size_elements = 4 * 1024 * 1024
|
|
|
|
|
|
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|
monkeypatch = pytest.MonkeyPatch()
|
|
with monkeypatch.context() as m:
|
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
|
dtype = torch.bfloat16
|
|
device = torch.device(f"cuda:{local_rank}")
|
|
torch.cuda.set_device(device)
|
|
torch.set_default_device(device)
|
|
torch.set_default_dtype(dtype)
|
|
update_environment_variables({
|
|
'RANK': str(local_rank),
|
|
'LOCAL_RANK': str(local_rank),
|
|
'WORLD_SIZE': str(world_size),
|
|
'MASTER_ADDR': 'localhost',
|
|
'MASTER_PORT': '12345',
|
|
})
|
|
|
|
init_distributed_environment()
|
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
|
|
cuda_communicator = typing.cast(CudaCommunicator,
|
|
get_tp_group().device_communicator)
|
|
symm_mem_comm = cuda_communicator.symm_mem_comm
|
|
if symm_mem_comm is None or symm_mem_comm.disabled:
|
|
pytest.skip("SymmMemCommunicator is not available or disabled.")
|
|
|
|
inp_direct_symm_mem = torch.randint(1,
|
|
23, (test_size_elements, ),
|
|
dtype=dtype,
|
|
device=device)
|
|
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
|
pytest.skip(
|
|
"SymmMemCommunicator isn't used for this world and input size."
|
|
)
|
|
|
|
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
|
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
|
assert out_direct_symm_mem is not None
|
|
|
|
group = get_tensor_model_parallel_group().device_group
|
|
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
|
torch.testing.assert_close(out_direct_symm_mem,
|
|
original_inp_direct_symm_mem,
|
|
atol=2.5,
|
|
rtol=0.1)
|
|
|
|
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
|
inp_tensor_parallel = torch.randint(-23,
|
|
1, (test_size_elements, ),
|
|
dtype=dtype,
|
|
device=device)
|
|
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
|
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
|
inp_tensor_parallel)
|
|
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
|
torch.testing.assert_close(out_tensor_parallel,
|
|
original_inp_tensor_parallel,
|
|
atol=2.5,
|
|
rtol=0.1)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not current_platform.is_cuda(),
|
|
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
|
@pytest.mark.parametrize("tp_size", [2])
|
|
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
|
reason="Only test on CUDA")
|
|
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
|
pipeline_parallel_size):
|
|
world_size = tp_size * pipeline_parallel_size
|
|
if world_size > torch.cuda.device_count():
|
|
pytest.skip("Not enough GPUs to run the test.")
|
|
|
|
# Enable SymmMemCommunicator
|
|
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
|
|
|
|
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
|
|
cleanup_dist_env_and_memory()
|