vllm/vllm/platforms/cuda.py
Chendi.Xue 0a71900bc9
Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
2024-11-26 17:57:11 -08:00

241 lines
8.2 KiB
Python

"""Code inside this file can safely assume cuda platform, e.g. importing
pynvml. However, it should not initialize cuda context.
"""
import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, List, TypeVar
import pynvml
import torch
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
if pynvml.__file__.endswith("__init__.py"):
logger.warning(
"You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.")
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
msg = (
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
" GPU support is disabled. If you are using ray, please unset"
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
" worker/actor. "
"Check https://github.com/vllm-project/vllm/issues/8402 for"
" more information.")
raise RuntimeError(msg)
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
return device_id
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()
return wrapper
class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
dispatch_key: str = "CUDA"
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
raise NotImplementedError
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
raise NotImplementedError
@classmethod
def log_warnings(cls):
pass
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return DeviceCapability(major=major, minor=minor)
@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return cls._get_physical_device_name(physical_device_id)
@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
handles = [
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
]
for i, handle in enumerate(handles):
for j, peer_handle in enumerate(handles):
if i < j:
try:
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
handle,
peer_handle,
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
)
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
return False
except pynvml.NVMLError:
logger.exception(
"NVLink detection failed. This is normal if"
" your machine has no NVLink equipped.")
return False
return True
@classmethod
def _get_physical_device_name(cls, device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)
@classmethod
@with_nvml_context
def log_warnings(cls):
device_ids: int = pynvml.nvmlDeviceGetCount()
if device_ids > 1:
device_names = [
cls._get_physical_device_name(i) for i in range(device_ids)
]
if (len(set(device_names)) > 1
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
logger.warning(
"Detected different devices in the system: \n%s\nPlease"
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
"avoid unexpected behavior.",
"\n".join(device_names),
)
class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")
return False
# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
nvml_available = False
try:
try:
pynvml.nvmlInit()
nvml_available = True
except Exception:
# On Jetson, NVML is not supported.
nvml_available = False
finally:
if nvml_available:
pynvml.nvmlShutdown()
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
try:
from sphinx.ext.autodoc.mock import _MockModule
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()