mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 17:24:40 +08:00
195 lines
6.2 KiB
Python
195 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
import contextlib
|
|
import functools
|
|
import logging
|
|
import os
|
|
from collections.abc import Callable
|
|
from enum import Enum
|
|
from typing import Any, Literal
|
|
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import triton
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
|
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
|
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
|
|
|
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
|
|
|
|
|
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
|
"""
|
|
A decorator that caches the most recent results of a function with tensor inputs.
|
|
|
|
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
|
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
|
|
|
Args:
|
|
fn (Callable[..., torch.Tensor]):
|
|
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
|
|
|
Returns:
|
|
Callable[..., torch.Tensor]:
|
|
A wrapped version of the input function with single-entry caching.
|
|
"""
|
|
|
|
cache_entries: tuple[tuple | None, dict | None, Any] = []
|
|
cache_size = 8
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
nonlocal cache_entries, cache_size
|
|
for i, entry in enumerate(cache_entries):
|
|
last_args, last_kwargs, last_result = entry
|
|
if (
|
|
len(args) == len(last_args)
|
|
and len(kwargs) == len(last_kwargs)
|
|
and all(a is b for a, b in zip(args, last_args))
|
|
and all(
|
|
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
|
)
|
|
):
|
|
cache_entries = (
|
|
cache_entries[:i]
|
|
+ cache_entries[i + 1 :]
|
|
+ [(args, kwargs, last_result)]
|
|
)
|
|
return last_result
|
|
|
|
result = fn(*args, **kwargs)
|
|
|
|
if len(cache_entries) >= cache_size:
|
|
cache_entries = cache_entries[1:]
|
|
cache_entries.append((args, kwargs, result))
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
|
"""
|
|
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
"""
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
contiguous_args = (
|
|
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
|
)
|
|
contiguous_kwargs = {
|
|
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
|
for k, v in kwargs.items()
|
|
}
|
|
|
|
tensor = None
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
tensor = arg
|
|
break
|
|
if tensor is None:
|
|
for value in kwargs.values():
|
|
if isinstance(value, torch.Tensor):
|
|
tensor = value
|
|
break
|
|
|
|
if tensor is not None:
|
|
ctx = torch.cuda.device(tensor.device.index)
|
|
else:
|
|
ctx = contextlib.nullcontext()
|
|
|
|
with ctx:
|
|
return fn(*contiguous_args, **contiguous_kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@functools.cache
|
|
def get_available_device() -> str:
|
|
try:
|
|
return triton.runtime.driver.active.get_current_target().backend
|
|
except BaseException:
|
|
return "cpu"
|
|
|
|
|
|
@functools.cache
|
|
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|
device = get_available_device()
|
|
mapping = {
|
|
"cuda": "nvidia",
|
|
"hip": "amd",
|
|
"xpu": "intel",
|
|
}
|
|
# return the mapped value, or the original if not found
|
|
return mapping.get(device, device)
|
|
|
|
|
|
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
|
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
|
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
|
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
|
|
device_torch_lib = getattr(torch, device, None)
|
|
device_platform = _check_platform()
|
|
|
|
is_amd = device_platform == "amd"
|
|
is_intel = device_platform == "intel"
|
|
is_nvidia = device_platform == "nvidia"
|
|
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
|
is_nvidia_hopper = is_nvidia and (
|
|
"NVIDIA H" in torch.cuda.get_device_name(0)
|
|
or torch.cuda.get_device_capability()[0] >= 9
|
|
)
|
|
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
|
is_gather_supported = hasattr(triton.language, "gather")
|
|
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
|
|
hasattr(triton.language, "_experimental_make_tensor_descriptor")
|
|
or hasattr(triton.language, "make_tensor_descriptor")
|
|
)
|
|
|
|
|
|
def get_all_max_shared_mem():
|
|
try:
|
|
return [
|
|
triton.runtime.driver.active.utils.get_device_properties(i)[
|
|
"max_shared_mem"
|
|
]
|
|
for i in range(device_torch_lib.device_count())
|
|
]
|
|
except BaseException:
|
|
return [-1]
|
|
|
|
|
|
class Backend(Enum):
|
|
ADA = 101376 # RTX 4090
|
|
AMPERE = 166912 # A100
|
|
HOPPER = 232448 # H100
|
|
DEFAULT = 102400 # Default
|
|
|
|
@classmethod
|
|
def get_shared_memory(cls, arch: str) -> int:
|
|
try:
|
|
return cls[arch.upper()].value
|
|
except KeyError:
|
|
return cls.DEFAULT.value
|
|
|
|
|
|
@functools.cache
|
|
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
|
try:
|
|
device_shared_mem_list = get_all_max_shared_mem()
|
|
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
return max_shared_memory >= Backend.get_shared_memory(arch)
|
|
except Exception:
|
|
return False
|