[Bugfix][V1] Allow manual FlashAttention for Blackwell (#19492)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-12 06:40:24 -04:00 committed by GitHub
parent 4f6c42fa0a
commit af09b3f0a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -226,15 +226,21 @@ class CudaPlatformBase(Platform):
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.FLEX_ATTENTION:
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
elif selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try:
import flashinfer # noqa: F401
logger.info_once(
@ -248,10 +254,13 @@ class CudaPlatformBase(Platform):
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass
if cls.has_device_capability(80):
# FlashAttention is the default for SM 8.0+ GPUs
elif cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
# Backends for V0 engine
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"