mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 16:14:37 +08:00
[Bugfix][MM encoder] Fix ViT attention backend resolving for Turing GPU (#29614)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
a24ea5414b
commit
38658ec6f3
@ -264,14 +264,15 @@ class CudaPlatformBase(Platform):
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Try FlashAttention first
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
except ImportError:
|
||||
pass
|
||||
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user