[Misc] Set default backend to SDPA for get_vit_attn_backend (#12235)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-22 03:52:11 +08:00 committed by GitHub
parent 347eeebe3b
commit fa9ee08121
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,7 +82,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None: if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead. if current_platform.is_cuda():
device_available = current_platform.has_device_capability(80) device_available = current_platform.has_device_capability(80)
if device_available and support_fa: if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
@ -90,15 +90,17 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
else: else:
logger.warning_once( logger.warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, " "Current `vllm-flash-attn` has a bug inside vision "
"so we use xformers backend instead. You can run " "module, so we use xformers backend instead. You can "
"`pip install flash-attn` to use flash-attention backend.") "run `pip install flash-attn` to use flash-attention "
"backend.")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else: else:
# For Volta and Turing GPUs, use xformers instead.
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
else:
# Default to torch SDPA for other non-GPU platforms.
selected_backend = _Backend.TORCH_SDPA
return selected_backend return selected_backend