[Core][Distributed] improve p2p cache generation (#5528)

This commit is contained in:
youkaichao 2024-06-14 14:47:45 -07:00 committed by GitHub
parent 28c145eb57
commit f5bb85b435
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 261 additions and 92 deletions

View File

@ -0,0 +1,146 @@
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa
from vllm.logger import init_logger
logger = init_logger(__name__)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t,
[ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function("cudaIpcOpenMemHandle", cudaError_t, [
ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
assert torch.version.cuda is not None
major_version = torch.version.cuda.split(".")[0]
so_file = f"libcudart.so.{major_version}"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self,
devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self,
handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
return devPtr

View File

@ -1,87 +1,98 @@
import ctypes
import json
import os
import sys
import tempfile
import time
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional
from itertools import product
from typing import Dict, Optional, Sequence
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__)
@contextmanager
def mute_output():
with open(os.devnull, "w") as f:
sys.stderr = f
sys.stdout = f
yield
def producer(i: int,
init_method: str,
def producer(batch_src: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
with mute_output():
dist.init_process_group(
backend="gloo",
init_method=init_method,
world_size=2,
rank=0,
)
# produce a tensor in GPU i
data = torch.zeros((128, ), device=f"cuda:{i}")
# get the information to reconstruct the shared tensor
func, args = torch.multiprocessing.reductions.reduce_tensor(data)
args = list(args)
dist.broadcast_object_list([(func, args)], src=0)
dist.barrier()
torch.cuda.synchronize()
assert torch.all(data == 1).item()
lib = CudaRTLibrary()
for i in batch_src:
lib.cudaSetDevice(i)
pointer = lib.cudaMalloc(1024)
lib.cudaMemset(pointer, 1, 1024)
lib.cudaDeviceSynchronize()
handle = lib.cudaIpcGetMemHandle(pointer)
producer_queue.put(handle)
open_success = consumer_queue.get()
if open_success:
# use two queues to simulate barrier
producer_queue.put(0)
consumer_queue.get()
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def consumer(j: int,
init_method: str,
def consumer(batch_tgt: Sequence[int],
producer_queue,
consumer_queue,
result_queue,
cuda_visible_devices: Optional[str] = None):
if cuda_visible_devices is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
with mute_output():
dist.init_process_group(
backend="gloo",
init_method=init_method,
world_size=2,
rank=1,
)
torch.cuda.set_device(j)
recv = [None]
dist.broadcast_object_list(recv, src=0)
func: Callable
args: List
func, args = recv[0] # type: ignore
# `args[6]` is the device id
# by default pytorch will use `i` from the producer
# here we need to set it to `j` to test P2P access
args[6] = j
data = func(*args)
data += 1
dist.barrier()
torch.cuda.synchronize()
assert torch.all(data == 1).item()
lib = CudaRTLibrary()
for j in batch_tgt:
lib.cudaSetDevice(j)
handle = producer_queue.get()
open_success = False
try:
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
open_success = True
except RuntimeError:
# cannot error out here, because the producer process
# is still waiting for the response.
pass
consumer_queue.put(open_success)
if open_success:
# modify the memory
lib.cudaMemset(pointer, 2, 1024)
# use two queues to simulate barrier
producer_queue.get()
consumer_queue.put(0)
# check if the memory is modified
host_data = (ctypes.c_char * 1024)()
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
for i in range(1024):
if ord(host_data[i]) != 2:
open_success = False
break
result_queue.put(open_success)
lib.cudaDeviceReset()
def can_actually_p2p(i, j):
def can_actually_p2p(
batch_src: Sequence[int],
batch_tgt: Sequence[int],
):
"""
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(i, j)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(i, j)`
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
@ -90,41 +101,50 @@ def can_actually_p2p(i, j):
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
GPU i --> cuda context i --> tensor i --> process i
GPU src --> cuda context src --> tensor src --> process src
We need to combine p2p and cuda IPC, so that:
GPU i --> cuda context i --> tensor i --> process i
|shared|
GPU j --> cuda context j --> tensor j --> process j
That is to say, process i creates a tensor in GPU i, passes IPC handle to
process j, and process j accesses the tensor in GPU j. Any operation on the
tensor in process j will be reflected in the tensor in process i, because
GPU src --> cuda context src --> tensor src --> process src
|shared|
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
That is to say, process src creates a tensor in GPU src, passes IPC handle to
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
tensor in process tgt will be reflected in the tensor in process src, because
they are the same memory segment.
It is important to note that process j accesses the tensor in GPU j, not
GPU i. That's why we need p2p access. # noqa
"""
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU src. That's why we need p2p access.
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
""" # noqa
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# make sure the temp file is not the same across different calls
temp_path = tempfile.mktemp() + str(time.time())
# create an empty file
with open(temp_path, "w"):
pass
init_method = f"file://{temp_path}"
# make sure the processes are spawned
smp = mp.get_context("spawn")
pi = smp.Process(target=producer,
args=(i, init_method, cuda_visible_devices))
pj = smp.Process(target=consumer,
args=(j, init_method, cuda_visible_devices))
pi.start()
pj.start()
pi.join()
pj.join()
return pi.exitcode == 0 and pj.exitcode == 0
producer_queue = smp.Queue()
consumer_queue = smp.Queue()
result_queue = smp.Queue()
p_src = smp.Process(target=producer,
args=(batch_src, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_tgt = smp.Process(target=consumer,
args=(batch_tgt, producer_queue, consumer_queue,
result_queue, cuda_visible_devices))
p_src.start()
p_tgt.start()
p_src.join()
p_tgt.join()
result = []
for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get()
b = result_queue.get()
assert a == b
result.append(a)
return result
# why do we need this cache?
@ -142,14 +162,14 @@ def can_actually_p2p(i, j):
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
"""Check if GPU src can access GPU tgt."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
is_distributed = dist.is_initialized()
@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
# enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path)
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
cache[f"{_i}->{_j}"] = can_actually_p2p(_i, _j)
ids = list(range(num_dev))
# batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids)))
result = can_actually_p2p(batch_src, batch_tgt)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
__all__ = ["gpu_p2p_access_check"]