[Misc] Set default backend to SDPA for get_vit_attn_backend (#12235)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-22 03:52:11 +08:00 committed by GitHub
parent 347eeebe3b
commit fa9ee08121
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision "
"module, so we use xformers backend instead. You can "
"run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS
else:
logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
# Default to torch SDPA for other non-GPU platforms.
selected_backend = _Backend.TORCH_SDPA
return selected_backend