mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:15:27 +08:00
Enable symmetric memory all reduce by default only enabling for TP (#25070)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
a8ffc4f0f2
commit
8bdd8b5c51
@ -164,6 +164,7 @@ steps:
|
|||||||
- tests/v1/test_internal_lb_dp.py
|
- tests/v1/test_internal_lb_dp.py
|
||||||
- tests/v1/test_hybrid_lb_dp.py
|
- tests/v1/test_hybrid_lb_dp.py
|
||||||
- tests/v1/engine/test_engine_core_client.py
|
- tests/v1/engine/test_engine_core_client.py
|
||||||
|
- tests/distributed/test_symm_mem_allreduce.py
|
||||||
commands:
|
commands:
|
||||||
# test with torchrun tp=2 and external_dp=2
|
# test with torchrun tp=2 and external_dp=2
|
||||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
@ -188,6 +189,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
- pytest -v -s distributed/test_events.py
|
- pytest -v -s distributed/test_events.py
|
||||||
|
- pytest -v -s distributed/test_symm_mem_allreduce.py
|
||||||
# TODO: create a dedicated test section for multi-GPU example tests
|
# TODO: create a dedicated test section for multi-GPU example tests
|
||||||
# when we have multiple distributed example tests
|
# when we have multiple distributed example tests
|
||||||
- pushd ../examples/offline_inference
|
- pushd ../examples/offline_inference
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import queue
|
||||||
import random
|
import random
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
@ -10,26 +11,31 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.distributed import cleanup_dist_env_and_memory
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||||
CudaCommunicator)
|
CudaCommunicator)
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
from vllm.distributed.parallel_state import (get_tp_group,
|
||||||
get_tp_group,
|
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
random.seed(44)
|
random.seed(44)
|
||||||
|
|
||||||
test_size_elements = 4 * 1024 * 1024
|
test_size_elements = 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
|
||||||
monkeypatch = pytest.MonkeyPatch()
|
monkeypatch = pytest.MonkeyPatch()
|
||||||
with monkeypatch.context() as m:
|
config = VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
tensor_parallel_size=world_size))
|
||||||
|
|
||||||
|
with monkeypatch.context() as m, set_current_vllm_config(config):
|
||||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
|||||||
get_tp_group().device_communicator)
|
get_tp_group().device_communicator)
|
||||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||||
pytest.skip("SymmMemCommunicator is not available or disabled.")
|
# can't use skip under multiprocessing
|
||||||
|
q.put("SymmMemCommunicator is not available or disabled.")
|
||||||
|
return
|
||||||
|
|
||||||
inp_direct_symm_mem = torch.randint(1,
|
inp_direct_symm_mem = torch.randint(1,
|
||||||
23, (test_size_elements, ),
|
23, (test_size_elements, ),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device)
|
device=device)
|
||||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||||
pytest.skip(
|
# can't use skip under multiprocessing
|
||||||
|
q.put(
|
||||||
"SymmMemCommunicator isn't used for this world and input size."
|
"SymmMemCommunicator isn't used for this world and input size."
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||||
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
||||||
assert out_direct_symm_mem is not None
|
assert out_direct_symm_mem is not None
|
||||||
|
|
||||||
group = get_tensor_model_parallel_group().device_group
|
group = get_tp_group().device_group
|
||||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||||
torch.testing.assert_close(out_direct_symm_mem,
|
torch.testing.assert_close(out_direct_symm_mem,
|
||||||
original_inp_direct_symm_mem,
|
original_inp_direct_symm_mem,
|
||||||
@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
|||||||
world_size = tp_size * pipeline_parallel_size
|
world_size = tp_size * pipeline_parallel_size
|
||||||
if world_size > torch.cuda.device_count():
|
if world_size > torch.cuda.device_count():
|
||||||
pytest.skip("Not enough GPUs to run the test.")
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
q = mp.get_context('spawn').Queue()
|
||||||
# Enable SymmMemCommunicator
|
mp.spawn(symm_mem_allreduce_worker,
|
||||||
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
|
args=(world_size, q),
|
||||||
|
nprocs=world_size)
|
||||||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
|
try:
|
||||||
|
val = q.get(timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
val = None
|
||||||
|
finally:
|
||||||
cleanup_dist_env_and_memory()
|
cleanup_dist_env_and_memory()
|
||||||
|
if val is not None:
|
||||||
|
pytest.skip(val)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda(),
|
||||||
|
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||||
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||||
|
reason="Only test on CUDA")
|
||||||
|
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
world_size = 4
|
||||||
|
if world_size > torch.cuda.device_count():
|
||||||
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
# Verify that the DataParallel runs without error
|
||||||
|
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
data_parallel_size=2,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
data_parallel_backend="mp")
|
||||||
|
LLMEngine.from_engine_args(engine_args)
|
||||||
|
|||||||
@ -30,18 +30,21 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
unique_name: str = ""):
|
unique_name: str = ""):
|
||||||
super().__init__(cpu_group, device, device_group, unique_name)
|
super().__init__(cpu_group, device, device_group, unique_name)
|
||||||
if "tp" not in unique_name:
|
if "tp" not in unique_name:
|
||||||
# only tp uses custom allreduce
|
# custom allreduce or torch symm mem can be used only by tp
|
||||||
use_custom_allreduce = False
|
use_custom_allreduce = False
|
||||||
|
use_torch_symm_mem = False
|
||||||
else:
|
else:
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
_ENABLE_CUSTOM_ALL_REDUCE)
|
_ENABLE_CUSTOM_ALL_REDUCE)
|
||||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
|
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||||
|
|
||||||
# ep does not use pynccl
|
# ep does not use pynccl
|
||||||
use_pynccl = "ep" not in unique_name
|
use_pynccl = "ep" not in unique_name
|
||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
self.use_torch_symm_mem = use_torch_symm_mem
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
# lazy import to avoid documentation build error
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
@ -65,7 +68,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
self.ca_comm: Optional[CustomAllreduce] = None
|
self.ca_comm: Optional[CustomAllreduce] = None
|
||||||
self.qr_comm: Optional[QuickAllReduce] = None
|
self.qr_comm: Optional[QuickAllReduce] = None
|
||||||
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
||||||
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
|
if use_torch_symm_mem and current_platform.is_cuda():
|
||||||
self.symm_mem_comm = SymmMemCommunicator(
|
self.symm_mem_comm = SymmMemCommunicator(
|
||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@ -182,7 +182,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||||
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
||||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||||
@ -1370,7 +1370,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# Whether to use pytorch symmetric memory for allreduce
|
# Whether to use pytorch symmetric memory for allreduce
|
||||||
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
"VLLM_ALLREDUCE_USE_SYMM_MEM":
|
||||||
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
|
||||||
|
|
||||||
# Allows vllm to find tuned config under customized folder
|
# Allows vllm to find tuned config under customized folder
|
||||||
"VLLM_TUNED_CONFIG_FOLDER":
|
"VLLM_TUNED_CONFIG_FOLDER":
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user