# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import importlib.metadata import os import threading from collections.abc import Callable, Collection from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeVar import numpy as np import numpy.typing as npt import torch from packaging import version from packaging.version import Version from torch.library import Library import vllm.envs as envs if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import IntermediateTensors else: ModelConfig = object IntermediateTensors = object STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, torch.uint8: np.uint8, torch.int32: np.int32, torch.int64: np.int64, } T = TypeVar("T") @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" old_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) yield torch.set_default_dtype(old_dtype) @contextlib.contextmanager def set_default_torch_num_threads(num_threads: int): """Sets the default number of threads for PyTorch to the given value.""" old_num_threads = torch.get_num_threads() torch.set_num_threads(num_threads) yield torch.set_num_threads(old_num_threads) @contextlib.contextmanager def guard_cuda_initialization(): """Avoid unexpected CUDA initialization.""" from vllm.platforms import current_platform if not current_platform.is_cuda(): yield return had_key = "CUDA_VISIBLE_DEVICES" in os.environ old_value = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "" try: yield except Exception as e: if "No CUDA GPUs are available" in str(e): err_msg = "CUDA initialization is blocked." else: err_msg = str(e) raise RuntimeError(err_msg) from e finally: if had_key: os.environ["CUDA_VISIBLE_DEVICES"] = old_value else: os.environ.pop("CUDA_VISIBLE_DEVICES") def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() # bool = 0, int = 1, float = 2, complex = 3 def _get_precision_level(dtype: torch.dtype) -> int: # NOTE: Complex dtypes return `is_floating_point=False` return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): """ Test whether it is lossless to cast a tensor from `src_dtype` to `tgt_dtype`. """ if src_dtype == tgt_dtype: return True src_level = _get_precision_level(src_dtype) tgt_level = _get_precision_level(tgt_dtype) if src_level < tgt_level: return True if src_level > tgt_level: return False # Compare integral types if not src_dtype.is_floating_point and not src_dtype.is_complex: src_info = torch.iinfo(src_dtype) tgt_info = torch.iinfo(tgt_dtype) return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max # Compare floating-point types src_info = torch.finfo(src_dtype) tgt_info = torch.finfo(tgt_dtype) return ( src_info.min >= tgt_info.min and src_info.max <= tgt_info.max and src_info.resolution >= tgt_info.resolution ) def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): """ Get the common `dtype` where all of the other `dtypes` can be cast to it without losing any information. """ return max( dtypes, key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), ) def _generate_random_fp8( tensor: torch.Tensor, low: float, high: float, ) -> None: # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, # it may occur Inf or NaN if we directly use torch.randint # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp def get_kv_cache_torch_dtype( cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, ) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype): torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") return torch_dtype def kv_cache_dtype_str_to_dtype( kv_cache_dtype: str, model_config: ModelConfig ) -> torch.dtype: if kv_cache_dtype == "auto": # Model config may not be specified for unit tests, default to float16 return model_config.dtype if model_config else torch.half return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, seed: int | None = None, device: str | None = "cuda", cache_layout: str | None = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) assert cache_layout in ("NHD", "HND") stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_value_cache = torch.empty( size=kv_cache_allocation_shape, dtype=dtype, device=device ).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches def create_kv_caches_with_random( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, seed: int | None = None, device: str | None = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) from vllm.platforms import current_platform current_platform.seed_everything(seed) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: list[torch.Tensor] = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches def async_tensor_h2d( data: list, dtype: torch.dtype, target_device: str | torch.device, pin_memory: bool, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) def make_ndarray_with_pad( x: list[list[T]], pad: T, dtype: npt.DTypeLike, *, max_len: int | None = None, ) -> npt.NDArray: """ Make a padded array from 2D inputs. The padding is applied to the end of each inner list until it reaches `max_len`. """ if max_len is None: # Unlike for most functions, map is faster than a genexpr over `len` max_len = max(map(len, x), default=0) padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len padded_x[ind, : len(blocktb)] = blocktb return padded_x def make_tensor_with_pad( x: list[list[T]], pad: T, dtype: torch.dtype, *, max_len: int | None = None, device: str | torch.device | None = None, pin_memory: bool = False, ) -> torch.Tensor: """ Make a padded tensor from 2D inputs. The padding is applied to the end of each inner list until it reaches `max_len`. """ np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) tensor = torch.from_numpy(padded_x).to(device) if pin_memory: tensor = tensor.pin_memory() return tensor prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() def _patched_set_stream(stream: torch.cuda.Stream) -> None: _current_stream_tls.value = stream prev_set_stream(stream) torch.cuda.set_stream = _patched_set_stream class _StreamPlaceholder: def __init__(self): self.synchronize = lambda: None def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. it turns out that `torch.cuda.current_stream()` is quite expensive, as it will construct a new stream object at each call. here we patch `torch.cuda.set_stream` to keep track of the current stream directly, so that we can avoid calling `torch.cuda.current_stream()`. the underlying hypothesis is that we do not call `torch._C._cuda_setStream` from C/C++ code. """ from vllm.platforms import current_platform if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process if current_platform.is_rocm(): # torch.cuda.set_stream here is the alias of _pathed_set_stream torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu(): _current_stream_tls.value = _StreamPlaceholder() else: current_stream = current_platform.current_stream if current_stream is not None: _current_stream_tls.value = current_stream() else: raise ValueError( "Fail to set current stream, current platform " "may not support current_stream with torch API" ) return _current_stream_tls.value # Global auxilary stream for running operations in background streams. # We have single global auxilary stream to avoid an explosion of streams # for every layer (and make profiling look sane). # # aux_stream() is currently used for: # - MoE shared_expert overlap with router _aux_stream: torch.cuda.Stream | None = None def aux_stream() -> torch.cuda.Stream | None: """ Ensures aux_stream is initialized only once """ global _aux_stream from vllm.platforms import current_platform if _aux_stream is None and current_platform.is_cuda_alike(): _aux_stream = torch.cuda.Stream() return _aux_stream @lru_cache(maxsize=8) def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. # Code below is based on # https://github.com/pytorch/pytorch/blob/ # c1cd946818442aca8c7f812b16d187ce1586c3bc/ # torch/cuda/__init__.py#L831C1-L831C17 import torch.cuda import torch.version from vllm.platforms import current_platform if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 raw_count = ( torch.cuda._device_count_amdsmi() if (hasattr(torch.cuda, "_device_count_amdsmi")) else -1 ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count return r def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. """ if isinstance(tensor, torch.Tensor): return torch.ops._C.weak_ref_tensor(tensor) else: return tensor def weak_ref_tensors( tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor] | IntermediateTensors, ) -> torch.Tensor | list[Any] | tuple[Any] | Any: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. """ if isinstance(tensors, torch.Tensor): return weak_ref_tensor(tensors) if isinstance(tensors, list): return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) # For IntermediateTensors used in pipeline parallelism from vllm.sequence import IntermediateTensors if isinstance(tensors, IntermediateTensors): ret = IntermediateTensors( {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} ) return ret raise ValueError("Invalid type for tensors") def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: """ Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). """ assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) # Helper function used in testing. def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: torch_version = version.parse(torch_version) return torch_version >= version.parse(target) def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. Args: target: a version string, like "2.6.0". Returns: Whether the condition meets. """ try: return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version("torch")) >= Version(target) def _is_torch_equal(target: str) -> bool: assert target.count(".") == 2 torch_version = str(torch.__version__) torch_version = version.parse(torch_version) # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" # or "2.6.0+cu128" but never "2.6.0.1" return ( torch_version >= version.parse(target) and version.parse(target + ".1") > torch_version ) def is_torch_equal(target: str) -> bool: """Check if the installed torch version is == the target version. Args: target: a version string, like "2.6.0". Returns: Whether the condition meets. """ try: return _is_torch_equal(target) except Exception: return Version(importlib.metadata.version("torch")) == Version(target) # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. def supports_dynamo() -> bool: return is_torch_equal_or_newer("2.4.0") # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: return ( is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() ) # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: return hasattr(torch.library, "custom_op") # create a library to hold the custom op vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: list[str] | None = None, fake_impl: Callable | None = None, target_lib: Library | None = None, dispatch_key: str | None = None, tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. By default, the custom op is registered to the vLLM library. If you want to register it to a different library, you can pass the library object to the `target_lib` argument. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ if not supports_custom_op(): from vllm.platforms import current_platform assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " "the required dependencies." ) return if mutates_args is None: mutates_args = [] if dispatch_key is None: from vllm.platforms import current_platform dispatch_key = current_platform.dispatch_key import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl)