diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index d5c3a177d9c2b..4bf9401b6b051 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -264,14 +264,15 @@ class CudaPlatformBase(Platform): cls, head_size: int, dtype: torch.dtype ) -> "AttentionBackendEnum": # Try FlashAttention first - try: - backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() - if backend_class.supports_head_size( - head_size - ) and backend_class.supports_dtype(dtype): - return AttentionBackendEnum.FLASH_ATTN - except ImportError: - pass + if (cc := cls.get_device_capability()) and cc.major >= 8: + try: + backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() + if backend_class.supports_head_size( + head_size + ) and backend_class.supports_dtype(dtype): + return AttentionBackendEnum.FLASH_ATTN + except ImportError: + pass return AttentionBackendEnum.TORCH_SDPA