[core] add nccl symmetric memory for all reduce (#24532)

Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Amir Samani 2025-09-23 11:33:06 -07:00 committed by GitHub
parent a3a7828010
commit 8c1c81a3de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 489 additions and 6 deletions

View File

@ -1039,3 +1039,4 @@ steps:
num_gpus: 2
commands:
- pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py

View File

@ -7,6 +7,10 @@ Benchmark script for device communicators:
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
and SymmMemCommunicator (multimem, two-shot).
for NCCL symmetric memory you need to set the environment variables
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
not use fast NVLS implementation for all reduce.
Usage:
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
@ -26,7 +30,13 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
register_nccl_symmetric_ops,
)
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id,
)
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
@ -98,6 +108,7 @@ class CommunicatorBenchmark:
)
if not self.pynccl_comm.disabled:
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
register_nccl_symmetric_ops(self.pynccl_comm)
else:
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
self.pynccl_comm = None
@ -194,6 +205,15 @@ class CommunicatorBenchmark:
None, # no env variable needed
)
)
communicators.append(
(
"pynccl-symm",
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
)
)
if self.symm_mem_comm_multimem is not None:
comm = self.symm_mem_comm_multimem
@ -271,7 +291,9 @@ class CommunicatorBenchmark:
# Capture the graph using context manager
with context:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input)

View File

@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import typing
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import (
get_nccl_mem_pool, is_symmetric_memory_enabled)
from vllm.distributed.parallel_state import (get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
torch.manual_seed(42)
random.seed(44)
test_size_elements = 4 * 1024 * 1024
def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
pynccl_comm = cuda_communicator.pynccl_comm
if get_nccl_mem_pool() is None:
pytest.skip("NCCL allocator compilation failed "
"(probably missing NCCL headers).")
if not is_symmetric_memory_enabled():
pytest.skip("NCCL symmetric memory allreduce is disabled.")
register_nccl_symmetric_ops(pynccl_comm)
input = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
input_clone = input.clone()
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
assert output is not None
group = get_tp_group().device_group
dist.all_reduce(input_clone, group=group)
torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
)
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
mp.spawn(nccl_symm_mem_allreduce_worker,
args=(world_size, ),
nprocs=world_size)
cleanup_dist_env_and_memory()

View File

@ -12,6 +12,8 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -154,6 +156,10 @@ class CUDAGraphWrapper:
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool

View File

@ -10,8 +10,9 @@ import sys
import tempfile
from collections.abc import Sequence
from itertools import product
from typing import Optional
from typing import Any, Optional
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
}
}
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
},
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int,
input_tensor: torch.Tensor) -> bool:
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
return (world_size
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
def producer(batch_src: Sequence[int],
producer_queue,

View File

@ -7,6 +7,12 @@ import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.distributed.device_communicators.all_reduce_utils import (
should_nccl_symm_mem_allreduce)
from vllm.distributed.device_communicators.pynccl import (
register_nccl_symmetric_ops)
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled)
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -53,6 +59,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
group=self.cpu_group,
device=self.device,
)
if is_symmetric_memory_enabled():
register_nccl_symmetric_ops(self.pynccl_comm)
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
@ -107,6 +115,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and \
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm

View File

@ -17,6 +17,39 @@ from vllm.utils import current_stream
logger = init_logger(__name__)
_NCCL_SYMM_OPS_REGISTERED = False
def register_nccl_symmetric_ops(pynccl_comm):
from vllm.distributed.device_communicators.pynccl_allocator import (
nccl_symm_mem_context)
from vllm.utils import direct_register_custom_op
global _NCCL_SYMM_OPS_REGISTERED
if _NCCL_SYMM_OPS_REGISTERED:
return
_NCCL_SYMM_OPS_REGISTERED = True
def all_reduce_symmetric_with_copy_impl(
input_tensor: torch.Tensor) -> torch.Tensor:
with nccl_symm_mem_context(pynccl_comm):
symm_input = torch.empty_like(input_tensor)
symm_output = torch.empty_like(input_tensor)
symm_input.copy_(input_tensor)
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
return symm_output
def all_reduce_symmetric_with_copy_fake(
input_tensor: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input_tensor)
direct_register_custom_op(
op_name="all_reduce_symmetric_with_copy",
op_func=all_reduce_symmetric_with_copy_impl,
mutates_args=[],
fake_impl=all_reduce_symmetric_with_copy_fake,
)
class PyNcclCommunicator:
@ -67,6 +100,7 @@ class PyNcclCommunicator:
self.available = True
self.disabled = False
self.nccl_version = self.nccl.ncclGetRawVersion()
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
if self.rank == 0:
@ -109,6 +143,7 @@ class PyNcclCommunicator:
def all_reduce(self,
in_tensor: torch.Tensor,
out_tensor: torch.Tensor = None,
op: ReduceOp = ReduceOp.SUM,
stream=None) -> torch.Tensor:
if self.disabled:
@ -120,7 +155,8 @@ class PyNcclCommunicator:
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {in_tensor.device}")
out_tensor = torch.empty_like(in_tensor)
if out_tensor is None:
out_tensor = torch.empty_like(in_tensor)
if stream is None:
stream = current_stream()
@ -288,3 +324,18 @@ class PyNcclCommunicator:
def group_end(self):
self.nccl.ncclGroupEnd()
def register_comm_window(self, tensor: torch.Tensor):
return self.nccl.ncclCommWindowRegister(
self.comm,
buffer_type(tensor.data_ptr()),
tensor.numel() * tensor.element_size(),
1,
)
def register_comm_window_raw(self, ptr: int, size: int):
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
size, 1)
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)

View File

@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import contextlib
import tempfile
from typing import Any, Optional
import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator
from torch.utils.cpp_extension import load_inline
from vllm import envs
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import find_nccl_include_paths
logger = init_logger(__name__)
nccl_allocator_source = """
#include <nccl.h>
extern "C" {
void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;
}
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
ncclResult_t err = ncclMemFree(ptr);
}
}
"""
_allocator = None
_allocator_wrapper = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_nccl_allocator_failed_to_compile = False
_cached_pool_snapshot = None
def is_symmetric_memory_enabled():
global _nccl_allocator_failed_to_compile
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
for block in segment["blocks"]:
if block["address"] == tensor.untyped_storage().data_ptr():
return True
return False
def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
def compile_nccl_allocator():
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
if not current_platform.is_cuda():
_nccl_allocator_failed_to_compile = True
return
try:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
nccl_include_paths = find_nccl_include_paths()
load_inline(
name=nccl_allocator_libname,
cpp_sources=nccl_allocator_source,
with_cuda=True,
extra_ldflags=["-lnccl"],
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
is_python_module=False,
build_directory=out_dir,
extra_include_paths=nccl_include_paths,
)
_allocator_wrapper = CUDAPluggableAllocator(
f"{out_dir}/{nccl_allocator_libname}.so",
"nccl_alloc_plug",
"nccl_free_plug",
)
_allocator = _allocator_wrapper.allocator()
except Exception as e:
_nccl_allocator_failed_to_compile = True
logger.warning(
"Failed to compile NCCL memory allocator. "
"Symmetric memory will be disabled. "
"This is expected if NCCL headers are not available. "
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
"containing the NCCL header. "
"Error: %s", str(e))
def get_nccl_mem_pool():
global _mem_pool, _nccl_allocator_failed_to_compile
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
compile_nccl_allocator()
if _allocator is not None:
_mem_pool = torch.cuda.MemPool(_allocator)
return _mem_pool
def _cleanup_nccl_mem_pool():
global _mem_pool
_mem_pool = None
def _cleanup_nccl_allocator_wrapper():
global _allocator_wrapper
_allocator_wrapper = None
atexit.register(_cleanup_nccl_mem_pool)
atexit.register(_cleanup_nccl_allocator_wrapper)
class nccl_symm_mem_context:
def __init__(
self,
pynccl_comm: PyNcclCommunicator,
disabled: bool = False,
):
self.disabled = (disabled or not is_symmetric_memory_enabled()
or pynccl_comm.world_size == 1
or not current_platform.is_cuda()
or get_nccl_mem_pool() is None or version.parse(
torch.__version__) < version.parse("2.8.0.a0"))
if self.disabled:
self.pynccl_comm: Optional[PyNcclCommunicator] = None
self._mem_pool_ctx: contextlib.AbstractContextManager[
Any] = contextlib.nullcontext()
self.is_graph_capture = None
self.device = None
else:
self.pynccl_comm = pynccl_comm
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
def __enter__(self):
if self.disabled:
return self
assert (
self.pynccl_comm
is not None), "Symmetric memory requires pynccl to be initalized"
assert (
self.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id
is not None), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
self._mem_pool_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.disabled:
return
global _cached_pool_snapshot
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
_pool = get_nccl_mem_pool()
assert _pool is not None
_cached_pool_snapshot = _pool.snapshot()
assert self.pynccl_comm is not None
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
self.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"])
_registered_base_addrs.add(segment["address"])
if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id)

View File

@ -41,6 +41,7 @@ logger = init_logger(__name__)
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
ncclWindow_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
@ -222,6 +223,24 @@ class NCCLLibrary:
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
# ncclResult_t ncclCommWindowRegister(
# ncclComm_t comm, void* buff, size_t size,
# ncclWindow_t* win, int winFlags);
Function(
"ncclCommWindowRegister",
ncclResult_t,
[
ncclComm_t,
buffer_type,
ctypes.c_size_t,
ctypes.POINTER(ncclWindow_t),
ctypes.c_int,
],
),
# ncclResult_t ncclCommWindowDeregister(
# ncclComm_t comm, ncclWindow_t win);
Function("ncclCommWindowDeregister", ncclResult_t,
[ncclComm_t, ncclWindow_t]),
]
# class attribute to store the mapping from the path to the library
@ -271,10 +290,14 @@ class NCCLLibrary:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
def ncclGetRawVersion(self) -> int:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903
return version.value
def ncclGetVersion(self) -> str:
version_str = str(self.ncclGetRawVersion())
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
@ -375,6 +398,17 @@ class NCCLLibrary:
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
size: int, win_flags: int) -> ncclWindow_t:
window = ncclWindow_t()
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
comm, buff, size, ctypes.byref(window), win_flags))
return window
def ncclCommWindowDeregister(self, comm: ncclComm_t,
window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

View File

@ -193,6 +193,8 @@ if TYPE_CHECKING:
VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
def get_default_cache_root():
@ -1410,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
["container",
"code_interpreter",
"web_search_preview"]),
# Flag to enable NCCL symmetric memory allocation and registration
"VLLM_USE_NCCL_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
# NCCL header path
"VLLM_NCCL_INCLUDE_PATH":
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
}
# --8<-- [end:env-vars-definition]

View File

@ -1383,6 +1383,38 @@ def find_nccl_library() -> str:
return so_file
def find_nccl_include_paths() -> Optional[list[str]]:
"""
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
environment variable, or we find the library file brought by
nvidia-nccl-cuXX. load_inline by default uses
torch.utils.cpp_extension.include_paths
"""
paths: list[str] = []
inc = envs.VLLM_NCCL_INCLUDE_PATH
if inc and os.path.isdir(inc):
paths.append(inc)
try:
import importlib.util
spec = importlib.util.find_spec("nvidia.nccl")
if spec and getattr(spec, "submodule_search_locations", None):
for loc in spec.submodule_search_locations:
inc_dir = os.path.join(loc, "include")
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
paths.append(inc_dir)
except Exception:
pass
seen = set()
out: list[str] = []
for p in paths:
if p and p not in seen:
out.append(p)
seen.add(p)
return out or None
prev_set_stream = torch.cuda.set_stream
_current_stream_tls = threading.local()

View File

@ -11,6 +11,8 @@ import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_ep_group
from vllm.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id)
from vllm.forward_context import (create_forward_context, get_forward_context,
override_forward_context)
from vllm.logger import init_logger
@ -206,6 +208,10 @@ class UBatchWrapper:
cudagraph=torch.cuda.CUDAGraph(),
ubatch_metadata=ubatch_metadata,
)
if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())
with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream,
pool=self.graph_pool):