[Core] remove cupy dependency (#3625)

This commit is contained in:
youkaichao 2024-03-27 00:33:26 -07:00 committed by GitHub
parent e66b629c04
commit 8f44facddd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 506 additions and 223 deletions

View File

@ -22,10 +22,13 @@ steps:
working_dir: "/vllm-workspace/tests/distributed" working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now. num_gpus: 2 # only support 1 or 2 for now.
- label: Distributed Correctness Test - label: Distributed Tests
command: pytest -v -s --forked test_basic_distributed_correctness.py
working_dir: "/vllm-workspace/tests/distributed" working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now. num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s --forked test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py
- label: Engine Test - label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py command: pytest -v -s engine tokenization test_sequence.py test_config.py

View File

@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal
#################### RUNTIME BASE IMAGE #################### #################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries. # We used base cuda image because pytorch installs its own cuda libraries.
# However cupy depends on cuda libraries so we had to switch to the runtime image # However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda # In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base

View File

@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm # In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1" ARG BUILD_FA="1"
# whether to build cupy on rocm
ARG BUILD_CUPY="1"
# Install some basic utilities # Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
# build cupy
RUN if [ "$BUILD_CUPY" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
&& pip install scipy==1.9.3 \
&& pip install cython==0.29.* \
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
&& export CUPY_INSTALL_USE_HIP=1 \
&& export ROCM_HOME=/opt/rocm \
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
&& pip install . \
&& cd ..; \
fi
COPY ./ /app/vllm COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip

View File

@ -14,4 +14,3 @@ prometheus_client >= 0.18.0
pynvml == 11.5.0 pynvml == 11.5.0
triton >= 2.1.0 triton >= 2.1.0
outlines == 0.0.34 outlines == 0.0.34
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.

View File

@ -306,12 +306,6 @@ def get_requirements() -> List[str]:
if _is_cuda(): if _is_cuda():
with open(get_path("requirements.txt")) as f: with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n") requirements = f.read().strip().split("\n")
if get_nvcc_cuda_version() <= Version("11.8"):
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
for i in range(len(requirements)):
if requirements[i].startswith("cupy-cuda12x"):
requirements[i] = "cupy-cuda11x"
break
elif _is_hip(): elif _is_hip():
with open(get_path("requirements-rocm.txt")) as f: with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n") requirements = f.read().strip().split("\n")

View File

@ -1,13 +1,22 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. """Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`. by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_basic_distributed_correctness.py
```
""" """
import os
import pytest import pytest
import torch import torch
MODELS = [ MODELS = [
"facebook/opt-125m", os.environ["TEST_DIST_MODEL"],
"meta-llama/Llama-2-7b-hf",
] ]

View File

@ -2,6 +2,8 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`. Run `pytest tests/distributed/test_comm_ops.py --forked`.
""" """
import os
import pytest import pytest
import ray import ray
import torch import torch
@ -16,6 +18,12 @@ from vllm.test_utils import (init_test_distributed_environment,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int, def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
num_elements = 8 num_elements = 8
@ -32,6 +40,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int, def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
num_dimensions = 3 num_dimensions = 3
@ -54,6 +68,12 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str): distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank, init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port) distributed_init_port)
test_dict = { test_dict = {

View File

@ -0,0 +1,90 @@
import multiprocessing
import os
import pytest
import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env['RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()
for p in processes:
p.join()
def update_env(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
fn()
return wrapper
@update_env
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
distributed_run(worker_fn, 2)
@update_env
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)
def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# as long as the function doesn't raise an exception, we're good
assert unique_id is not None

View File

@ -188,11 +188,9 @@ class RayGPUExecutor(ExecutorBase):
is_driver_worker=True, is_driver_worker=True,
) )
# FIXME(woosuk): We are not properly initializing cupy NCCL when # FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes. # we have multiple nodes.
self._run_workers("init_device", self._run_workers("init_device")
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers( self._run_workers(
"load_model", "load_model",
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.

View File

@ -4,12 +4,12 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.custom_all_reduce import ( from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce) custom_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce) get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_) out = custom_all_reduce(input_)
if out is not None: if out is not None:
return out return out
if is_cupy_nccl_enabled_for_all_reduce(): if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups. # TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_) pynccl_utils.all_reduce(input_)
else: else:
torch.distributed.all_reduce(input_, torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())

View File

@ -1,130 +0,0 @@
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
import torch
from torch.distributed import ReduceOp
try:
import cupy
from cupy.cuda import nccl
from cupyx.distributed import NCCLBackend
except ImportError as e:
cupy = e
nccl = None
class NCCLBackend:
...
_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}
class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count
def barrier(self) -> None:
raise RuntimeError(
"Currently, CuPy NCCL barrier is not supported since the TCP "
"store is immediately stopped after the initialization.")
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None
@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield
def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])
def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE
def get_nccl_backend():
return _NCCL_BACKEND

View File

@ -7,7 +7,7 @@ import contextlib
import torch import torch
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import pynccl_utils
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
@ -210,36 +210,36 @@ def destroy_model_parallel():
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
# Destroy the cupy states if any. # Destroy the pynccl states if any.
cupy_utils.destroy_process_group() pynccl_utils.destroy_process_group()
# Whether to use cupy for nccl all reduce. # Whether to use pynccl for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed # We use pynccl for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph. # is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager @contextlib.contextmanager
def with_cupy_nccl_for_all_reduce(): def with_pynccl_for_all_reduce():
"""use CuPy nccl instead of torch.distributed for all reduce""" """use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1: if tp_size == 1:
# No-op. # No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1. # NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
yield yield
else: else:
global _ENABLE_CUPY_FOR_ALL_REDUCE global _ENABLE_PYNCCL_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream): with pynccl_utils.set_pynccl_stream(stream):
yield yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
def is_cupy_nccl_enabled_for_all_reduce(): def is_pynccl_enabled_for_all_reduce():
"""check if CuPy nccl is enabled for all reduce""" """check if pynccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE global _ENABLE_PYNCCL_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE return _ENABLE_PYNCCL_FOR_ALL_REDUCE

View File

@ -0,0 +1,258 @@
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import datetime
import logging
import os
# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
logger = logging.getLogger(__name__)
so_file = os.environ.get("VLLM_NCCL_SO_PATH", "")
# manually load the nccl library
if so_file:
logger.info(
f"Loading nccl from environment variable VLLM_NCCL_SO_PATH={so_file}")
else:
if torch.version.cuda is not None:
so_file = "libnccl.so"
elif torch.version.hip is not None:
so_file = "librccl.so"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug(f"Loading nccl from library {so_file}")
try:
nccl = ctypes.CDLL(so_file)
except Exception as e:
logger.error(
f"Failed to load NCCL library from {so_file} ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.")
raise e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion = nccl.ncclGetVersion
_c_ncclGetVersion.restype = ctypes.c_int
_c_ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
def ncclGetVersion() -> str:
version = ctypes.c_int()
result = _c_ncclGetVersion(ctypes.byref(version))
assert result == 0
# something like 21903 --> "2.19.3"
version_str = str(version.value)
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
class NcclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId = nccl.ncclGetUniqueId
_c_ncclGetUniqueId.restype = ctypes.c_int
_c_ncclGetUniqueId.argtypes = [ctypes.POINTER(NcclUniqueId)]
def ncclGetUniqueId() -> NcclUniqueId:
unique_id = NcclUniqueId()
result = _c_ncclGetUniqueId(ctypes.byref(unique_id))
assert result == 0
return unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank = nccl.ncclCommInitRank
_c_ncclCommInitRank.restype = ctypes.c_int
_c_ncclCommInitRank.argtypes = [
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]
# enums
class ncclDataType_t(ctypes.c_int):
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
class ncclRedOp_t(ctypes.c_int):
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
]
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy = nccl.ncclCommDestroy
_c_ncclCommDestroy.restype = ctypes.c_int
_c_ncclCommDestroy.argtypes = [ctypes.c_void_p]
class NCCLCommunicator:
def __init__(
self,
backend=None,
init_method=None,
timeout=datetime.timedelta(seconds=10),
world_size: int = -1,
rank: int = -1,
store=None,
group_name: str = "",
pg_options=None,
):
if not dist.is_initialized():
backend = backend or "nccl"
assert backend == 'nccl', (
"only use nccl backend for starting the NCCL communicator")
dist.init_process_group(backend=backend,
init_method=init_method,
timeout=timeout,
world_size=world_size,
rank=rank,
store=store,
group_name=group_name,
pg_options=pg_options)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
torch.cuda.set_device(self.rank)
if self.rank == 0:
self.unique_id = ncclGetUniqueId()
else:
self.unique_id = NcclUniqueId()
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
self.rank)
dist.broadcast(tensor, src=0)
byte_list = tensor.cpu().tolist()
self.unique_id = NcclUniqueId()
for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte
self.comm = ctypes.c_void_p()
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
self.unique_id, self.rank)
assert result == 0
self.stream = torch.cuda.Stream(device=f"cuda:{self.rank}")
def all_reduce(self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if stream is None:
stream = self.stream
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0
def __del__(self):
dist.destroy_process_group()
_c_ncclCommDestroy(self.comm)

View File

@ -0,0 +1,64 @@
import contextlib
import logging
from typing import Optional
import torch
from torch.distributed import ReduceOp
logger = logging.getLogger(__name__)
try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetVersion)
logger.info(f"vLLM is using nccl=={ncclGetVersion()}")
except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
logger.info(f"Failed to import NCCL library: {e}")
logger.info("It is expected if you are not running on NVIDIA GPUs.")
pass
comm: Optional["NCCLCommunicator"] = None
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return comm is not None
@contextlib.contextmanager
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
comm.stream = stream
yield
finally:
pass
def init_process_group(world_size: int, rank: int, init_method: str) -> None:
assert not is_initialized()
global comm
comm = NCCLCommunicator(init_method=init_method,
world_size=world_size,
rank=rank)
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
comm.all_reduce(input_, op)
def destroy_process_group() -> None:
global comm
comm = None
def get_world_size() -> int:
"""Returns the world size."""
return comm.world_size
def get_nccl_backend():
return comm

View File

@ -16,10 +16,7 @@ def init_test_distributed_environment(
worker_use_ray=True) worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}" distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment( init_distributed_environment(
parallel_config, parallel_config, rank, distributed_init_method=distributed_init_method)
rank,
cupy_port=None,
distributed_init_method=distributed_init_method)
def multi_process_tensor_parallel( def multi_process_tensor_parallel(

View File

@ -15,11 +15,11 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils, custom_all_reduce from vllm.model_executor.parallel_utils import custom_all_reduce, pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce) with_pynccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
SequenceGroupMetadata) SequenceGroupMetadata)
@ -764,7 +764,7 @@ class ModelRunner:
""" """
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs. # deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend() self.pynccl_backend = pynccl_utils.get_nccl_backend()
assert not self.model_config.enforce_eager assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to " logger.info("Capturing the model for CUDA graphs. This may lead to "
@ -794,11 +794,11 @@ class ModelRunner:
] ]
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use # kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA # either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL. # graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back # We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported. # to PyTorch or pynccl if it is disabled or not supported.
with custom_all_reduce.capture(): with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
@ -846,12 +846,14 @@ class ModelRunner:
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
def __del__(self) -> None: def __del__(self) -> None:
# Delete the CUDA graphs before deleting the CuPy NCCL communicator. # Delete the CUDA graphs before deleting the pynccl communicator.
# NOTE(woosuk): This is necessary because otherwise deadlocks can # NOTE(woosuk): This is necessary because otherwise deadlocks can
# happen. # happen.
# FIXME(woosuk): This is a bit hacky. Find a more robust solution. # FIXME(woosuk): This is a bit hacky. Find a more robust solution.
# TODO(youkaichao): when we get enough user feedback that pynccl is
# more stable than cupy, we can remove this, e.g. in v0.4.1.
self.graph_runners.clear() self.graph_runners.clear()
self.cupy_nccl_backend = None self.pynccl_backend = None
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
@ -879,7 +881,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
with _maybe_cupy_nccl(): with _maybe_pynccl():
self.model( self.model(
input_ids, input_ids,
positions, positions,
@ -894,7 +896,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph() self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_cupy_nccl(): with _maybe_pynccl():
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
@ -947,9 +949,10 @@ class CUDAGraphRunner:
@contextlib.contextmanager @contextlib.contextmanager
def _maybe_cupy_nccl(): def _maybe_pynccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): if pynccl_utils.is_initialized(
with with_cupy_nccl_for_all_reduce(): ) and not custom_all_reduce.is_initialized():
with with_pynccl_for_all_reduce():
yield yield
else: else:
yield yield

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig) ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
@ -75,7 +75,7 @@ class Worker:
self.cache_engine = None self.cache_engine = None
self.gpu_cache = None self.gpu_cache = None
def init_device(self, cupy_port: Optional[int] = None) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow # the synchronization point. This causes the memory usage to grow
@ -98,7 +98,7 @@ class Worker:
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method) self.distributed_init_method)
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@ -250,7 +250,6 @@ class Worker:
def init_distributed_environment( def init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
@ -273,28 +272,27 @@ def init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
) )
if cupy_utils.is_initialized(): if pynccl_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size() pynccl_world_size = pynccl_utils.get_world_size()
if cupy_world_size != parallel_config.world_size: if pynccl_world_size != parallel_config.world_size:
raise RuntimeError( raise RuntimeError(
"cupy.distributed is already initialized but the cupy world " "pynccl is already initialized but the pynccl world "
"size does not match parallel_config.world_size " "size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).") f"({pynccl_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None): elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize CuPy process group when world size # NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1. # is 1.
# TODO(woosuk): Support multi-node connection. # TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group( pynccl_utils.init_process_group(
world_size=parallel_config.world_size, world_size=parallel_config.world_size,
rank=rank, rank=rank,
host="localhost", init_method=distributed_init_method,
port=cupy_port,
) )
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized(): if pynccl_utils.is_initialized():
cupy_utils.all_reduce(torch.zeros(1).cuda()) pynccl_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)