mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 17:55:58 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
389 lines
14 KiB
Python
389 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""Code inside this file can safely assume cuda platform, e.g. importing
|
|
pynvml. However, it should not initialize cuda context.
|
|
"""
|
|
|
|
import os
|
|
from functools import lru_cache, wraps
|
|
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
|
Union)
|
|
|
|
import pynvml
|
|
import torch
|
|
from typing_extensions import ParamSpec
|
|
|
|
# import custom ops, trigger op registration
|
|
import vllm._C # noqa
|
|
import vllm.envs as envs
|
|
from vllm.logger import init_logger
|
|
|
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
else:
|
|
VllmConfig = None
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
|
|
if pynvml.__file__.endswith("__init__.py"):
|
|
logger.warning(
|
|
"You are using a deprecated `pynvml` package. Please install"
|
|
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
|
|
" When both of them are installed, `pynvml` will take precedence"
|
|
" and cause errors. See https://pypi.org/project/pynvml "
|
|
"for more information.")
|
|
|
|
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
|
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
|
torch.backends.cuda.enable_cudnn_sdp(False)
|
|
|
|
|
|
def device_id_to_physical_device_id(device_id: int) -> int:
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
if device_ids == [""]:
|
|
msg = (
|
|
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
|
|
" GPU support is disabled. If you are using ray, please unset"
|
|
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
|
|
" worker/actor. "
|
|
"Check https://github.com/vllm-project/vllm/issues/8402 for"
|
|
" more information.")
|
|
raise RuntimeError(msg)
|
|
physical_device_id = device_ids[device_id]
|
|
return int(physical_device_id)
|
|
else:
|
|
return device_id
|
|
|
|
|
|
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|
|
|
@wraps(fn)
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
pynvml.nvmlInit()
|
|
try:
|
|
return fn(*args, **kwargs)
|
|
finally:
|
|
pynvml.nvmlShutdown()
|
|
|
|
return wrapper
|
|
|
|
|
|
class CudaPlatformBase(Platform):
|
|
_enum = PlatformEnum.CUDA
|
|
device_name: str = "cuda"
|
|
device_type: str = "cuda"
|
|
dispatch_key: str = "CUDA"
|
|
ray_device_key: str = "GPU"
|
|
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
|
|
|
@classmethod
|
|
def get_device_capability(cls,
|
|
device_id: int = 0
|
|
) -> Optional[DeviceCapability]:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
if enforce_eager:
|
|
logger.warning(
|
|
"To see benefits of async output processing, enable CUDA "
|
|
"graph. Since, enforce-eager is enabled, async output "
|
|
"processor cannot be used")
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def log_warnings(cls):
|
|
pass
|
|
|
|
@classmethod
|
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
parallel_config = vllm_config.parallel_config
|
|
scheduler_config = vllm_config.scheduler_config
|
|
|
|
if parallel_config.worker_cls == "auto":
|
|
if scheduler_config.is_multi_step:
|
|
if envs.VLLM_USE_V1:
|
|
raise NotImplementedError(
|
|
"Multi-step scheduling is not supported (and not "
|
|
"needed) on VLLM V1. Please launch without "
|
|
"--num-scheduler-steps.")
|
|
else:
|
|
parallel_config.worker_cls = \
|
|
"vllm.worker.multi_step_worker.MultiStepWorker"
|
|
elif vllm_config.speculative_config:
|
|
if envs.VLLM_USE_V1:
|
|
raise NotImplementedError(
|
|
"Speculative decoding is not yet supported on VLLM V1."
|
|
)
|
|
else:
|
|
parallel_config.worker_cls = \
|
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
|
parallel_config.sd_worker_cls = \
|
|
"vllm.worker.worker.Worker"
|
|
else:
|
|
if envs.VLLM_USE_V1:
|
|
parallel_config.worker_cls = \
|
|
"vllm.v1.worker.gpu_worker.Worker"
|
|
else:
|
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
|
|
|
cache_config = vllm_config.cache_config
|
|
if cache_config and cache_config.block_size is None:
|
|
cache_config.block_size = 16
|
|
|
|
@classmethod
|
|
def get_current_memory_usage(cls,
|
|
device: Optional[torch.types.Device] = None
|
|
) -> float:
|
|
torch.cuda.reset_peak_memory_stats(device)
|
|
return torch.cuda.max_memory_allocated(device)
|
|
|
|
@classmethod
|
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
|
kv_cache_dtype, block_size, use_v1,
|
|
use_mla) -> str:
|
|
if use_v1:
|
|
logger.info("Using Flash Attention backend on V1 engine.")
|
|
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
|
if use_mla:
|
|
logger.info("Using Triton MLA backend.")
|
|
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
|
if selected_backend == _Backend.FLASHINFER:
|
|
logger.info("Using FlashInfer backend.")
|
|
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
|
elif selected_backend == _Backend.XFORMERS:
|
|
logger.info("Using XFormers backend.")
|
|
return "vllm.attention.backends.xformers.XFormersBackend"
|
|
elif selected_backend == _Backend.FLASH_ATTN:
|
|
pass
|
|
elif selected_backend:
|
|
raise ValueError(
|
|
f"Invalid attention backend for {cls.device_name}, "
|
|
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
|
|
|
target_backend = _Backend.FLASH_ATTN
|
|
if not cls.has_device_capability(80):
|
|
# Volta and Turing NVIDIA GPUs.
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
|
"GPUs.")
|
|
target_backend = _Backend.XFORMERS
|
|
elif dtype not in (torch.float16, torch.bfloat16):
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend for dtype other than "
|
|
"torch.float16 or torch.bfloat16.")
|
|
target_backend = _Backend.XFORMERS
|
|
elif kv_cache_dtype is not None and \
|
|
kv_cache_dtype.startswith("fp8"):
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
|
logger.warning(
|
|
"Please use FlashInfer backend with FP8 KV Cache for "
|
|
"better performance by setting environment variable "
|
|
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
|
target_backend = _Backend.XFORMERS
|
|
elif block_size % 16 != 0:
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend for block size not "
|
|
"divisible by 16.")
|
|
target_backend = _Backend.XFORMERS
|
|
|
|
# FlashAttn is valid for the model, checking if the package is
|
|
# installed.
|
|
if target_backend == _Backend.FLASH_ATTN:
|
|
try:
|
|
import vllm.vllm_flash_attn # noqa: F401
|
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
|
FlashAttentionBackend)
|
|
|
|
supported_sizes = \
|
|
FlashAttentionBackend.get_supported_head_sizes()
|
|
if head_size not in supported_sizes:
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend for head size %d.",
|
|
head_size)
|
|
target_backend = _Backend.XFORMERS
|
|
except ImportError:
|
|
logger.info(
|
|
"Cannot use FlashAttention-2 backend because the "
|
|
"vllm.vllm_flash_attn package is not found. "
|
|
"Make sure that vllm_flash_attn was built and installed "
|
|
"(on by default).")
|
|
target_backend = _Backend.XFORMERS
|
|
|
|
if target_backend == _Backend.XFORMERS:
|
|
logger.info("Using XFormers backend.")
|
|
return "vllm.attention.backends.xformers.XFormersBackend"
|
|
|
|
logger.info("Using Flash Attention backend.")
|
|
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
|
|
|
@classmethod
|
|
def get_punica_wrapper(cls) -> str:
|
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
|
|
|
|
|
# NVML utils
|
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
|
# all the related functions work on real physical device ids.
|
|
# the major benefit of using NVML is that it will not initialize CUDA
|
|
class NvmlCudaPlatform(CudaPlatformBase):
|
|
|
|
@classmethod
|
|
@lru_cache(maxsize=8)
|
|
@with_nvml_context
|
|
def get_device_capability(cls,
|
|
device_id: int = 0
|
|
) -> Optional[DeviceCapability]:
|
|
try:
|
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
|
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
|
return DeviceCapability(major=major, minor=minor)
|
|
except RuntimeError:
|
|
return None
|
|
|
|
@classmethod
|
|
@lru_cache(maxsize=8)
|
|
@with_nvml_context
|
|
def has_device_capability(
|
|
cls,
|
|
capability: Union[Tuple[int, int], int],
|
|
device_id: int = 0,
|
|
) -> bool:
|
|
try:
|
|
return super().has_device_capability(capability, device_id)
|
|
except RuntimeError:
|
|
return False
|
|
|
|
@classmethod
|
|
@lru_cache(maxsize=8)
|
|
@with_nvml_context
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
|
return cls._get_physical_device_name(physical_device_id)
|
|
|
|
@classmethod
|
|
@lru_cache(maxsize=8)
|
|
@with_nvml_context
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
physical_device_id = device_id_to_physical_device_id(device_id)
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
|
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
|
|
|
@classmethod
|
|
@with_nvml_context
|
|
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
|
"""
|
|
query if the set of gpus are fully connected by nvlink (1 hop)
|
|
"""
|
|
handles = [
|
|
pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids
|
|
]
|
|
for i, handle in enumerate(handles):
|
|
for j, peer_handle in enumerate(handles):
|
|
if i < j:
|
|
try:
|
|
p2p_status = pynvml.nvmlDeviceGetP2PStatus(
|
|
handle,
|
|
peer_handle,
|
|
pynvml.NVML_P2P_CAPS_INDEX_NVLINK,
|
|
)
|
|
if p2p_status != pynvml.NVML_P2P_STATUS_OK:
|
|
return False
|
|
except pynvml.NVMLError:
|
|
logger.exception(
|
|
"NVLink detection failed. This is normal if"
|
|
" your machine has no NVLink equipped.")
|
|
return False
|
|
return True
|
|
|
|
@classmethod
|
|
def _get_physical_device_name(cls, device_id: int = 0) -> str:
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
|
return pynvml.nvmlDeviceGetName(handle)
|
|
|
|
@classmethod
|
|
@with_nvml_context
|
|
def log_warnings(cls):
|
|
device_ids: int = pynvml.nvmlDeviceGetCount()
|
|
if device_ids > 1:
|
|
device_names = [
|
|
cls._get_physical_device_name(i) for i in range(device_ids)
|
|
]
|
|
if (len(set(device_names)) > 1
|
|
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
|
logger.warning(
|
|
"Detected different devices in the system: \n%s\nPlease"
|
|
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
|
"avoid unexpected behavior.",
|
|
"\n".join(device_names),
|
|
)
|
|
|
|
|
|
class NonNvmlCudaPlatform(CudaPlatformBase):
|
|
|
|
@classmethod
|
|
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
|
major, minor = torch.cuda.get_device_capability(device_id)
|
|
return DeviceCapability(major=major, minor=minor)
|
|
|
|
@classmethod
|
|
def get_device_name(cls, device_id: int = 0) -> str:
|
|
return torch.cuda.get_device_name(device_id)
|
|
|
|
@classmethod
|
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
device_props = torch.cuda.get_device_properties(device_id)
|
|
return device_props.total_memory
|
|
|
|
@classmethod
|
|
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
|
logger.exception(
|
|
"NVLink detection not possible, as context support was"
|
|
" not found. Assuming no NVLink available.")
|
|
return False
|
|
|
|
|
|
# Autodetect either NVML-enabled or non-NVML platform
|
|
# based on whether NVML is available.
|
|
nvml_available = False
|
|
try:
|
|
try:
|
|
pynvml.nvmlInit()
|
|
nvml_available = True
|
|
except Exception:
|
|
# On Jetson, NVML is not supported.
|
|
nvml_available = False
|
|
finally:
|
|
if nvml_available:
|
|
pynvml.nvmlShutdown()
|
|
|
|
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
|
|
|
|
try:
|
|
from sphinx.ext.autodoc.mock import _MockModule
|
|
|
|
if not isinstance(pynvml, _MockModule):
|
|
CudaPlatform.log_warnings()
|
|
except ModuleNotFoundError:
|
|
CudaPlatform.log_warnings()
|