mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 10:55:43 +08:00
[Misc] Set default backend to SDPA for get_vit_attn_backend (#12235)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
347eeebe3b
commit
fa9ee08121
@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
|||||||
if backend_by_env_var is not None:
|
if backend_by_env_var is not None:
|
||||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
if selected_backend is None:
|
if selected_backend is None:
|
||||||
# For Volta and Turing GPUs, use xformers instead.
|
if current_platform.is_cuda():
|
||||||
device_available = current_platform.has_device_capability(80)
|
device_available = current_platform.has_device_capability(80)
|
||||||
if device_available and support_fa:
|
if device_available and support_fa:
|
||||||
from transformers.utils import is_flash_attn_2_available
|
from transformers.utils import is_flash_attn_2_available
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
selected_backend = _Backend.FLASH_ATTN
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"Current `vllm-flash-attn` has a bug inside vision "
|
||||||
|
"module, so we use xformers backend instead. You can "
|
||||||
|
"run `pip install flash-attn` to use flash-attention "
|
||||||
|
"backend.")
|
||||||
|
selected_backend = _Backend.XFORMERS
|
||||||
else:
|
else:
|
||||||
logger.warning_once(
|
# For Volta and Turing GPUs, use xformers instead.
|
||||||
"Current `vllm-flash-attn` has a bug inside vision module, "
|
|
||||||
"so we use xformers backend instead. You can run "
|
|
||||||
"`pip install flash-attn` to use flash-attention backend.")
|
|
||||||
selected_backend = _Backend.XFORMERS
|
selected_backend = _Backend.XFORMERS
|
||||||
elif current_platform.is_cpu() or current_platform.is_rocm():
|
|
||||||
# ROCM doesn't support xformers
|
|
||||||
selected_backend = _Backend.TORCH_SDPA
|
|
||||||
else:
|
else:
|
||||||
selected_backend = _Backend.XFORMERS
|
# Default to torch SDPA for other non-GPU platforms.
|
||||||
|
selected_backend = _Backend.TORCH_SDPA
|
||||||
return selected_backend
|
return selected_backend
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user