From 8c1c81a3de344fba340cb2efc82f29d2c1563a2d Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Tue, 23 Sep 2025 11:33:06 -0700 Subject: [PATCH] [core] add nccl symmetric memory for all reduce (#24532) Signed-off-by: Amir Samani Signed-off-by: Michael Goin Co-authored-by: Michael Goin --- .buildkite/test-pipeline.yaml | 1 + .../kernels/benchmark_device_communicators.py | 26 ++- .../test_nccl_symm_mem_allreduce.py | 94 +++++++++ vllm/compilation/cuda_graph.py | 6 + .../device_communicators/all_reduce_utils.py | 27 ++- .../device_communicators/cuda_communicator.py | 15 ++ .../device_communicators/pynccl.py | 53 ++++- .../device_communicators/pynccl_allocator.py | 186 ++++++++++++++++++ .../device_communicators/pynccl_wrapper.py | 38 +++- vllm/envs.py | 11 ++ vllm/utils/__init__.py | 32 +++ vllm/v1/worker/gpu_ubatch_wrapper.py | 6 + 12 files changed, 489 insertions(+), 6 deletions(-) create mode 100644 tests/distributed/test_nccl_symm_mem_allreduce.py create mode 100644 vllm/distributed/device_communicators/pynccl_allocator.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 49316eb4f607..cf32087ed3b9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1039,3 +1039,4 @@ steps: num_gpus: 2 commands: - pytest -v -s tests/distributed/test_context_parallel.py + - pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index a61c17edc1e2..4cbdde5a5b2c 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -7,6 +7,10 @@ Benchmark script for device communicators: CustomAllreduce (oneshot, twoshot), PyNcclCommunicator, and SymmMemCommunicator (multimem, two-shot). +for NCCL symmetric memory you need to set the environment variables +NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does +not use fast NVLS implementation for all reduce. + Usage: torchrun --nproc_per_node= benchmark_device_communicators.py [options] @@ -26,7 +30,13 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce -from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + register_nccl_symmetric_ops, +) +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id, +) from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -98,6 +108,7 @@ class CommunicatorBenchmark: ) if not self.pynccl_comm.disabled: logger.info("Rank %s: PyNcclCommunicator initialized", self.rank) + register_nccl_symmetric_ops(self.pynccl_comm) else: logger.info("Rank %s: PyNcclCommunicator disabled", self.rank) self.pynccl_comm = None @@ -194,6 +205,15 @@ class CommunicatorBenchmark: None, # no env variable needed ) ) + communicators.append( + ( + "pynccl-symm", + lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), + lambda t: True, # Always available if initialized + nullcontext(), + None, # no env variable needed + ) + ) if self.symm_mem_comm_multimem is not None: comm = self.symm_mem_comm_multimem @@ -271,7 +291,9 @@ class CommunicatorBenchmark: # Capture the graph using context manager with context: graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): + graph_pool = torch.cuda.graph_pool_handle() + set_graph_pool_id(graph_pool) + with torch.cuda.graph(graph, pool=graph_pool): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/tests/distributed/test_nccl_symm_mem_allreduce.py b/tests/distributed/test_nccl_symm_mem_allreduce.py new file mode 100644 index 000000000000..ffc913742620 --- /dev/null +++ b/tests/distributed/test_nccl_symm_mem_allreduce.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.device_communicators.pynccl import ( + register_nccl_symmetric_ops) +from vllm.distributed.device_communicators.pynccl_allocator import ( + get_nccl_mem_pool, is_symmetric_memory_enabled) +from vllm.distributed.parallel_state import (get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + pynccl_comm = cuda_communicator.pynccl_comm + if get_nccl_mem_pool() is None: + pytest.skip("NCCL allocator compilation failed " + "(probably missing NCCL headers).") + if not is_symmetric_memory_enabled(): + pytest.skip("NCCL symmetric memory allreduce is disabled.") + + register_nccl_symmetric_ops(pynccl_comm) + input = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + input_clone = input.clone() + output = torch.ops.vllm.all_reduce_symmetric_with_copy(input) + assert output is not None + + group = get_tp_group().device_group + dist.all_reduce(input_clone, group=group) + torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="NCCLSymmMemAllreduce is only available for CUDA platforms.", +) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size): + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1") + monkeypatch.setenv("NCCL_NVLS_ENABLE", "1") + monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1") + + mp.spawn(nccl_symm_mem_allreduce_worker, + args=(world_size, ), + nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index e233f959c0a4..befb7736d75a 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -12,6 +12,8 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import validate_cudagraph_capturing_enabled from vllm.config import CUDAGraphMode, VllmConfig +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform @@ -154,6 +156,10 @@ class CUDAGraphWrapper: stack.enter_context( patch("torch.cuda.empty_cache", lambda: None)) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. with torch.cuda.graph(cudagraph, pool=self.graph_pool): # `output` is managed by pytorch's cudagraph pool diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 805a88854b77..87e0f8e1a967 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -10,8 +10,9 @@ import sys import tempfile from collections.abc import Sequence from itertools import product -from typing import Optional +from typing import Any, Optional +import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = { } } +NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { + "min_world_size": 4, + "thresholds": { + 4: 2 * MiB, # 2 MB + 8: 1 * MiB, # 1 MB + }, + "always_use_above_world_size": 8 # Always use symm mem for world_size > 8 +} + + +def should_nccl_symm_mem_allreduce(world_size: int, + input_tensor: torch.Tensor) -> bool: + from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled) + if not is_symmetric_memory_enabled(): + return False + if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: + return False + threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size) + if threshold is not None and input_tensor.nbytes >= threshold: + return True + return (world_size + > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]) + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b2bf3bc3cc2e..6c25bf3cd95c 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -7,6 +7,12 @@ import torch from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.distributed.device_communicators.all_reduce_utils import ( + should_nccl_symm_mem_allreduce) +from vllm.distributed.device_communicators.pynccl import ( + register_nccl_symmetric_ops) +from vllm.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled) from vllm.logger import init_logger from vllm.platforms import current_platform @@ -53,6 +59,8 @@ class CudaCommunicator(DeviceCommunicatorBase): group=self.cpu_group, device=self.device, ) + if is_symmetric_memory_enabled(): + register_nccl_symmetric_ops(self.pynccl_comm) self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None @@ -107,6 +115,13 @@ class CudaCommunicator(DeviceCommunicatorBase): raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): + # since currently we perform copy input -> symm_input -> out-of-place AR + # return symm_output, we don't need to check if input is symmetric + if self.pynccl_comm is not None and \ + should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_): + out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) + if out is not None: + return out # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 3e4d0d250af9..75de85e1b0ab 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -17,6 +17,39 @@ from vllm.utils import current_stream logger = init_logger(__name__) +_NCCL_SYMM_OPS_REGISTERED = False + + +def register_nccl_symmetric_ops(pynccl_comm): + from vllm.distributed.device_communicators.pynccl_allocator import ( + nccl_symm_mem_context) + from vllm.utils import direct_register_custom_op + + global _NCCL_SYMM_OPS_REGISTERED + if _NCCL_SYMM_OPS_REGISTERED: + return + _NCCL_SYMM_OPS_REGISTERED = True + + def all_reduce_symmetric_with_copy_impl( + input_tensor: torch.Tensor) -> torch.Tensor: + with nccl_symm_mem_context(pynccl_comm): + symm_input = torch.empty_like(input_tensor) + symm_output = torch.empty_like(input_tensor) + symm_input.copy_(input_tensor) + symm_output = pynccl_comm.all_reduce(symm_input, symm_output) + return symm_output + + def all_reduce_symmetric_with_copy_fake( + input_tensor: torch.Tensor) -> torch.Tensor: + return torch.empty_like(input_tensor) + + direct_register_custom_op( + op_name="all_reduce_symmetric_with_copy", + op_func=all_reduce_symmetric_with_copy_impl, + mutates_args=[], + fake_impl=all_reduce_symmetric_with_copy_fake, + ) + class PyNcclCommunicator: @@ -67,6 +100,7 @@ class PyNcclCommunicator: self.available = True self.disabled = False + self.nccl_version = self.nccl.ncclGetRawVersion() logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) if self.rank == 0: @@ -109,6 +143,7 @@ class PyNcclCommunicator: def all_reduce(self, in_tensor: torch.Tensor, + out_tensor: torch.Tensor = None, op: ReduceOp = ReduceOp.SUM, stream=None) -> torch.Tensor: if self.disabled: @@ -120,7 +155,8 @@ class PyNcclCommunicator: f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {in_tensor.device}") - out_tensor = torch.empty_like(in_tensor) + if out_tensor is None: + out_tensor = torch.empty_like(in_tensor) if stream is None: stream = current_stream() @@ -288,3 +324,18 @@ class PyNcclCommunicator: def group_end(self): self.nccl.ncclGroupEnd() + + def register_comm_window(self, tensor: torch.Tensor): + return self.nccl.ncclCommWindowRegister( + self.comm, + buffer_type(tensor.data_ptr()), + tensor.numel() * tensor.element_size(), + 1, + ) + + def register_comm_window_raw(self, ptr: int, size: int): + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), + size, 1) + + def deregister_comm_window(self, window): + return self.nccl.ncclCommWindowDeregister(self.comm, window) diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py new file mode 100644 index 000000000000..bc874c1e197e --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import atexit +import contextlib +import tempfile +from typing import Any, Optional + +import torch +from packaging import version +from torch.cuda.memory import CUDAPluggableAllocator +from torch.utils.cpp_extension import load_inline + +from vllm import envs +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import find_nccl_include_paths + +logger = init_logger(__name__) + +nccl_allocator_source = """ +#include +extern "C" { + +void* nccl_alloc_plug(size_t size, int device, void* stream) { + void* ptr; + ncclResult_t err = ncclMemAlloc(&ptr, size); + return ptr; + +} + +void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { + ncclResult_t err = ncclMemFree(ptr); +} + +} +""" + +_allocator = None +_allocator_wrapper = None +_mem_pool = None +_registered_base_addrs = set() +_graph_pool_id = None +_nccl_allocator_failed_to_compile = False +_cached_pool_snapshot = None + + +def is_symmetric_memory_enabled(): + global _nccl_allocator_failed_to_compile + return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile + + +def is_symmetric_memory_tensor(tensor: torch.Tensor): + if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: + return False + for segment in _cached_pool_snapshot: + for block in segment["blocks"]: + if block["address"] == tensor.untyped_storage().data_ptr(): + return True + return False + + +def set_graph_pool_id(graph_pool_id): + global _graph_pool_id + _graph_pool_id = graph_pool_id + + +def compile_nccl_allocator(): + global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile + if not current_platform.is_cuda(): + _nccl_allocator_failed_to_compile = True + return + try: + out_dir = tempfile.gettempdir() + nccl_allocator_libname = "nccl_allocator" + nccl_include_paths = find_nccl_include_paths() + load_inline( + name=nccl_allocator_libname, + cpp_sources=nccl_allocator_source, + with_cuda=True, + extra_ldflags=["-lnccl"], + verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG", + is_python_module=False, + build_directory=out_dir, + extra_include_paths=nccl_include_paths, + ) + _allocator_wrapper = CUDAPluggableAllocator( + f"{out_dir}/{nccl_allocator_libname}.so", + "nccl_alloc_plug", + "nccl_free_plug", + ) + _allocator = _allocator_wrapper.allocator() + except Exception as e: + _nccl_allocator_failed_to_compile = True + logger.warning( + "Failed to compile NCCL memory allocator. " + "Symmetric memory will be disabled. " + "This is expected if NCCL headers are not available. " + "optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory " + "containing the NCCL header. " + "Error: %s", str(e)) + + +def get_nccl_mem_pool(): + global _mem_pool, _nccl_allocator_failed_to_compile + if _mem_pool is None and not _nccl_allocator_failed_to_compile: + compile_nccl_allocator() + if _allocator is not None: + _mem_pool = torch.cuda.MemPool(_allocator) + return _mem_pool + + +def _cleanup_nccl_mem_pool(): + global _mem_pool + _mem_pool = None + + +def _cleanup_nccl_allocator_wrapper(): + global _allocator_wrapper + _allocator_wrapper = None + + +atexit.register(_cleanup_nccl_mem_pool) +atexit.register(_cleanup_nccl_allocator_wrapper) + + +class nccl_symm_mem_context: + + def __init__( + self, + pynccl_comm: PyNcclCommunicator, + disabled: bool = False, + ): + self.disabled = (disabled or not is_symmetric_memory_enabled() + or pynccl_comm.world_size == 1 + or not current_platform.is_cuda() + or get_nccl_mem_pool() is None or version.parse( + torch.__version__) < version.parse("2.8.0.a0")) + if self.disabled: + self.pynccl_comm: Optional[PyNcclCommunicator] = None + self._mem_pool_ctx: contextlib.AbstractContextManager[ + Any] = contextlib.nullcontext() + self.is_graph_capture = None + self.device = None + else: + self.pynccl_comm = pynccl_comm + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() + self.device = torch.cuda.current_device() + + def __enter__(self): + if self.disabled: + return self + assert ( + self.pynccl_comm + is not None), "Symmetric memory requires pynccl to be initalized" + assert ( + self.pynccl_comm.nccl_version >= 22703 + ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + if self.is_graph_capture: + assert ( + _graph_pool_id + is not None), "graph_pool_id is not set under graph capture" + # Pause graph memory pool to use symmetric memory with cuda graph + torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) + self._mem_pool_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.disabled: + return + global _cached_pool_snapshot + global _registered_base_addrs + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) + _pool = get_nccl_mem_pool() + assert _pool is not None + _cached_pool_snapshot = _pool.snapshot() + assert self.pynccl_comm is not None + for segment in _cached_pool_snapshot: + if segment["address"] not in _registered_base_addrs: + self.pynccl_comm.register_comm_window_raw( + segment["address"], segment["total_size"]) + _registered_base_addrs.add(segment["address"]) + if self.is_graph_capture: + torch._C._cuda_beginAllocateCurrentThreadToPool( + self.device, _graph_pool_id) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index a930b63bc26f..c3e99e177e2d 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -41,6 +41,7 @@ logger = init_logger(__name__) ncclResult_t = ctypes.c_int ncclComm_t = ctypes.c_void_p +ncclWindow_t = ctypes.c_void_p class ncclUniqueId(ctypes.Structure): @@ -222,6 +223,24 @@ class NCCLLibrary: Function("ncclGroupStart", ncclResult_t, []), # ncclResult_t ncclGroupEnd(); Function("ncclGroupEnd", ncclResult_t, []), + # ncclResult_t ncclCommWindowRegister( + # ncclComm_t comm, void* buff, size_t size, + # ncclWindow_t* win, int winFlags); + Function( + "ncclCommWindowRegister", + ncclResult_t, + [ + ncclComm_t, + buffer_type, + ctypes.c_size_t, + ctypes.POINTER(ncclWindow_t), + ctypes.c_int, + ], + ), + # ncclResult_t ncclCommWindowDeregister( + # ncclComm_t comm, ncclWindow_t win); + Function("ncclCommWindowDeregister", ncclResult_t, + [ncclComm_t, ncclWindow_t]), ] # class attribute to store the mapping from the path to the library @@ -271,10 +290,14 @@ class NCCLLibrary: error_str = self.ncclGetErrorString(result) raise RuntimeError(f"NCCL error: {error_str}") - def ncclGetVersion(self) -> str: + def ncclGetRawVersion(self) -> int: version = ctypes.c_int() self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) - version_str = str(version.value) + # something like 21903 + return version.value + + def ncclGetVersion(self) -> str: + version_str = str(self.ncclGetRawVersion()) # something like 21903 --> "2.19.3" major = version_str[0].lstrip("0") minor = version_str[1:3].lstrip("0") @@ -375,6 +398,17 @@ class NCCLLibrary: def ncclGroupEnd(self) -> None: self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type, + size: int, win_flags: int) -> ncclWindow_t: + window = ncclWindow_t() + self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags)) + return window + + def ncclCommWindowDeregister(self, comm: ncclComm_t, + window: ncclWindow_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/envs.py b/vllm/envs.py index f6eafe892ef2..fa6f14d6b037 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -193,6 +193,8 @@ if TYPE_CHECKING: VLLM_DBO_COMM_SMS: int = 20 GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None + VLLM_USE_NCCL_SYMM_MEM: bool = False + VLLM_NCCL_INCLUDE_PATH: Optional[str] = None def get_default_cache_root(): @@ -1410,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ["container", "code_interpreter", "web_search_preview"]), + + # Flag to enable NCCL symmetric memory allocation and registration + "VLLM_USE_NCCL_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))), + + # NCCL header path + "VLLM_NCCL_INCLUDE_PATH": + lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), + } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 3399d00fbabb..5d165f166238 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1383,6 +1383,38 @@ def find_nccl_library() -> str: return so_file +def find_nccl_include_paths() -> Optional[list[str]]: + """ + We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH` + environment variable, or we find the library file brought by + nvidia-nccl-cuXX. load_inline by default uses + torch.utils.cpp_extension.include_paths + """ + paths: list[str] = [] + inc = envs.VLLM_NCCL_INCLUDE_PATH + if inc and os.path.isdir(inc): + paths.append(inc) + + try: + import importlib.util + spec = importlib.util.find_spec("nvidia.nccl") + if spec and getattr(spec, "submodule_search_locations", None): + for loc in spec.submodule_search_locations: + inc_dir = os.path.join(loc, "include") + if os.path.exists(os.path.join(inc_dir, "nccl.h")): + paths.append(inc_dir) + except Exception: + pass + + seen = set() + out: list[str] = [] + for p in paths: + if p and p not in seen: + out.append(p) + seen.add(p) + return out or None + + prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index bfc3743ea417..d636e7af72ea 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -11,6 +11,8 @@ import vllm.envs as envs from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_ep_group +from vllm.distributed.device_communicators.pynccl_allocator import ( + set_graph_pool_id) from vllm.forward_context import (create_forward_context, get_forward_context, override_forward_context) from vllm.logger import init_logger @@ -206,6 +208,10 @@ class UBatchWrapper: cudagraph=torch.cuda.CUDAGraph(), ubatch_metadata=ubatch_metadata, ) + if self.graph_pool is not None: + set_graph_pool_id(self.graph_pool) + else: + set_graph_pool_id(current_platform.graph_pool_handle()) with torch.cuda.graph(cudagraph_metadata.cudagraph, stream=compute_stream, pool=self.graph_pool):