mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 07:24:26 +08:00
[Bugfix][MM encoder] Fix ViT attention backend resolving for Turing GPU (#29614)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
a24ea5414b
commit
38658ec6f3
@ -264,14 +264,15 @@ class CudaPlatformBase(Platform):
|
|||||||
cls, head_size: int, dtype: torch.dtype
|
cls, head_size: int, dtype: torch.dtype
|
||||||
) -> "AttentionBackendEnum":
|
) -> "AttentionBackendEnum":
|
||||||
# Try FlashAttention first
|
# Try FlashAttention first
|
||||||
try:
|
if (cc := cls.get_device_capability()) and cc.major >= 8:
|
||||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
try:
|
||||||
if backend_class.supports_head_size(
|
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||||
head_size
|
if backend_class.supports_head_size(
|
||||||
) and backend_class.supports_dtype(dtype):
|
head_size
|
||||||
return AttentionBackendEnum.FLASH_ATTN
|
) and backend_class.supports_dtype(dtype):
|
||||||
except ImportError:
|
return AttentionBackendEnum.FLASH_ATTN
|
||||||
pass
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
return AttentionBackendEnum.TORCH_SDPA
|
return AttentionBackendEnum.TORCH_SDPA
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user