[Bugfix][MM encoder] Fix ViT attention backend resolving for Turing GPU (#29614)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-28 03:17:37 +08:00 committed by GitHub
parent a24ea5414b
commit 38658ec6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -264,14 +264,15 @@ class CudaPlatformBase(Platform):
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Try FlashAttention first
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
if (cc := cls.get_device_capability()) and cc.major >= 8:
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA