mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:07:13 +08:00
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Yuan Zhou <yuan.zhou@intel.com>
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
from functools import lru_cache
|
|
from typing import Type
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import is_cpu, is_hip
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
|
if _can_use_flash_attn(dtype):
|
|
logger.info("Using FlashAttention backend.")
|
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
|
FlashAttentionBackend)
|
|
return FlashAttentionBackend
|
|
elif is_cpu():
|
|
logger.info("Using Torch SDPA backend.")
|
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
|
return TorchSDPABackend
|
|
else:
|
|
logger.info("Using XFormers backend.")
|
|
from vllm.attention.backends.xformers import ( # noqa: F401
|
|
XFormersBackend)
|
|
return XFormersBackend
|
|
|
|
|
|
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
|
|
if is_hip():
|
|
# AMD GPUs.
|
|
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
|
|
return False
|
|
if is_cpu():
|
|
return False
|
|
if torch.cuda.get_device_capability()[0] < 8:
|
|
# Volta and Turing NVIDIA GPUs.
|
|
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
|
"GPUs.")
|
|
return False
|
|
if dtype not in (torch.float16, torch.bfloat16):
|
|
logger.info("Cannot use FlashAttention backend for dtype other than "
|
|
"torch.float16 or torch.bfloat16.")
|
|
return False
|
|
|
|
try:
|
|
import flash_attn # noqa: F401
|
|
except ImportError:
|
|
logger.info(
|
|
"Cannot use FlashAttention because the package is not found. "
|
|
"Please install it for better performance.")
|
|
return False
|
|
return True
|