Enable symmetric memory all reduce by default only enabling for TP (#25070)

Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Ilya Markov 2025-09-23 21:53:00 +02:00 committed by GitHub
parent a8ffc4f0f2
commit 8bdd8b5c51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 16 deletions

View File

@ -164,6 +164,7 @@ steps:
- tests/v1/test_internal_lb_dp.py - tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py - tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
- tests/distributed/test_symm_mem_allreduce.py
commands: commands:
# test with torchrun tp=2 and external_dp=2 # test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@ -188,6 +189,7 @@ steps:
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py - pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests # TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests # when we have multiple distributed example tests
- pushd ../examples/offline_inference - pushd ../examples/offline_inference

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import queue
import random import random
import typing import typing
@ -10,26 +11,31 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import ( from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator) CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, from vllm.distributed.parallel_state import (get_tp_group,
get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import update_environment_variables from vllm.utils import update_environment_variables
torch.manual_seed(42) torch.manual_seed(42)
random.seed(44) random.seed(44)
test_size_elements = 4 * 1024 * 1024 test_size_elements = 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int): def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
monkeypatch = pytest.MonkeyPatch() monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m: config = VllmConfig(parallel_config=ParallelConfig(
tensor_parallel_size=world_size))
with monkeypatch.context() as m, set_current_vllm_config(config):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16 dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}") device = torch.device(f"cuda:{local_rank}")
@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
get_tp_group().device_communicator) get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled: if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.") # can't use skip under multiprocessing
q.put("SymmMemCommunicator is not available or disabled.")
return
inp_direct_symm_mem = torch.randint(1, inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ), 23, (test_size_elements, ),
dtype=dtype, dtype=dtype,
device=device) device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip( # can't use skip under multiprocessing
q.put(
"SymmMemCommunicator isn't used for this world and input size." "SymmMemCommunicator isn't used for this world and input size."
) )
return
original_inp_direct_symm_mem = inp_direct_symm_mem.clone() original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group group = get_tp_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group) dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem, torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem, original_inp_direct_symm_mem,
@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
world_size = tp_size * pipeline_parallel_size world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count(): if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.") pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context('spawn').Queue()
# Enable SymmMemCommunicator mp.spawn(symm_mem_allreduce_worker,
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") args=(world_size, q),
nprocs=world_size)
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) try:
val = q.get(timeout=1)
except queue.Empty:
val = None
finally:
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
if val is not None:
pytest.skip(val)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error
engine_args = EngineArgs(model="distilbert/distilgpt2",
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=2,
tensor_parallel_size=2,
data_parallel_backend="mp")
LLMEngine.from_engine_args(engine_args)

View File

@ -30,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase):
unique_name: str = ""): unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name) super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name: if "tp" not in unique_name:
# only tp uses custom allreduce # custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False use_custom_allreduce = False
use_torch_symm_mem = False
else: else:
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE) _ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl # ep does not use pynccl
use_pynccl = "ep" not in unique_name use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
@ -65,7 +68,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator( self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,

View File

@ -182,7 +182,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
@ -1370,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to use pytorch symmetric memory for allreduce # Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM": "VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
# Allows vllm to find tuned config under customized folder # Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": "VLLM_TUNED_CONFIG_FOLDER":