mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
[Misc] Take user preference in attention selector (#4960)
This commit is contained in:
parent
a36de682d4
commit
ee3eea0a1b
84
tests/kernels/test_attention_selector.py
Normal file
84
tests/kernels/test_attention_selector.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.selector import which_attn_to_use
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
|
||||||
|
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||||
|
def test_env(name: str, device: str):
|
||||||
|
"""Test that the attention selector can be set via environment variable.
|
||||||
|
Note that we do not test FlashAttn because it is the default backend.
|
||||||
|
"""
|
||||||
|
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = name
|
||||||
|
|
||||||
|
if device == "cpu":
|
||||||
|
with patch("vllm.attention.selector.is_cpu", return_value=True):
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||||
|
torch.float16, 16)
|
||||||
|
assert backend.name == "TORCH_SDPA"
|
||||||
|
elif device == "hip":
|
||||||
|
with patch("vllm.attention.selector.is_hip", return_value=True):
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||||
|
torch.float16, 16)
|
||||||
|
assert backend.name == "ROCM_FLASH"
|
||||||
|
else:
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
|
||||||
|
torch.float16, 16)
|
||||||
|
assert backend.name == name
|
||||||
|
|
||||||
|
if name_backup is not None:
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
|
||||||
|
|
||||||
|
|
||||||
|
def test_flash_attn():
|
||||||
|
"""Test FlashAttn validation."""
|
||||||
|
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported CUDA arch
|
||||||
|
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported data type
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported kv cache data type
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported block size
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported sliding window
|
||||||
|
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# flash-attn is not installed
|
||||||
|
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||||
|
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
# Unsupported head size
|
||||||
|
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
||||||
|
assert backend.name != "FLASH_ATTN"
|
||||||
|
|
||||||
|
if name_backup is not None:
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_env():
|
||||||
|
"""Throw an exception if the backend name is invalid."""
|
||||||
|
name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = name_backup
|
||||||
@ -218,6 +218,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
|
# Prompt run.
|
||||||
assert prefill_meta.block_tables is not None
|
assert prefill_meta.block_tables is not None
|
||||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
|
|||||||
@ -30,24 +30,16 @@ def get_attn_backend(
|
|||||||
kv_cache_dtype: Optional[str],
|
kv_cache_dtype: Optional[str],
|
||||||
block_size: int,
|
block_size: int,
|
||||||
) -> Type[AttentionBackend]:
|
) -> Type[AttentionBackend]:
|
||||||
backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
|
"""Determine which attention backend to use and only import
|
||||||
|
the selected backend module.
|
||||||
|
"""
|
||||||
|
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
|
||||||
sliding_window, dtype, kv_cache_dtype,
|
sliding_window, dtype, kv_cache_dtype,
|
||||||
block_size)
|
block_size)
|
||||||
if backend == _Backend.FLASH_ATTN:
|
if backend == _Backend.FLASH_ATTN:
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
|
|
||||||
# We check it here not in _which_attn_to_use because we cannot know
|
|
||||||
# the head size until we import FlashAttentionBackend.
|
|
||||||
supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
||||||
if head_size in supported_head_sizes:
|
|
||||||
logger.info("Using FlashAttention-2 backend.")
|
|
||||||
return FlashAttentionBackend
|
return FlashAttentionBackend
|
||||||
logger.info(
|
|
||||||
"Cannot use FlashAttention-2 backend for head size %d. "
|
|
||||||
"Using XFormers backend instead.", head_size)
|
|
||||||
backend = _Backend.XFORMERS
|
|
||||||
|
|
||||||
if backend == _Backend.XFORMERS:
|
if backend == _Backend.XFORMERS:
|
||||||
logger.info("Using XFormers backend.")
|
logger.info("Using XFormers backend.")
|
||||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||||
@ -64,14 +56,15 @@ def get_attn_backend(
|
|||||||
return TorchSDPABackend
|
return TorchSDPABackend
|
||||||
elif backend == _Backend.FLASHINFER:
|
elif backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using Flashinfer backend.")
|
logger.info("Using Flashinfer backend.")
|
||||||
logger.warning("Eager mode is enforced for the Flashinfer backend.")
|
logger.warning("Eager mode is required for the Flashinfer backend. "
|
||||||
|
"Please make sure --enforce-eager is set.")
|
||||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||||
return FlashInferBackend
|
return FlashInferBackend
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid attention backend.")
|
raise ValueError("Invalid attention backend.")
|
||||||
|
|
||||||
|
|
||||||
def _which_attn_to_use(
|
def which_attn_to_use(
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@ -81,54 +74,84 @@ def _which_attn_to_use(
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
) -> _Backend:
|
) -> _Backend:
|
||||||
"""Returns which flash attention backend to use."""
|
"""Returns which flash attention backend to use."""
|
||||||
|
|
||||||
|
# Default case.
|
||||||
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
|
|
||||||
|
# Check the environment variable and override if specified
|
||||||
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
backend_members = _Backend.__members__
|
||||||
|
if backend_by_env_var not in backend_members:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend '{backend_by_env_var}'. "
|
||||||
|
f"Available backends: {', '.join(backend_members)} "
|
||||||
|
"(case-sensitive).")
|
||||||
|
selected_backend = _Backend[backend_by_env_var]
|
||||||
|
|
||||||
if is_cpu():
|
if is_cpu():
|
||||||
|
if selected_backend != _Backend.TORCH_SDPA:
|
||||||
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
return _Backend.TORCH_SDPA
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
# AMD GPUs.
|
# AMD GPUs.
|
||||||
|
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||||
|
== _Backend.FLASH_ATTN else selected_backend)
|
||||||
|
if selected_backend == _Backend.ROCM_FLASH:
|
||||||
if torch.cuda.get_device_capability()[0] != 9:
|
if torch.cuda.get_device_capability()[0] != 9:
|
||||||
# not Instinct series GPUs.
|
# not Instinct series GPUs.
|
||||||
logger.info("flash_atten is not supported on NAVI GPUs.")
|
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||||
|
else:
|
||||||
|
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||||
return _Backend.ROCM_FLASH
|
return _Backend.ROCM_FLASH
|
||||||
|
|
||||||
# NVIDIA GPUs.
|
# FlashAttn in NVIDIA GPUs.
|
||||||
|
if selected_backend == _Backend.FLASH_ATTN:
|
||||||
if torch.cuda.get_device_capability()[0] < 8:
|
if torch.cuda.get_device_capability()[0] < 8:
|
||||||
# Volta and Turing NVIDIA GPUs.
|
# Volta and Turing NVIDIA GPUs.
|
||||||
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||||
"GPUs.")
|
"GPUs.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
elif dtype not in (torch.float16, torch.bfloat16):
|
||||||
if dtype not in (torch.float16, torch.bfloat16):
|
logger.info(
|
||||||
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
|
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||||
"torch.float16 or torch.bfloat16.")
|
"torch.float16 or torch.bfloat16.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||||
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
logger.info(
|
||||||
logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
elif block_size % 16 != 0:
|
||||||
if block_size % 16 != 0:
|
logger.info(
|
||||||
logger.info("Cannot use FlashAttention-2 backend for block size not "
|
"Cannot use FlashAttention-2 backend for block size not "
|
||||||
"divisible by 16.")
|
"divisible by 16.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
elif sliding_window is not None:
|
||||||
if sliding_window is not None:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend due to sliding window.")
|
"Cannot use FlashAttention-2 backend due to sliding window.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
|
# FlashAttn is valid for the model, checking if the package is installed.
|
||||||
|
if selected_backend == _Backend.FLASH_ATTN:
|
||||||
try:
|
try:
|
||||||
import vllm_flash_attn # noqa: F401
|
import vllm_flash_attn # noqa: F401
|
||||||
|
|
||||||
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
|
FlashAttentionBackend)
|
||||||
|
|
||||||
|
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_sizes:
|
||||||
|
logger.info(
|
||||||
|
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||||
|
head_size)
|
||||||
|
selected_backend = _Backend.XFORMERS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention-2 backend because the vllm_flash_attn "
|
"Cannot use FlashAttention-2 backend because the "
|
||||||
"package is not found. `pip install vllm-flash-attn` for better "
|
"vllm_flash_attn package is not found. "
|
||||||
"performance.")
|
"`pip install vllm-flash-attn` for better performance.")
|
||||||
return _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
|
|
||||||
backend_by_env_var = envs.VLLM_ATTENTION_BACKEND
|
return selected_backend
|
||||||
if backend_by_env_var is not None:
|
|
||||||
return _Backend[backend_by_env_var]
|
|
||||||
|
|
||||||
# Default case.
|
|
||||||
return _Backend.FLASH_ATTN
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user