mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:25:01 +08:00
[Core][Distributed] improve p2p cache generation (#5528)
This commit is contained in:
parent
28c145eb57
commit
f5bb85b435
146
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
146
vllm/distributed/device_communicators/cuda_wrapper.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""This file is a pure Python wrapper for the cudart library.
|
||||||
|
It avoids the need to compile a separate shared library, and is
|
||||||
|
convenient for use when we just need to call a few functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||||
|
import torch # noqa
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# === export types and functions from cudart to Python ===
|
||||||
|
# for the original cudart definition, please check
|
||||||
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
|
||||||
|
|
||||||
|
cudaError_t = ctypes.c_int
|
||||||
|
cudaMemcpyKind = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||||
|
_fields_ = [("internal", ctypes.c_byte * 128)]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Function:
|
||||||
|
name: str
|
||||||
|
restype: Any
|
||||||
|
argtypes: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class CudaRTLibrary:
|
||||||
|
exported_functions = [
|
||||||
|
# cudaError_t cudaSetDevice ( int device )
|
||||||
|
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
|
||||||
|
# cudaError_t cudaDeviceSynchronize ( void )
|
||||||
|
Function("cudaDeviceSynchronize", cudaError_t, []),
|
||||||
|
# cudaError_t cudaDeviceReset ( void )
|
||||||
|
Function("cudaDeviceReset", cudaError_t, []),
|
||||||
|
|
||||||
|
# const char* cudaGetErrorString ( cudaError_t error )
|
||||||
|
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
|
||||||
|
|
||||||
|
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
|
||||||
|
Function("cudaMalloc", cudaError_t,
|
||||||
|
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
|
||||||
|
# cudaError_t cudaFree ( void* devPtr )
|
||||||
|
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
|
||||||
|
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
|
||||||
|
Function("cudaMemset", cudaError_t,
|
||||||
|
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
|
||||||
|
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
|
||||||
|
Function("cudaMemcpy", cudaError_t, [
|
||||||
|
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
|
||||||
|
]),
|
||||||
|
|
||||||
|
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
|
||||||
|
Function("cudaIpcGetMemHandle", cudaError_t,
|
||||||
|
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
|
||||||
|
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
|
||||||
|
Function("cudaIpcOpenMemHandle", cudaError_t, [
|
||||||
|
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# class attribute to store the mapping from the path to the library
|
||||||
|
# to avoid loading the same library multiple times
|
||||||
|
path_to_library_cache: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
# class attribute to store the mapping from library path
|
||||||
|
# to the corresponding dictionary
|
||||||
|
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def __init__(self, so_file: Optional[str] = None):
|
||||||
|
if so_file is None:
|
||||||
|
assert torch.version.cuda is not None
|
||||||
|
major_version = torch.version.cuda.split(".")[0]
|
||||||
|
so_file = f"libcudart.so.{major_version}"
|
||||||
|
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||||
|
lib = ctypes.CDLL(so_file)
|
||||||
|
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||||
|
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
|
||||||
|
|
||||||
|
if so_file not in CudaRTLibrary.path_to_dict_mapping:
|
||||||
|
_funcs = {}
|
||||||
|
for func in CudaRTLibrary.exported_functions:
|
||||||
|
f = getattr(self.lib, func.name)
|
||||||
|
f.restype = func.restype
|
||||||
|
f.argtypes = func.argtypes
|
||||||
|
_funcs[func.name] = f
|
||||||
|
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
|
||||||
|
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
|
||||||
|
|
||||||
|
def CUDART_CHECK(self, result: cudaError_t) -> None:
|
||||||
|
if result != 0:
|
||||||
|
error_str = self.cudaGetErrorString(result)
|
||||||
|
raise RuntimeError(f"CUDART error: {error_str}")
|
||||||
|
|
||||||
|
def cudaGetErrorString(self, error: cudaError_t) -> str:
|
||||||
|
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
|
||||||
|
|
||||||
|
def cudaSetDevice(self, device: int) -> None:
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
|
||||||
|
|
||||||
|
def cudaDeviceSynchronize(self) -> None:
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
|
||||||
|
|
||||||
|
def cudaDeviceReset(self) -> None:
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
|
||||||
|
|
||||||
|
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
|
||||||
|
devPtr = ctypes.c_void_p()
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
|
||||||
|
return devPtr
|
||||||
|
|
||||||
|
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
|
||||||
|
|
||||||
|
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
|
||||||
|
count: int) -> None:
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
|
||||||
|
|
||||||
|
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
|
||||||
|
count: int) -> None:
|
||||||
|
cudaMemcpyDefault = 4
|
||||||
|
kind = cudaMemcpyDefault
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
|
||||||
|
|
||||||
|
def cudaIpcGetMemHandle(self,
|
||||||
|
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
|
||||||
|
handle = cudaIpcMemHandle_t()
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
|
||||||
|
ctypes.byref(handle), devPtr))
|
||||||
|
return handle
|
||||||
|
|
||||||
|
def cudaIpcOpenMemHandle(self,
|
||||||
|
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
|
||||||
|
cudaIpcMemLazyEnablePeerAccess = 1
|
||||||
|
devPtr = ctypes.c_void_p()
|
||||||
|
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
|
||||||
|
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
|
||||||
|
return devPtr
|
||||||
@ -1,87 +1,98 @@
|
|||||||
|
import ctypes
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
from itertools import product
|
||||||
import tempfile
|
from typing import Dict, Optional, Sequence
|
||||||
import time
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cuda_device_count_stateless
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
def producer(batch_src: Sequence[int],
|
||||||
def mute_output():
|
producer_queue,
|
||||||
with open(os.devnull, "w") as f:
|
consumer_queue,
|
||||||
sys.stderr = f
|
result_queue,
|
||||||
sys.stdout = f
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def producer(i: int,
|
|
||||||
init_method: str,
|
|
||||||
cuda_visible_devices: Optional[str] = None):
|
cuda_visible_devices: Optional[str] = None):
|
||||||
if cuda_visible_devices is not None:
|
if cuda_visible_devices is not None:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||||
with mute_output():
|
|
||||||
dist.init_process_group(
|
lib = CudaRTLibrary()
|
||||||
backend="gloo",
|
for i in batch_src:
|
||||||
init_method=init_method,
|
lib.cudaSetDevice(i)
|
||||||
world_size=2,
|
pointer = lib.cudaMalloc(1024)
|
||||||
rank=0,
|
lib.cudaMemset(pointer, 1, 1024)
|
||||||
)
|
lib.cudaDeviceSynchronize()
|
||||||
# produce a tensor in GPU i
|
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||||
data = torch.zeros((128, ), device=f"cuda:{i}")
|
producer_queue.put(handle)
|
||||||
# get the information to reconstruct the shared tensor
|
open_success = consumer_queue.get()
|
||||||
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
|
if open_success:
|
||||||
args = list(args)
|
# use two queues to simulate barrier
|
||||||
dist.broadcast_object_list([(func, args)], src=0)
|
producer_queue.put(0)
|
||||||
dist.barrier()
|
consumer_queue.get()
|
||||||
torch.cuda.synchronize()
|
# check if the memory is modified
|
||||||
assert torch.all(data == 1).item()
|
host_data = (ctypes.c_char * 1024)()
|
||||||
|
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||||
|
for i in range(1024):
|
||||||
|
if ord(host_data[i]) != 2:
|
||||||
|
open_success = False
|
||||||
|
break
|
||||||
|
result_queue.put(open_success)
|
||||||
|
lib.cudaDeviceReset()
|
||||||
|
|
||||||
|
|
||||||
def consumer(j: int,
|
def consumer(batch_tgt: Sequence[int],
|
||||||
init_method: str,
|
producer_queue,
|
||||||
|
consumer_queue,
|
||||||
|
result_queue,
|
||||||
cuda_visible_devices: Optional[str] = None):
|
cuda_visible_devices: Optional[str] = None):
|
||||||
if cuda_visible_devices is not None:
|
if cuda_visible_devices is not None:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
|
||||||
with mute_output():
|
|
||||||
dist.init_process_group(
|
lib = CudaRTLibrary()
|
||||||
backend="gloo",
|
for j in batch_tgt:
|
||||||
init_method=init_method,
|
lib.cudaSetDevice(j)
|
||||||
world_size=2,
|
handle = producer_queue.get()
|
||||||
rank=1,
|
open_success = False
|
||||||
)
|
try:
|
||||||
torch.cuda.set_device(j)
|
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||||
recv = [None]
|
open_success = True
|
||||||
dist.broadcast_object_list(recv, src=0)
|
except RuntimeError:
|
||||||
func: Callable
|
# cannot error out here, because the producer process
|
||||||
args: List
|
# is still waiting for the response.
|
||||||
func, args = recv[0] # type: ignore
|
pass
|
||||||
# `args[6]` is the device id
|
consumer_queue.put(open_success)
|
||||||
# by default pytorch will use `i` from the producer
|
if open_success:
|
||||||
# here we need to set it to `j` to test P2P access
|
# modify the memory
|
||||||
args[6] = j
|
lib.cudaMemset(pointer, 2, 1024)
|
||||||
data = func(*args)
|
# use two queues to simulate barrier
|
||||||
data += 1
|
producer_queue.get()
|
||||||
dist.barrier()
|
consumer_queue.put(0)
|
||||||
torch.cuda.synchronize()
|
# check if the memory is modified
|
||||||
assert torch.all(data == 1).item()
|
host_data = (ctypes.c_char * 1024)()
|
||||||
|
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||||
|
for i in range(1024):
|
||||||
|
if ord(host_data[i]) != 2:
|
||||||
|
open_success = False
|
||||||
|
break
|
||||||
|
result_queue.put(open_success)
|
||||||
|
lib.cudaDeviceReset()
|
||||||
|
|
||||||
|
|
||||||
def can_actually_p2p(i, j):
|
def can_actually_p2p(
|
||||||
|
batch_src: Sequence[int],
|
||||||
|
batch_tgt: Sequence[int],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Usually, checking if P2P access is enabled can be done by
|
Usually, checking if P2P access is enabled can be done by
|
||||||
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
|
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||||
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)`
|
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||||
returns `True` even if P2P access is not actually possible.
|
returns `True` even if P2P access is not actually possible.
|
||||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||||
@ -90,41 +101,50 @@ def can_actually_p2p(i, j):
|
|||||||
|
|
||||||
Note on p2p and cuda IPC:
|
Note on p2p and cuda IPC:
|
||||||
Usually, one process uses one GPU:
|
Usually, one process uses one GPU:
|
||||||
GPU i --> cuda context i --> tensor i --> process i
|
GPU src --> cuda context src --> tensor src --> process src
|
||||||
|
|
||||||
We need to combine p2p and cuda IPC, so that:
|
We need to combine p2p and cuda IPC, so that:
|
||||||
GPU i --> cuda context i --> tensor i --> process i
|
GPU src --> cuda context src --> tensor src --> process src
|
||||||
|shared|
|
|shared|
|
||||||
GPU j --> cuda context j --> tensor j --> process j
|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||||
That is to say, process i creates a tensor in GPU i, passes IPC handle to
|
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||||
process j, and process j accesses the tensor in GPU j. Any operation on the
|
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||||
tensor in process j will be reflected in the tensor in process i, because
|
tensor in process tgt will be reflected in the tensor in process src, because
|
||||||
they are the same memory segment.
|
they are the same memory segment.
|
||||||
It is important to note that process j accesses the tensor in GPU j, not
|
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||||
GPU i. That's why we need p2p access. # noqa
|
GPU src. That's why we need p2p access.
|
||||||
"""
|
|
||||||
|
The most time-consuming part is the process creation. To avoid creating
|
||||||
|
processes for every pair of GPUs, we use batched testing. We create two
|
||||||
|
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||||
|
the device after each test (which is not available in PyTorch).
|
||||||
|
""" # noqa
|
||||||
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
|
||||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||||
# to make sure they see the same set of GPUs
|
# to make sure they see the same set of GPUs
|
||||||
|
|
||||||
# make sure the temp file is not the same across different calls
|
|
||||||
temp_path = tempfile.mktemp() + str(time.time())
|
|
||||||
# create an empty file
|
|
||||||
with open(temp_path, "w"):
|
|
||||||
pass
|
|
||||||
init_method = f"file://{temp_path}"
|
|
||||||
|
|
||||||
# make sure the processes are spawned
|
# make sure the processes are spawned
|
||||||
smp = mp.get_context("spawn")
|
smp = mp.get_context("spawn")
|
||||||
pi = smp.Process(target=producer,
|
producer_queue = smp.Queue()
|
||||||
args=(i, init_method, cuda_visible_devices))
|
consumer_queue = smp.Queue()
|
||||||
pj = smp.Process(target=consumer,
|
result_queue = smp.Queue()
|
||||||
args=(j, init_method, cuda_visible_devices))
|
p_src = smp.Process(target=producer,
|
||||||
pi.start()
|
args=(batch_src, producer_queue, consumer_queue,
|
||||||
pj.start()
|
result_queue, cuda_visible_devices))
|
||||||
pi.join()
|
p_tgt = smp.Process(target=consumer,
|
||||||
pj.join()
|
args=(batch_tgt, producer_queue, consumer_queue,
|
||||||
return pi.exitcode == 0 and pj.exitcode == 0
|
result_queue, cuda_visible_devices))
|
||||||
|
p_src.start()
|
||||||
|
p_tgt.start()
|
||||||
|
p_src.join()
|
||||||
|
p_tgt.join()
|
||||||
|
result = []
|
||||||
|
for src, tgt in zip(batch_src, batch_tgt):
|
||||||
|
a = result_queue.get()
|
||||||
|
b = result_queue.get()
|
||||||
|
assert a == b
|
||||||
|
result.append(a)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# why do we need this cache?
|
# why do we need this cache?
|
||||||
@ -142,14 +162,14 @@ def can_actually_p2p(i, j):
|
|||||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||||
|
|
||||||
|
|
||||||
def gpu_p2p_access_check(i: int, j: int) -> bool:
|
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||||
"""Check if GPU i can access GPU j."""
|
"""Check if GPU src can access GPU tgt."""
|
||||||
|
|
||||||
# if the cache variable is already calculated,
|
# if the cache variable is already calculated,
|
||||||
# read from the cache instead of checking it again
|
# read from the cache instead of checking it again
|
||||||
global _gpu_p2p_access_cache
|
global _gpu_p2p_access_cache
|
||||||
if _gpu_p2p_access_cache is not None:
|
if _gpu_p2p_access_cache is not None:
|
||||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||||
|
|
||||||
is_distributed = dist.is_initialized()
|
is_distributed = dist.is_initialized()
|
||||||
|
|
||||||
@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|||||||
# enter this block to calculate the cache
|
# enter this block to calculate the cache
|
||||||
logger.info("generating GPU P2P access cache in %s", path)
|
logger.info("generating GPU P2P access cache in %s", path)
|
||||||
cache = {}
|
cache = {}
|
||||||
for _i in range(num_dev):
|
ids = list(range(num_dev))
|
||||||
for _j in range(num_dev):
|
# batch of all pairs of GPUs
|
||||||
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j)
|
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||||
|
result = can_actually_p2p(batch_src, batch_tgt)
|
||||||
|
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||||
|
cache[f"{_i}->{_j}"] = r
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(cache, f, indent=4)
|
json.dump(cache, f, indent=4)
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
cache = json.load(f)
|
cache = json.load(f)
|
||||||
_gpu_p2p_access_cache = cache
|
_gpu_p2p_access_cache = cache
|
||||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["gpu_p2p_access_check"]
|
__all__ = ["gpu_p2p_access_check"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user