From 0032903a5bb7c7c655f52f4efdfcc221947e9ca8 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Thu, 20 Mar 2025 20:20:16 -0600 Subject: [PATCH] [Bugfix] detect alibi and revert to FA2 (#15231) Signed-off-by: Travis Johnson --- vllm/attention/backends/flash_attn.py | 3 ++- vllm/fa_utils.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e981ac780b007..4cb0b916739a0 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -630,7 +630,8 @@ class FlashAttentionImpl(AttentionImpl): self.sliding_window = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version() + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=self.alibi_slopes is not None) if (is_quantized_kv_cache(self.kv_cache_dtype) and self.vllm_flash_attn_version != 3): raise NotImplementedError( diff --git a/vllm/fa_utils.py b/vllm/fa_utils.py index 028c96b839fb6..4176534901586 100644 --- a/vllm/fa_utils.py +++ b/vllm/fa_utils.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger logger = init_logger(__name__) -def get_flash_attn_version() -> Optional[int]: +def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform try: @@ -28,8 +28,14 @@ def get_flash_attn_version() -> Optional[int]: # 3. fallback for unsupported combinations if device_capability.major == 10 and fa_version == 3: - logger.warning("Cannot use FA version 3 on Blackwell platform", - "defaulting to FA version 2.") + logger.warning_once( + "Cannot use FA version 3 on Blackwell platform " + "defaulting to FA version 2.") + fa_version = 2 + + if requires_alibi and fa_version == 3: + logger.warning_once("Cannot use FA version 3 with ALiBi, " + "defaulting to FA version 2.") fa_version = 2 if not is_fa_version_supported(fa_version):