[Bugfix][MM encoder] Fix ViT attention backend resolving for Turing GPU (#29614)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-28 03:17:37 +08:00 committed by GitHub
parent a24ea5414b
commit 38658ec6f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -264,14 +264,15 @@ class CudaPlatformBase(Platform):
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
# Try FlashAttention first # Try FlashAttention first
try: if (cc := cls.get_device_capability()) and cc.major >= 8:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() try:
if backend_class.supports_head_size( backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
head_size if backend_class.supports_head_size(
) and backend_class.supports_dtype(dtype): head_size
return AttentionBackendEnum.FLASH_ATTN ) and backend_class.supports_dtype(dtype):
except ImportError: return AttentionBackendEnum.FLASH_ATTN
pass except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA