[distributed][misc] add specialized method for cuda platform (#7249)

This commit is contained in:
youkaichao 2024-08-07 08:54:52 -07:00 committed by GitHub
parent 66d617e343
commit 639159b2a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 53 deletions

View File

@ -11,7 +11,8 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless
try:
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
@ -113,7 +114,10 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink = is_full_nvlink(physical_device_ids)
assert current_platform.is_cuda()
from vllm.platforms.cuda import CudaPlatform
cuda_platform: CudaPlatform = current_platform
full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
if world_size > 2 and not full_nvlink:
logger.warning(
"Custom allreduce is disabled because it's not supported on"

View File

@ -4,12 +4,21 @@ pynvml. However, it should not initialize cuda context.
import os
from functools import lru_cache, wraps
from typing import Tuple
from typing import List, Tuple
import pynvml
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
logger = init_logger(__name__)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
def with_nvml_context(fn):
@ -47,3 +56,29 @@ class CudaPlatform(Platform):
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)
@staticmethod
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True

View File

@ -1034,56 +1034,6 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
try:
import pynvml
except ImportError:
# For non-NV devices
pynvml = None
def with_nvml_context(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if pynvml is not None:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
if pynvml is not None:
pynvml.nvmlShutdown()
return wrapper
@with_nvml_context
def is_full_nvlink(device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError as error:
logger.error(
"NVLink detection failed. This is normal if your"
" machine has no NVLink equipped.",
exc_info=error)
return False
return True
#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f):