mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:34:58 +08:00
[core] add nccl symmetric memory for all reduce (#24532)
Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
a3a7828010
commit
8c1c81a3de
@ -1039,3 +1039,4 @@ steps:
|
|||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
|
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py
|
||||||
|
|||||||
@ -7,6 +7,10 @@ Benchmark script for device communicators:
|
|||||||
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||||
and SymmMemCommunicator (multimem, two-shot).
|
and SymmMemCommunicator (multimem, two-shot).
|
||||||
|
|
||||||
|
for NCCL symmetric memory you need to set the environment variables
|
||||||
|
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
||||||
|
not use fast NVLS implementation for all reduce.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||||
|
|
||||||
@ -26,7 +30,13 @@ import torch.distributed as dist
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
PyNcclCommunicator,
|
||||||
|
register_nccl_symmetric_ops,
|
||||||
|
)
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id,
|
||||||
|
)
|
||||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
@ -98,6 +108,7 @@ class CommunicatorBenchmark:
|
|||||||
)
|
)
|
||||||
if not self.pynccl_comm.disabled:
|
if not self.pynccl_comm.disabled:
|
||||||
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||||
|
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||||
else:
|
else:
|
||||||
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||||
self.pynccl_comm = None
|
self.pynccl_comm = None
|
||||||
@ -194,6 +205,15 @@ class CommunicatorBenchmark:
|
|||||||
None, # no env variable needed
|
None, # no env variable needed
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"pynccl-symm",
|
||||||
|
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||||
|
lambda t: True, # Always available if initialized
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.symm_mem_comm_multimem is not None:
|
if self.symm_mem_comm_multimem is not None:
|
||||||
comm = self.symm_mem_comm_multimem
|
comm = self.symm_mem_comm_multimem
|
||||||
@ -271,7 +291,9 @@ class CommunicatorBenchmark:
|
|||||||
# Capture the graph using context manager
|
# Capture the graph using context manager
|
||||||
with context:
|
with context:
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph):
|
graph_pool = torch.cuda.graph_pool_handle()
|
||||||
|
set_graph_pool_id(graph_pool)
|
||||||
|
with torch.cuda.graph(graph, pool=graph_pool):
|
||||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||||
allreduce_fn(graph_input)
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
|||||||
94
tests/distributed/test_nccl_symm_mem_allreduce.py
Normal file
94
tests/distributed/test_nccl_symm_mem_allreduce.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import random
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||||
|
CudaCommunicator)
|
||||||
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
register_nccl_symmetric_ops)
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
||||||
|
from vllm.distributed.parallel_state import (get_tp_group,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
random.seed(44)
|
||||||
|
|
||||||
|
test_size_elements = 4 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||||
|
monkeypatch = pytest.MonkeyPatch()
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
device = torch.device(f"cuda:{local_rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
torch.set_default_dtype(dtype)
|
||||||
|
update_environment_variables({
|
||||||
|
"RANK": str(local_rank),
|
||||||
|
"LOCAL_RANK": str(local_rank),
|
||||||
|
"WORLD_SIZE": str(world_size),
|
||||||
|
"MASTER_ADDR": "localhost",
|
||||||
|
"MASTER_PORT": "12345",
|
||||||
|
})
|
||||||
|
|
||||||
|
init_distributed_environment()
|
||||||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
|
cuda_communicator = typing.cast(CudaCommunicator,
|
||||||
|
get_tp_group().device_communicator)
|
||||||
|
pynccl_comm = cuda_communicator.pynccl_comm
|
||||||
|
if get_nccl_mem_pool() is None:
|
||||||
|
pytest.skip("NCCL allocator compilation failed "
|
||||||
|
"(probably missing NCCL headers).")
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||||
|
|
||||||
|
register_nccl_symmetric_ops(pynccl_comm)
|
||||||
|
input = torch.randint(1,
|
||||||
|
23, (test_size_elements, ),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
input_clone = input.clone()
|
||||||
|
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||||
|
assert output is not None
|
||||||
|
|
||||||
|
group = get_tp_group().device_group
|
||||||
|
dist.all_reduce(input_clone, group=group)
|
||||||
|
torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda(),
|
||||||
|
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||||
|
reason="Only test on CUDA")
|
||||||
|
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||||
|
if world_size > torch.cuda.device_count():
|
||||||
|
pytest.skip("Not enough GPUs to run the test.")
|
||||||
|
|
||||||
|
# Enable SymmMemCommunicator
|
||||||
|
monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
|
||||||
|
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||||
|
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||||
|
|
||||||
|
mp.spawn(nccl_symm_mem_allreduce_worker,
|
||||||
|
args=(world_size, ),
|
||||||
|
nprocs=world_size)
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
@ -12,6 +12,8 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id)
|
||||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -154,6 +156,10 @@ class CUDAGraphWrapper:
|
|||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
patch("torch.cuda.empty_cache", lambda: None))
|
patch("torch.cuda.empty_cache", lambda: None))
|
||||||
|
|
||||||
|
if self.graph_pool is not None:
|
||||||
|
set_graph_pool_id(self.graph_pool)
|
||||||
|
else:
|
||||||
|
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||||
# mind-exploding: carefully manage the reference and memory.
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||||
# `output` is managed by pytorch's cudagraph pool
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
|
|||||||
@ -10,8 +10,9 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
@ -56,6 +57,30 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||||
|
"min_world_size": 4,
|
||||||
|
"thresholds": {
|
||||||
|
4: 2 * MiB, # 2 MB
|
||||||
|
8: 1 * MiB, # 1 MB
|
||||||
|
},
|
||||||
|
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def should_nccl_symm_mem_allreduce(world_size: int,
|
||||||
|
input_tensor: torch.Tensor) -> bool:
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
is_symmetric_memory_enabled)
|
||||||
|
if not is_symmetric_memory_enabled():
|
||||||
|
return False
|
||||||
|
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||||
|
return False
|
||||||
|
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||||
|
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||||
|
return True
|
||||||
|
return (world_size
|
||||||
|
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
|
||||||
|
|
||||||
|
|
||||||
def producer(batch_src: Sequence[int],
|
def producer(batch_src: Sequence[int],
|
||||||
producer_queue,
|
producer_queue,
|
||||||
|
|||||||
@ -7,6 +7,12 @@ import torch
|
|||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||||
|
should_nccl_symm_mem_allreduce)
|
||||||
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
register_nccl_symmetric_ops)
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
is_symmetric_memory_enabled)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -53,6 +59,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
group=self.cpu_group,
|
group=self.cpu_group,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
if is_symmetric_memory_enabled():
|
||||||
|
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||||
|
|
||||||
self.ca_comm: Optional[CustomAllreduce] = None
|
self.ca_comm: Optional[CustomAllreduce] = None
|
||||||
self.qr_comm: Optional[QuickAllReduce] = None
|
self.qr_comm: Optional[QuickAllReduce] = None
|
||||||
@ -107,6 +115,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||||
|
|
||||||
def all_reduce(self, input_):
|
def all_reduce(self, input_):
|
||||||
|
# since currently we perform copy input -> symm_input -> out-of-place AR
|
||||||
|
# return symm_output, we don't need to check if input is symmetric
|
||||||
|
if self.pynccl_comm is not None and \
|
||||||
|
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
|
||||||
|
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
# always try quick reduce first, then custom allreduce,
|
# always try quick reduce first, then custom allreduce,
|
||||||
# and then pynccl. (quick reduce just for ROCM MI3*)
|
# and then pynccl. (quick reduce just for ROCM MI3*)
|
||||||
qr_comm = self.qr_comm
|
qr_comm = self.qr_comm
|
||||||
|
|||||||
@ -17,6 +17,39 @@ from vllm.utils import current_stream
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_NCCL_SYMM_OPS_REGISTERED = False
|
||||||
|
|
||||||
|
|
||||||
|
def register_nccl_symmetric_ops(pynccl_comm):
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
nccl_symm_mem_context)
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
global _NCCL_SYMM_OPS_REGISTERED
|
||||||
|
if _NCCL_SYMM_OPS_REGISTERED:
|
||||||
|
return
|
||||||
|
_NCCL_SYMM_OPS_REGISTERED = True
|
||||||
|
|
||||||
|
def all_reduce_symmetric_with_copy_impl(
|
||||||
|
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
with nccl_symm_mem_context(pynccl_comm):
|
||||||
|
symm_input = torch.empty_like(input_tensor)
|
||||||
|
symm_output = torch.empty_like(input_tensor)
|
||||||
|
symm_input.copy_(input_tensor)
|
||||||
|
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
|
||||||
|
return symm_output
|
||||||
|
|
||||||
|
def all_reduce_symmetric_with_copy_fake(
|
||||||
|
input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.empty_like(input_tensor)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="all_reduce_symmetric_with_copy",
|
||||||
|
op_func=all_reduce_symmetric_with_copy_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=all_reduce_symmetric_with_copy_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PyNcclCommunicator:
|
class PyNcclCommunicator:
|
||||||
|
|
||||||
@ -67,6 +100,7 @@ class PyNcclCommunicator:
|
|||||||
self.available = True
|
self.available = True
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
|
self.nccl_version = self.nccl.ncclGetRawVersion()
|
||||||
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
@ -109,6 +143,7 @@ class PyNcclCommunicator:
|
|||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
in_tensor: torch.Tensor,
|
in_tensor: torch.Tensor,
|
||||||
|
out_tensor: torch.Tensor = None,
|
||||||
op: ReduceOp = ReduceOp.SUM,
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
stream=None) -> torch.Tensor:
|
stream=None) -> torch.Tensor:
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
@ -120,7 +155,8 @@ class PyNcclCommunicator:
|
|||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {in_tensor.device}")
|
f"but the input tensor is on {in_tensor.device}")
|
||||||
|
|
||||||
out_tensor = torch.empty_like(in_tensor)
|
if out_tensor is None:
|
||||||
|
out_tensor = torch.empty_like(in_tensor)
|
||||||
|
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = current_stream()
|
stream = current_stream()
|
||||||
@ -288,3 +324,18 @@ class PyNcclCommunicator:
|
|||||||
|
|
||||||
def group_end(self):
|
def group_end(self):
|
||||||
self.nccl.ncclGroupEnd()
|
self.nccl.ncclGroupEnd()
|
||||||
|
|
||||||
|
def register_comm_window(self, tensor: torch.Tensor):
|
||||||
|
return self.nccl.ncclCommWindowRegister(
|
||||||
|
self.comm,
|
||||||
|
buffer_type(tensor.data_ptr()),
|
||||||
|
tensor.numel() * tensor.element_size(),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def register_comm_window_raw(self, ptr: int, size: int):
|
||||||
|
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
|
||||||
|
size, 1)
|
||||||
|
|
||||||
|
def deregister_comm_window(self, window):
|
||||||
|
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||||
|
|||||||
186
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
186
vllm/distributed/device_communicators/pynccl_allocator.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import atexit
|
||||||
|
import contextlib
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
from torch.cuda.memory import CUDAPluggableAllocator
|
||||||
|
from torch.utils.cpp_extension import load_inline
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import find_nccl_include_paths
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
nccl_allocator_source = """
|
||||||
|
#include <nccl.h>
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
||||||
|
void* ptr;
|
||||||
|
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
||||||
|
return ptr;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
||||||
|
ncclResult_t err = ncclMemFree(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_allocator = None
|
||||||
|
_allocator_wrapper = None
|
||||||
|
_mem_pool = None
|
||||||
|
_registered_base_addrs = set()
|
||||||
|
_graph_pool_id = None
|
||||||
|
_nccl_allocator_failed_to_compile = False
|
||||||
|
_cached_pool_snapshot = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_symmetric_memory_enabled():
|
||||||
|
global _nccl_allocator_failed_to_compile
|
||||||
|
return envs.VLLM_USE_NCCL_SYMM_MEM and not _nccl_allocator_failed_to_compile
|
||||||
|
|
||||||
|
|
||||||
|
def is_symmetric_memory_tensor(tensor: torch.Tensor):
|
||||||
|
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
|
||||||
|
return False
|
||||||
|
for segment in _cached_pool_snapshot:
|
||||||
|
for block in segment["blocks"]:
|
||||||
|
if block["address"] == tensor.untyped_storage().data_ptr():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def set_graph_pool_id(graph_pool_id):
|
||||||
|
global _graph_pool_id
|
||||||
|
_graph_pool_id = graph_pool_id
|
||||||
|
|
||||||
|
|
||||||
|
def compile_nccl_allocator():
|
||||||
|
global _allocator, _allocator_wrapper, _nccl_allocator_failed_to_compile
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
_nccl_allocator_failed_to_compile = True
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
out_dir = tempfile.gettempdir()
|
||||||
|
nccl_allocator_libname = "nccl_allocator"
|
||||||
|
nccl_include_paths = find_nccl_include_paths()
|
||||||
|
load_inline(
|
||||||
|
name=nccl_allocator_libname,
|
||||||
|
cpp_sources=nccl_allocator_source,
|
||||||
|
with_cuda=True,
|
||||||
|
extra_ldflags=["-lnccl"],
|
||||||
|
verbose=envs.VLLM_LOGGING_LEVEL == "DEBUG",
|
||||||
|
is_python_module=False,
|
||||||
|
build_directory=out_dir,
|
||||||
|
extra_include_paths=nccl_include_paths,
|
||||||
|
)
|
||||||
|
_allocator_wrapper = CUDAPluggableAllocator(
|
||||||
|
f"{out_dir}/{nccl_allocator_libname}.so",
|
||||||
|
"nccl_alloc_plug",
|
||||||
|
"nccl_free_plug",
|
||||||
|
)
|
||||||
|
_allocator = _allocator_wrapper.allocator()
|
||||||
|
except Exception as e:
|
||||||
|
_nccl_allocator_failed_to_compile = True
|
||||||
|
logger.warning(
|
||||||
|
"Failed to compile NCCL memory allocator. "
|
||||||
|
"Symmetric memory will be disabled. "
|
||||||
|
"This is expected if NCCL headers are not available. "
|
||||||
|
"optionally set VLLM_NCCL_INCLUDE_PATH to point to a directory "
|
||||||
|
"containing the NCCL header. "
|
||||||
|
"Error: %s", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_mem_pool():
|
||||||
|
global _mem_pool, _nccl_allocator_failed_to_compile
|
||||||
|
if _mem_pool is None and not _nccl_allocator_failed_to_compile:
|
||||||
|
compile_nccl_allocator()
|
||||||
|
if _allocator is not None:
|
||||||
|
_mem_pool = torch.cuda.MemPool(_allocator)
|
||||||
|
return _mem_pool
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_nccl_mem_pool():
|
||||||
|
global _mem_pool
|
||||||
|
_mem_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_nccl_allocator_wrapper():
|
||||||
|
global _allocator_wrapper
|
||||||
|
_allocator_wrapper = None
|
||||||
|
|
||||||
|
|
||||||
|
atexit.register(_cleanup_nccl_mem_pool)
|
||||||
|
atexit.register(_cleanup_nccl_allocator_wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
class nccl_symm_mem_context:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pynccl_comm: PyNcclCommunicator,
|
||||||
|
disabled: bool = False,
|
||||||
|
):
|
||||||
|
self.disabled = (disabled or not is_symmetric_memory_enabled()
|
||||||
|
or pynccl_comm.world_size == 1
|
||||||
|
or not current_platform.is_cuda()
|
||||||
|
or get_nccl_mem_pool() is None or version.parse(
|
||||||
|
torch.__version__) < version.parse("2.8.0.a0"))
|
||||||
|
if self.disabled:
|
||||||
|
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||||
|
self._mem_pool_ctx: contextlib.AbstractContextManager[
|
||||||
|
Any] = contextlib.nullcontext()
|
||||||
|
self.is_graph_capture = None
|
||||||
|
self.device = None
|
||||||
|
else:
|
||||||
|
self.pynccl_comm = pynccl_comm
|
||||||
|
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
||||||
|
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
||||||
|
self.device = torch.cuda.current_device()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if self.disabled:
|
||||||
|
return self
|
||||||
|
assert (
|
||||||
|
self.pynccl_comm
|
||||||
|
is not None), "Symmetric memory requires pynccl to be initalized"
|
||||||
|
assert (
|
||||||
|
self.pynccl_comm.nccl_version >= 22703
|
||||||
|
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
||||||
|
if self.is_graph_capture:
|
||||||
|
assert (
|
||||||
|
_graph_pool_id
|
||||||
|
is not None), "graph_pool_id is not set under graph capture"
|
||||||
|
# Pause graph memory pool to use symmetric memory with cuda graph
|
||||||
|
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
||||||
|
self._mem_pool_ctx.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
global _cached_pool_snapshot
|
||||||
|
global _registered_base_addrs
|
||||||
|
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
_pool = get_nccl_mem_pool()
|
||||||
|
assert _pool is not None
|
||||||
|
_cached_pool_snapshot = _pool.snapshot()
|
||||||
|
assert self.pynccl_comm is not None
|
||||||
|
for segment in _cached_pool_snapshot:
|
||||||
|
if segment["address"] not in _registered_base_addrs:
|
||||||
|
self.pynccl_comm.register_comm_window_raw(
|
||||||
|
segment["address"], segment["total_size"])
|
||||||
|
_registered_base_addrs.add(segment["address"])
|
||||||
|
if self.is_graph_capture:
|
||||||
|
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
||||||
|
self.device, _graph_pool_id)
|
||||||
@ -41,6 +41,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
ncclResult_t = ctypes.c_int
|
ncclResult_t = ctypes.c_int
|
||||||
ncclComm_t = ctypes.c_void_p
|
ncclComm_t = ctypes.c_void_p
|
||||||
|
ncclWindow_t = ctypes.c_void_p
|
||||||
|
|
||||||
|
|
||||||
class ncclUniqueId(ctypes.Structure):
|
class ncclUniqueId(ctypes.Structure):
|
||||||
@ -222,6 +223,24 @@ class NCCLLibrary:
|
|||||||
Function("ncclGroupStart", ncclResult_t, []),
|
Function("ncclGroupStart", ncclResult_t, []),
|
||||||
# ncclResult_t ncclGroupEnd();
|
# ncclResult_t ncclGroupEnd();
|
||||||
Function("ncclGroupEnd", ncclResult_t, []),
|
Function("ncclGroupEnd", ncclResult_t, []),
|
||||||
|
# ncclResult_t ncclCommWindowRegister(
|
||||||
|
# ncclComm_t comm, void* buff, size_t size,
|
||||||
|
# ncclWindow_t* win, int winFlags);
|
||||||
|
Function(
|
||||||
|
"ncclCommWindowRegister",
|
||||||
|
ncclResult_t,
|
||||||
|
[
|
||||||
|
ncclComm_t,
|
||||||
|
buffer_type,
|
||||||
|
ctypes.c_size_t,
|
||||||
|
ctypes.POINTER(ncclWindow_t),
|
||||||
|
ctypes.c_int,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# ncclResult_t ncclCommWindowDeregister(
|
||||||
|
# ncclComm_t comm, ncclWindow_t win);
|
||||||
|
Function("ncclCommWindowDeregister", ncclResult_t,
|
||||||
|
[ncclComm_t, ncclWindow_t]),
|
||||||
]
|
]
|
||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
@ -271,10 +290,14 @@ class NCCLLibrary:
|
|||||||
error_str = self.ncclGetErrorString(result)
|
error_str = self.ncclGetErrorString(result)
|
||||||
raise RuntimeError(f"NCCL error: {error_str}")
|
raise RuntimeError(f"NCCL error: {error_str}")
|
||||||
|
|
||||||
def ncclGetVersion(self) -> str:
|
def ncclGetRawVersion(self) -> int:
|
||||||
version = ctypes.c_int()
|
version = ctypes.c_int()
|
||||||
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
||||||
version_str = str(version.value)
|
# something like 21903
|
||||||
|
return version.value
|
||||||
|
|
||||||
|
def ncclGetVersion(self) -> str:
|
||||||
|
version_str = str(self.ncclGetRawVersion())
|
||||||
# something like 21903 --> "2.19.3"
|
# something like 21903 --> "2.19.3"
|
||||||
major = version_str[0].lstrip("0")
|
major = version_str[0].lstrip("0")
|
||||||
minor = version_str[1:3].lstrip("0")
|
minor = version_str[1:3].lstrip("0")
|
||||||
@ -375,6 +398,17 @@ class NCCLLibrary:
|
|||||||
def ncclGroupEnd(self) -> None:
|
def ncclGroupEnd(self) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||||
|
|
||||||
|
def ncclCommWindowRegister(self, comm: ncclComm_t, buff: buffer_type,
|
||||||
|
size: int, win_flags: int) -> ncclWindow_t:
|
||||||
|
window = ncclWindow_t()
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclCommWindowRegister"](
|
||||||
|
comm, buff, size, ctypes.byref(window), win_flags))
|
||||||
|
return window
|
||||||
|
|
||||||
|
def ncclCommWindowDeregister(self, comm: ncclComm_t,
|
||||||
|
window: ncclWindow_t) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||||
|
|||||||
11
vllm/envs.py
11
vllm/envs.py
@ -193,6 +193,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DBO_COMM_SMS: int = 20
|
VLLM_DBO_COMM_SMS: int = 20
|
||||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
VLLM_USE_NCCL_SYMM_MEM: bool = False
|
||||||
|
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1410,6 +1412,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
["container",
|
["container",
|
||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
"web_search_preview"]),
|
"web_search_preview"]),
|
||||||
|
|
||||||
|
# Flag to enable NCCL symmetric memory allocation and registration
|
||||||
|
"VLLM_USE_NCCL_SYMM_MEM":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_NCCL_SYMM_MEM", "0"))),
|
||||||
|
|
||||||
|
# NCCL header path
|
||||||
|
"VLLM_NCCL_INCLUDE_PATH":
|
||||||
|
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -1383,6 +1383,38 @@ def find_nccl_library() -> str:
|
|||||||
return so_file
|
return so_file
|
||||||
|
|
||||||
|
|
||||||
|
def find_nccl_include_paths() -> Optional[list[str]]:
|
||||||
|
"""
|
||||||
|
We either use the nccl.h specified by the `VLLM_NCCL_INCLUDE_PATH`
|
||||||
|
environment variable, or we find the library file brought by
|
||||||
|
nvidia-nccl-cuXX. load_inline by default uses
|
||||||
|
torch.utils.cpp_extension.include_paths
|
||||||
|
"""
|
||||||
|
paths: list[str] = []
|
||||||
|
inc = envs.VLLM_NCCL_INCLUDE_PATH
|
||||||
|
if inc and os.path.isdir(inc):
|
||||||
|
paths.append(inc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.find_spec("nvidia.nccl")
|
||||||
|
if spec and getattr(spec, "submodule_search_locations", None):
|
||||||
|
for loc in spec.submodule_search_locations:
|
||||||
|
inc_dir = os.path.join(loc, "include")
|
||||||
|
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
|
||||||
|
paths.append(inc_dir)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
seen = set()
|
||||||
|
out: list[str] = []
|
||||||
|
for p in paths:
|
||||||
|
if p and p not in seen:
|
||||||
|
out.append(p)
|
||||||
|
seen.add(p)
|
||||||
|
return out or None
|
||||||
|
|
||||||
|
|
||||||
prev_set_stream = torch.cuda.set_stream
|
prev_set_stream = torch.cuda.set_stream
|
||||||
|
|
||||||
_current_stream_tls = threading.local()
|
_current_stream_tls = threading.local()
|
||||||
|
|||||||
@ -11,6 +11,8 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id)
|
||||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||||
override_forward_context)
|
override_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -206,6 +208,10 @@ class UBatchWrapper:
|
|||||||
cudagraph=torch.cuda.CUDAGraph(),
|
cudagraph=torch.cuda.CUDAGraph(),
|
||||||
ubatch_metadata=ubatch_metadata,
|
ubatch_metadata=ubatch_metadata,
|
||||||
)
|
)
|
||||||
|
if self.graph_pool is not None:
|
||||||
|
set_graph_pool_id(self.graph_pool)
|
||||||
|
else:
|
||||||
|
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||||
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||||
stream=compute_stream,
|
stream=compute_stream,
|
||||||
pool=self.graph_pool):
|
pool=self.graph_pool):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user