mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[platform] Allow platform specify attention backend (#11609)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
65097ca0af
commit
405eb8e396
@ -1,10 +1,10 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils import override_backend_env_variable
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
from vllm.attention.selector import which_attn_to_use
|
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||||
from vllm.platforms.cpu import CpuPlatform
|
from vllm.platforms.cpu import CpuPlatform
|
||||||
from vllm.platforms.cuda import CudaPlatform
|
from vllm.platforms.cuda import CudaPlatform
|
||||||
from vllm.platforms.openvino import OpenVinoPlatform
|
from vllm.platforms.openvino import OpenVinoPlatform
|
||||||
@ -12,6 +12,13 @@ from vllm.platforms.rocm import RocmPlatform
|
|||||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
"""Clear lru cache to ensure each test case runs without caching.
|
||||||
|
"""
|
||||||
|
_cached_get_attn_backend.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
||||||
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
||||||
@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
|
|||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||||
False)
|
False)
|
||||||
assert backend.name == "TORCH_SDPA"
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||||
False)
|
False)
|
||||||
assert backend.name == "ROCM_FLASH"
|
assert backend.get_name() == "ROCM_FLASH"
|
||||||
elif device == "openvino":
|
elif device == "openvino":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
OpenVinoPlatform()):
|
OpenVinoPlatform()), patch.dict('sys.modules',
|
||||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
{'openvino': Mock()}):
|
||||||
False)
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||||
assert backend.name == "OPENVINO"
|
False)
|
||||||
|
assert backend.get_name() == "OPENVINO"
|
||||||
else:
|
else:
|
||||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
if name in ["XFORMERS", "FLASHINFER"]:
|
||||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
with patch("vllm.attention.selector.current_platform",
|
||||||
False)
|
CudaPlatform()):
|
||||||
assert backend.name == name
|
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||||
|
16, False)
|
||||||
|
assert backend.get_name() == name
|
||||||
|
|
||||||
|
|
||||||
def test_flash_attn(monkeypatch):
|
def test_flash_attn(monkeypatch):
|
||||||
"""Test FlashAttn validation."""
|
"""Test FlashAttn validation."""
|
||||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||||
# which_attn_to_use
|
# get_attn_backend
|
||||||
|
|
||||||
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
||||||
|
|
||||||
# Unsupported CUDA arch
|
# Unsupported CUDA arch
|
||||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||||
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported data type
|
# Unsupported data type
|
||||||
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
|
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported kv cache data type
|
# Unsupported kv cache data type
|
||||||
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
|
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported block size
|
# Unsupported block size
|
||||||
backend = which_attn_to_use(16, torch.float16, None, 8, False)
|
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# flash-attn is not installed
|
# flash-attn is not installed
|
||||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||||
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Unsupported head size
|
# Unsupported head size
|
||||||
backend = which_attn_to_use(17, torch.float16, None, 16, False)
|
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Attention-free models should bypass env and use PlaceholderAttention
|
# Attention-free models should bypass env and use PlaceholderAttention
|
||||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
||||||
assert backend.name != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_env(monkeypatch):
|
def test_invalid_env(monkeypatch):
|
||||||
"""Throw an exception if the backend name is invalid."""
|
"""Throw an exception if the backend name is invalid."""
|
||||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
which_attn_to_use(16, torch.float16, None, 16, False)
|
get_attn_backend(16, torch.float16, None, 16, False)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -114,83 +114,19 @@ def _cached_get_attn_backend(
|
|||||||
BlocksparseFlashAttentionBackend)
|
BlocksparseFlashAttentionBackend)
|
||||||
return BlocksparseFlashAttentionBackend
|
return BlocksparseFlashAttentionBackend
|
||||||
|
|
||||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
|
||||||
is_attention_free, use_v1)
|
|
||||||
if backend == _Backend.FLASH_ATTN:
|
|
||||||
logger.info("Using Flash Attention backend.")
|
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
|
||||||
FlashAttentionBackend)
|
|
||||||
return FlashAttentionBackend
|
|
||||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
|
||||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
|
||||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
|
||||||
return FlashAttentionBackendV1
|
|
||||||
if backend == _Backend.XFORMERS:
|
|
||||||
logger.info("Using XFormers backend.")
|
|
||||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
|
||||||
XFormersBackend)
|
|
||||||
return XFormersBackend
|
|
||||||
elif backend == _Backend.ROCM_FLASH:
|
|
||||||
logger.info("Using ROCmFlashAttention backend.")
|
|
||||||
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
|
||||||
ROCmFlashAttentionBackend)
|
|
||||||
return ROCmFlashAttentionBackend
|
|
||||||
elif backend == _Backend.TORCH_SDPA:
|
|
||||||
assert current_platform.is_cpu(), RuntimeError(
|
|
||||||
"Torch SDPA backend is only used for the CPU device.")
|
|
||||||
logger.info("Using Torch SDPA backend.")
|
|
||||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
|
||||||
return TorchSDPABackend
|
|
||||||
elif backend == _Backend.OPENVINO:
|
|
||||||
logger.info("Using OpenVINO Attention backend.")
|
|
||||||
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
|
||||||
return OpenVINOAttentionBackend
|
|
||||||
elif backend == _Backend.IPEX:
|
|
||||||
assert current_platform.is_xpu(), RuntimeError(
|
|
||||||
"IPEX attention backend is only used for the XPU device.")
|
|
||||||
logger.info("Using IPEX attention backend.")
|
|
||||||
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
|
||||||
return IpexAttnBackend
|
|
||||||
elif backend == _Backend.FLASHINFER:
|
|
||||||
logger.info("Using Flashinfer backend.")
|
|
||||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
|
||||||
return FlashInferBackend
|
|
||||||
elif backend == _Backend.HPU_ATTN:
|
|
||||||
logger.info("Using HPUAttention backend.")
|
|
||||||
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
|
|
||||||
return HPUAttentionBackend
|
|
||||||
elif backend == _Backend.PALLAS:
|
|
||||||
logger.info("Using Pallas backend.")
|
|
||||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
|
||||||
return PallasAttentionBackend
|
|
||||||
elif backend == _Backend.NO_ATTENTION:
|
|
||||||
from vllm.attention.backends.placeholder_attn import (
|
|
||||||
PlaceholderAttentionBackend)
|
|
||||||
return PlaceholderAttentionBackend
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid attention backend.")
|
|
||||||
|
|
||||||
|
|
||||||
def which_attn_to_use(head_size: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
kv_cache_dtype: Optional[str],
|
|
||||||
block_size: int,
|
|
||||||
is_attention_free: bool,
|
|
||||||
use_v1: bool = False) -> _Backend:
|
|
||||||
"""Returns which flash attention backend to use."""
|
|
||||||
# Default case.
|
|
||||||
selected_backend = _Backend.FLASH_ATTN
|
|
||||||
|
|
||||||
# If there are no attention layers (e.g. we are running Mamba),
|
# If there are no attention layers (e.g. we are running Mamba),
|
||||||
# use the placeholder NO_ATTENTION
|
# use the placeholder NO_ATTENTION
|
||||||
if is_attention_free:
|
if is_attention_free:
|
||||||
return _Backend.NO_ATTENTION
|
from vllm.attention.backends.placeholder_attn import (
|
||||||
|
PlaceholderAttentionBackend)
|
||||||
|
return PlaceholderAttentionBackend
|
||||||
|
|
||||||
# Check whether a particular choice of backend was
|
# Check whether a particular choice of backend was
|
||||||
# previously forced.
|
# previously forced.
|
||||||
#
|
#
|
||||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
# ENVIRONMENT VARIABLE.
|
# ENVIRONMENT VARIABLE.
|
||||||
|
selected_backend = None
|
||||||
backend_by_global_setting: Optional[_Backend] = (
|
backend_by_global_setting: Optional[_Backend] = (
|
||||||
get_global_forced_attn_backend())
|
get_global_forced_attn_backend())
|
||||||
if backend_by_global_setting is not None:
|
if backend_by_global_setting is not None:
|
||||||
@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int,
|
|||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
|
||||||
# get device-specific default attn_backend
|
# get device-specific attn_backend
|
||||||
default_backend = current_platform.get_default_attn_backend(
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
selected_backend)
|
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
|
||||||
if default_backend is not None:
|
if not attention_cls:
|
||||||
return default_backend
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {current_platform.device_name}")
|
||||||
if use_v1:
|
return resolve_obj_by_qualname(attention_cls)
|
||||||
return _Backend.FLASH_ATTN_VLLM_V1
|
|
||||||
|
|
||||||
# FlashAttn in NVIDIA GPUs.
|
|
||||||
if selected_backend == _Backend.FLASH_ATTN:
|
|
||||||
if not current_platform.has_device_capability(80):
|
|
||||||
# Volta and Turing NVIDIA GPUs.
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
|
||||||
"GPUs.")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
elif dtype not in (torch.float16, torch.bfloat16):
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
|
||||||
"torch.float16 or torch.bfloat16.")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
|
||||||
logger.warning(
|
|
||||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
|
||||||
"better performance by setting environment variable "
|
|
||||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
elif block_size % 16 != 0:
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for block size not "
|
|
||||||
"divisible by 16.")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
# FlashAttn is valid for the model, checking if the package is installed.
|
|
||||||
if selected_backend == _Backend.FLASH_ATTN:
|
|
||||||
try:
|
|
||||||
import vllm.vllm_flash_attn # noqa: F401
|
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
|
||||||
FlashAttentionBackend)
|
|
||||||
|
|
||||||
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
||||||
if head_size not in supported_sizes:
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
|
||||||
head_size)
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
except ImportError:
|
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend because the "
|
|
||||||
"vllm.vllm_flash_attn package is not found. "
|
|
||||||
"Make sure that vllm_flash_attn was built and installed "
|
|
||||||
"(on by default).")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
return selected_backend
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@ -28,10 +28,13 @@ class CpuPlatform(Platform):
|
|||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int, use_v1: bool) -> str:
|
||||||
if selected_backend != _Backend.TORCH_SDPA:
|
if selected_backend != _Backend.TORCH_SDPA:
|
||||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
return _Backend.TORCH_SDPA
|
logger.info("Using Torch SDPA backend.")
|
||||||
|
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import vllm._C # noqa
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -141,6 +141,81 @@ class CudaPlatformBase(Platform):
|
|||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
|
kv_cache_dtype, block_size, use_v1) -> str:
|
||||||
|
if use_v1:
|
||||||
|
logger.info("Using Flash Attention backend on V1 engine.")
|
||||||
|
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
|
logger.info("Using FlashInfer backend.")
|
||||||
|
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||||
|
elif selected_backend == _Backend.XFORMERS:
|
||||||
|
logger.info("Using XFormers backend.")
|
||||||
|
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||||
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
|
pass
|
||||||
|
elif selected_backend:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {cls.device_name}")
|
||||||
|
|
||||||
|
target_backend = _Backend.FLASH_ATTN
|
||||||
|
if not cls.has_device_capability(80):
|
||||||
|
# Volta and Turing NVIDIA GPUs.
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||||
|
"GPUs.")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
elif dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||||
|
"torch.float16 or torch.bfloat16.")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
elif kv_cache_dtype is not None and \
|
||||||
|
kv_cache_dtype.startswith("fp8"):
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||||
|
logger.warning(
|
||||||
|
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||||
|
"better performance by setting environment variable "
|
||||||
|
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
elif block_size % 16 != 0:
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for block size not "
|
||||||
|
"divisible by 16.")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
# FlashAttn is valid for the model, checking if the package is
|
||||||
|
# installed.
|
||||||
|
if target_backend == _Backend.FLASH_ATTN:
|
||||||
|
try:
|
||||||
|
import vllm.vllm_flash_attn # noqa: F401
|
||||||
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
|
FlashAttentionBackend)
|
||||||
|
|
||||||
|
supported_sizes = \
|
||||||
|
FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_sizes:
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||||
|
head_size)
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
except ImportError:
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend because the "
|
||||||
|
"vllm.vllm_flash_attn package is not found. "
|
||||||
|
"Make sure that vllm_flash_attn was built and installed "
|
||||||
|
"(on by default).")
|
||||||
|
target_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
if target_backend == _Backend.XFORMERS:
|
||||||
|
logger.info("Using XFormers backend.")
|
||||||
|
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||||
|
|
||||||
|
logger.info("Using Flash Attention backend.")
|
||||||
|
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -21,8 +21,11 @@ class HpuPlatform(Platform):
|
|||||||
dispatch_key: str = "HPU"
|
dispatch_key: str = "HPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
return _Backend.HPU_ATTN
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int, use_v1: bool) -> str:
|
||||||
|
logger.info("Using HPUAttention backend.")
|
||||||
|
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||||
|
|||||||
@ -112,9 +112,11 @@ class Platform:
|
|||||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend):
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
"""Get the default attention backend of a device."""
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
return None
|
block_size: int, use_v1: bool) -> str:
|
||||||
|
"""Get the attention backend class of a device."""
|
||||||
|
return ""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(
|
def get_device_capability(
|
||||||
|
|||||||
@ -28,10 +28,13 @@ class OpenVinoPlatform(Platform):
|
|||||||
dispatch_key: str = "CPU"
|
dispatch_key: str = "CPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int, use_v1: bool) -> str:
|
||||||
if selected_backend != _Backend.OPENVINO:
|
if selected_backend != _Backend.OPENVINO:
|
||||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||||
return _Backend.OPENVINO
|
logger.info("Using OpenVINO Attention backend.")
|
||||||
|
return "vllm.attention.backends.openvino.OpenVINOAttentionBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
|
|||||||
@ -70,7 +70,8 @@ class RocmPlatform(Platform):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||||
|
kv_cache_dtype, block_size, use_v1) -> str:
|
||||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||||
== _Backend.FLASH_ATTN else selected_backend)
|
== _Backend.FLASH_ATTN else selected_backend)
|
||||||
if selected_backend == _Backend.ROCM_FLASH:
|
if selected_backend == _Backend.ROCM_FLASH:
|
||||||
@ -79,7 +80,8 @@ class RocmPlatform(Platform):
|
|||||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||||
else:
|
else:
|
||||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||||
return _Backend.ROCM_FLASH
|
logger.info("Using ROCmFlashAttention backend.")
|
||||||
|
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
|
|||||||
@ -24,10 +24,13 @@ class TpuPlatform(Platform):
|
|||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int, use_v1: bool) -> str:
|
||||||
if selected_backend != _Backend.PALLAS:
|
if selected_backend != _Backend.PALLAS:
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
return _Backend.PALLAS
|
logger.info("Using Pallas backend.")
|
||||||
|
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
|
|||||||
@ -21,10 +21,13 @@ class XPUPlatform(Platform):
|
|||||||
dispatch_key: str = "XPU"
|
dispatch_key: str = "XPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int, use_v1: bool) -> str:
|
||||||
if selected_backend != _Backend.IPEX:
|
if selected_backend != _Backend.IPEX:
|
||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
return _Backend.IPEX
|
logger.info("Using IPEX attention backend.")
|
||||||
|
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user