mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:06:25 +08:00
[Core][Distributed] improve p2p cache generation (#5528)
This commit is contained in:
parent
28c145eb57
commit
f5bb85b435
146
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
146
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""This file is a pure Python wrapper for the cudart library.
|
||||
It avoids the need to compile a separate shared library, and is
|
||||
convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# === export types and functions from cudart to Python ===
|
||||
# for the original cudart definition, please check
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
name: str
|
||||
restype: Any
|
||||
argtypes: List[Any]
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||
# cudaError_t cudaDeviceSynchronize ( void )
|
||||
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||
# cudaError_t cudaDeviceReset ( void )
|
||||
Function("cudaDeviceReset", cudaError_t, []),
|
||||
|
||||
# const char* cudaGetErrorString ( cudaError_t error )
|
||||
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||
|
||||
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||
Function("cudaMalloc", cudaError_t,
|
||||
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||
# cudaError_t cudaFree ( void* devPtr )
|
||||
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||
Function("cudaMemset", cudaError_t,
|
||||
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||
Function("cudaMemcpy", cudaError_t, [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||
]),
|
||||
|
||||
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||
Function("cudaIpcGetMemHandle", cudaError_t,
|
||||
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||
Function("cudaIpcOpenMemHandle", cudaError_t, [
|
||||
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||
]),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
# to avoid loading the same library multiple times
|
||||
path_to_library_cache: Dict[str, Any] = {}
|
||||
|
||||
# class attribute to store the mapping from library path
|
||||
# to the corresponding dictionary
|
||||
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
assert torch.version.cuda is not None
|
||||
major_version = torch.version.cuda.split(".")[0]
|
||||
so_file = f"libcudart.so.{major_version}"
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in CudaRTLibrary.exported_functions:
|
||||
f = getattr(self.lib, func.name)
|
||||
f.restype = func.restype
|
||||
f.argtypes = func.argtypes
|
||||
_funcs[func.name] = f
|
||||
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||
|
||||
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||
if result != 0:
|
||||
error_str = self.cudaGetErrorString(result)
|
||||
raise RuntimeError(f"CUDART error: {error_str}")
|
||||
|
||||
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||
|
||||
def cudaSetDevice(self, device: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||
|
||||
def cudaDeviceSynchronize(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||
|
||||
def cudaDeviceReset(self) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||
|
||||
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||
return devPtr
|
||||
|
||||
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||
|
||||
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||
count: int) -> None:
|
||||
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||
|
||||
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||
count: int) -> None:
|
||||
cudaMemcpyDefault = 4
|
||||
kind = cudaMemcpyDefault
|
||||
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||
|
||||
def cudaIpcGetMemHandle(self,
|
||||
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||
handle = cudaIpcMemHandle_t()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
|
||||
ctypes.byref(handle), devPtr))
|
||||
return handle
|
||||
|
||||
def cudaIpcOpenMemHandle(self,
|
||||
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||
cudaIpcMemLazyEnablePeerAccess = 1
|
||||
devPtr = ctypes.c_void_p()
|
||||
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
|
||||
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||
return devPtr
|
||||
@ -1,87 +1,98 @@
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from itertools import product
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mute_output():
|
||||
with open(os.devnull, "w") as f:
|
||||
sys.stderr = f
|
||||
sys.stdout = f
|
||||
yield
|
||||
|
||||
|
||||
def producer(i: int,
|
||||
init_method: str,
|
||||
def producer(batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||
with mute_output():
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
world_size=2,
|
||||
rank=0,
|
||||
)
|
||||
# produce a tensor in GPU i
|
||||
data = torch.zeros((128, ), device=f"cuda:{i}")
|
||||
# get the information to reconstruct the shared tensor
|
||||
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
|
||||
args = list(args)
|
||||
dist.broadcast_object_list([(func, args)], src=0)
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(data == 1).item()
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(j: int,
|
||||
init_method: str,
|
||||
def consumer(batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: Optional[str] = None):
|
||||
if cuda_visible_devices is not None:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||
with mute_output():
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
world_size=2,
|
||||
rank=1,
|
||||
)
|
||||
torch.cuda.set_device(j)
|
||||
recv = [None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
func: Callable
|
||||
args: List
|
||||
func, args = recv[0] # type: ignore
|
||||
# `args[6]` is the device id
|
||||
# by default pytorch will use `i` from the producer
|
||||
# here we need to set it to `j` to test P2P access
|
||||
args[6] = j
|
||||
data = func(*args)
|
||||
data += 1
|
||||
dist.barrier()
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(data == 1).item()
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(i, j):
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
):
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)`
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
@ -90,41 +101,50 @@ def can_actually_p2p(i, j):
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU i --> cuda context i --> tensor i --> process i
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU i --> cuda context i --> tensor i --> process i
|
||||
|shared|
|
||||
GPU j --> cuda context j --> tensor j --> process j
|
||||
That is to say, process i creates a tensor in GPU i, passes IPC handle to
|
||||
process j, and process j accesses the tensor in GPU j. Any operation on the
|
||||
tensor in process j will be reflected in the tensor in process i, because
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process j accesses the tensor in GPU j, not
|
||||
GPU i. That's why we need p2p access. # noqa
|
||||
"""
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the temp file is not the same across different calls
|
||||
temp_path = tempfile.mktemp() + str(time.time())
|
||||
# create an empty file
|
||||
with open(temp_path, "w"):
|
||||
pass
|
||||
init_method = f"file://{temp_path}"
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
pi = smp.Process(target=producer,
|
||||
args=(i, init_method, cuda_visible_devices))
|
||||
pj = smp.Process(target=consumer,
|
||||
args=(j, init_method, cuda_visible_devices))
|
||||
pi.start()
|
||||
pj.start()
|
||||
pi.join()
|
||||
pj.join()
|
||||
return pi.exitcode == 0 and pj.exitcode == 0
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(target=producer,
|
||||
args=(batch_src, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_tgt = smp.Process(target=consumer,
|
||||
args=(batch_tgt, producer_queue, consumer_queue,
|
||||
result_queue, cuda_visible_devices))
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
result = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
assert a == b
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
@ -142,14 +162,14 @@ def can_actually_p2p(i, j):
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
"""Check if GPU i can access GPU j."""
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache = {}
|
||||
for _i in range(num_dev):
|
||||
for _j in range(num_dev):
|
||||
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j)
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
with open(path, "r") as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user