vllm/vllm/attention/selector.py
bigPYJ1151 0e3f06fe9c
[Hardware][Intel] Add CPU inference backend (#3634)
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Yuan Zhou <yuan.zhou@intel.com>
2024-04-01 22:07:30 -07:00

56 lines
1.8 KiB
Python

from functools import lru_cache
from typing import Type
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip
logger = init_logger(__name__)
@lru_cache(maxsize=None)
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
if _can_use_flash_attn(dtype):
logger.info("Using FlashAttention backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
elif is_cpu():
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
else:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
XFormersBackend)
return XFormersBackend
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
if is_hip():
# AMD GPUs.
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
return False
if is_cpu():
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
"GPUs.")
return False
if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return False
try:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention because the package is not found. "
"Please install it for better performance.")
return False
return True