mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 10:17:14 +08:00
259 lines
9.5 KiB
Python
259 lines
9.5 KiB
Python
# This file is a pure Python wrapper for the NCCL library.
|
|
# The main purpose is to use NCCL combined with CUDA graph.
|
|
# Before writing this script, we tried the following approach:
|
|
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
|
|
# often gets stuck when initializing the NCCL communicator.
|
|
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
|
|
# contains many other potential cuda APIs, that are not allowed during
|
|
# capturing the CUDA graph. For further details, please check
|
|
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
|
|
#
|
|
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
|
|
# doable, but we often encounter issues related with nccl versions, and need
|
|
# to switch between different versions of NCCL. See
|
|
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
|
|
# A C/C++ binding is not flexible enough to handle this. It requires
|
|
# recompilation of the code every time we want to switch between different
|
|
# versions. This current implementation, with a **pure** Python wrapper, is
|
|
# more flexible. We can easily switch between different versions of NCCL by
|
|
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
|
|
# variable in the code.
|
|
|
|
import ctypes
|
|
import platform
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
from torch.distributed import ReduceOp
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import find_nccl_library, nccl_integrity_check
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# === export types and functions from nccl to Python ===
|
|
# for the original nccl definition, please check
|
|
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
|
|
|
|
ncclResult_t = ctypes.c_int
|
|
ncclComm_t = ctypes.c_void_p
|
|
|
|
|
|
class ncclUniqueId(ctypes.Structure):
|
|
_fields_ = [("internal", ctypes.c_byte * 128)]
|
|
|
|
|
|
cudaStream_t = ctypes.c_void_p
|
|
buffer_type = ctypes.c_void_p
|
|
|
|
ncclDataType_t = ctypes.c_int
|
|
|
|
|
|
class ncclDataTypeEnum:
|
|
ncclInt8 = 0
|
|
ncclChar = 0
|
|
ncclUint8 = 1
|
|
ncclInt32 = 2
|
|
ncclInt = 2
|
|
ncclUint32 = 3
|
|
ncclInt64 = 4
|
|
ncclUint64 = 5
|
|
ncclFloat16 = 6
|
|
ncclHalf = 6
|
|
ncclFloat32 = 7
|
|
ncclFloat = 7
|
|
ncclFloat64 = 8
|
|
ncclDouble = 8
|
|
ncclBfloat16 = 9
|
|
ncclNumTypes = 10
|
|
|
|
@classmethod
|
|
def from_torch(cls, dtype: torch.dtype) -> int:
|
|
if dtype == torch.int8:
|
|
return cls.ncclInt8
|
|
if dtype == torch.uint8:
|
|
return cls.ncclUint8
|
|
if dtype == torch.int32:
|
|
return cls.ncclInt32
|
|
if dtype == torch.int64:
|
|
return cls.ncclInt64
|
|
if dtype == torch.float16:
|
|
return cls.ncclFloat16
|
|
if dtype == torch.float32:
|
|
return cls.ncclFloat32
|
|
if dtype == torch.float64:
|
|
return cls.ncclFloat64
|
|
if dtype == torch.bfloat16:
|
|
return cls.ncclBfloat16
|
|
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
|
|
|
|
ncclRedOp_t = ctypes.c_int
|
|
|
|
|
|
class ncclRedOpTypeEnum:
|
|
ncclSum = 0
|
|
ncclProd = 1
|
|
ncclMax = 2
|
|
ncclMin = 3
|
|
ncclAvg = 4
|
|
ncclNumOps = 5
|
|
|
|
@classmethod
|
|
def from_torch(cls, op: ReduceOp) -> int:
|
|
if op == ReduceOp.SUM:
|
|
return cls.ncclSum
|
|
if op == ReduceOp.PRODUCT:
|
|
return cls.ncclProd
|
|
if op == ReduceOp.MAX:
|
|
return cls.ncclMax
|
|
if op == ReduceOp.MIN:
|
|
return cls.ncclMin
|
|
if op == ReduceOp.AVG:
|
|
return cls.ncclAvg
|
|
raise ValueError(f"Unsupported op: {op}")
|
|
|
|
|
|
@dataclass
|
|
class Function:
|
|
name: str
|
|
restype: Any
|
|
argtypes: List[Any]
|
|
|
|
|
|
class NCCLLibrary:
|
|
exported_functions = [
|
|
# const char* ncclGetErrorString(ncclResult_t result)
|
|
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
|
|
# ncclResult_t ncclGetVersion(int *version);
|
|
Function("ncclGetVersion", ncclResult_t,
|
|
[ctypes.POINTER(ctypes.c_int)]),
|
|
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
|
|
Function("ncclGetUniqueId", ncclResult_t,
|
|
[ctypes.POINTER(ncclUniqueId)]),
|
|
# ncclResult_t ncclCommInitRank(
|
|
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
|
|
# note that ncclComm_t is a pointer type, so the first argument
|
|
# is a pointer to a pointer
|
|
Function("ncclCommInitRank", ncclResult_t, [
|
|
ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId,
|
|
ctypes.c_int
|
|
]),
|
|
# ncclResult_t ncclAllReduce(
|
|
# const void* sendbuff, void* recvbuff, size_t count,
|
|
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
|
# cudaStream_t stream);
|
|
# note that cudaStream_t is a pointer type, so the last argument
|
|
# is a pointer
|
|
Function("ncclAllReduce", ncclResult_t, [
|
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
|
]),
|
|
|
|
# be cautious! this is a collective call, it will block until all
|
|
# processes in the communicator have called this function.
|
|
# because Python object destruction can happen in random order,
|
|
# it is better not to call it at all.
|
|
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
|
]
|
|
|
|
# class attribute to store the mapping from the path to the library
|
|
# to avoid loading the same library multiple times
|
|
path_to_library_cache: Dict[str, Any] = {}
|
|
|
|
# class attribute to store the mapping from library path
|
|
# to the corresponding dictionary
|
|
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def __init__(self, so_file: Optional[str] = None):
|
|
|
|
so_file = so_file or find_nccl_library()
|
|
|
|
try:
|
|
# load the library in another process.
|
|
# if it core dumps, it will not crash the current process
|
|
nccl_integrity_check(so_file)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to load NCCL library from %s ."
|
|
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
|
"Otherwise, the nccl library might not exist, be corrupted "
|
|
"or it does not support the current platform %s."
|
|
"One solution is to download libnccl2 version 2.18 from "
|
|
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
|
"and extract the libnccl.so.2 file. If you already have the "
|
|
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
|
" to point to the correct nccl library path.", so_file,
|
|
platform.platform())
|
|
raise e
|
|
|
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
|
lib = ctypes.CDLL(so_file)
|
|
NCCLLibrary.path_to_library_cache[so_file] = lib
|
|
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
|
|
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
|
_funcs = {}
|
|
for func in NCCLLibrary.exported_functions:
|
|
f = getattr(self.lib, func.name)
|
|
f.restype = func.restype
|
|
f.argtypes = func.argtypes
|
|
_funcs[func.name] = f
|
|
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
|
|
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
|
|
|
|
def ncclGetErrorString(self, result: ncclResult_t) -> str:
|
|
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
|
|
|
|
def NCCL_CHECK(self, result: ncclResult_t) -> None:
|
|
if result != 0:
|
|
error_str = self.ncclGetErrorString(result)
|
|
raise RuntimeError(f"NCCL error: {error_str}")
|
|
|
|
def ncclGetVersion(self) -> str:
|
|
version = ctypes.c_int()
|
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
|
version_str = str(version.value)
|
|
# something like 21903 --> "2.19.3"
|
|
major = version_str[0].lstrip("0")
|
|
minor = version_str[1:3].lstrip("0")
|
|
patch = version_str[3:].lstrip("0")
|
|
return f"{major}.{minor}.{patch}"
|
|
|
|
def ncclGetUniqueId(self) -> ncclUniqueId:
|
|
unique_id = ncclUniqueId()
|
|
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](
|
|
ctypes.byref(unique_id)))
|
|
return unique_id
|
|
|
|
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
|
rank: int) -> ncclComm_t:
|
|
comm = ncclComm_t()
|
|
self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm),
|
|
world_size, unique_id,
|
|
rank))
|
|
return comm
|
|
|
|
def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
|
count: int, datatype: int, op: int, comm: ncclComm_t,
|
|
stream: cudaStream_t) -> None:
|
|
# `datatype` actually should be `ncclDataType_t`
|
|
# and `op` should be `ncclRedOp_t`
|
|
# both are aliases of `ctypes.c_int`
|
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
|
# by ctypes automatically
|
|
self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count,
|
|
datatype, op, comm,
|
|
stream))
|
|
|
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
|
|
|
|
|
__all__ = [
|
|
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
|
"ncclComm_t", "cudaStream_t", "buffer_type"
|
|
]
|