mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[misc][cuda] use nvml to avoid accidentally cuda initialization (#6007)
This commit is contained in:
parent
af9ad46fca
commit
614aa51203
@ -8,12 +8,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = get_device_capability_stateless()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
|
||||
def is_quant_method_supported(quant_method: str) -> bool:
|
||||
@ -8,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = get_device_capability_stateless()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return (capability >=
|
||||
QUANTIZATION_METHODS[quant_method].get_min_capability())
|
||||
|
||||
@ -2,13 +2,13 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 8)
|
||||
and get_device_capability_stateless()[0] >= 8)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
@ -235,4 +235,4 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
v,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
sm_scale=sm_scale)
|
||||
sm_scale=sm_scale)
|
||||
|
||||
@ -5,6 +5,8 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@triton.jit
|
||||
@ -683,7 +685,7 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
|
||||
cap = torch.cuda.get_device_capability()
|
||||
cap = get_device_capability_stateless()
|
||||
BLOCK = 128 if cap[0] >= 8 else 64
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
@ -11,66 +11,18 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||
gpu_p2p_access_check)
|
||||
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
from vllm.utils import cuda_device_count_stateless, is_full_nvlink
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
# Simulate ImportError if custom_ar ops are not supported.
|
||||
if not ops.is_custom_op_supported("_C_custom_ar::meta_size"):
|
||||
raise ImportError("custom_ar", __file__)
|
||||
|
||||
assert ops.is_custom_op_supported("_C_custom_ar::meta_size")
|
||||
custom_ar = True
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
except Exception:
|
||||
# For AMD GPUs and CPUs
|
||||
custom_ar = False
|
||||
pynvml = None
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@_nvml()
|
||||
def _is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
so it works on real physical device ids.
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.error(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
@ -161,7 +113,7 @@ class CustomAllreduce:
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = _is_full_nvlink(physical_device_ids)
|
||||
full_nvlink = is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
|
||||
@ -5,13 +5,14 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
|
||||
def _check_punica_support():
|
||||
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
|
||||
return
|
||||
|
||||
if torch.cuda.get_device_capability() < (8, 0):
|
||||
if get_device_capability_stateless() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0")
|
||||
else:
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
find_first_name_or_class_match)
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
@ -84,7 +85,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
def _check_gptq_and_marlin_can_run(self):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = get_device_capability_stateless()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < 80:
|
||||
raise RuntimeError("The quantization config is not supported for ",
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.utils import get_device_capability_stateless, print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@ -18,7 +18,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = get_device_capability_stateless()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -165,7 +166,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
major, minor = get_device_capability_stateless()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
@ -12,8 +12,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
|
||||
__cuda_arch = torch.cuda.get_device_capability()
|
||||
__cuda_arch = get_device_capability_stateless()
|
||||
|
||||
MARLIN_TILE = 16
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import is_tpu
|
||||
from vllm.utils import get_device_capability_stateless, is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -46,7 +46,7 @@ def _get_quantization_config(
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = get_device_capability_stateless()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
|
||||
@ -816,6 +816,63 @@ def cuda_device_count_stateless() -> int:
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
# the major benefit of using NVML is that it will not initialize CUDA
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError:
|
||||
# For non-NV devices
|
||||
pynvml = None
|
||||
|
||||
|
||||
def with_nvml_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if pynvml is not None:
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
if pynvml is not None:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
|
||||
for i, handle in enumerate(handles):
|
||||
for j, peer_handle in enumerate(handles):
|
||||
if i < j:
|
||||
try:
|
||||
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
||||
handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
|
||||
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
||||
return False
|
||||
except pynvml.NVMLError as error:
|
||||
logger.error(
|
||||
"NVLink detection failed. This is normal if your"
|
||||
" machine has no NVLink equipped.",
|
||||
exc_info=error)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f):
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
@ -322,7 +323,7 @@ def init_worker_distributed_environment(
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
compute_capability = get_device_capability_stateless()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user