From fe743b798dfa56aea3e2cb7182365ba3495489ee Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Feb 2025 00:06:56 +0800 Subject: [PATCH] [bugfix] fix early import of flash attention (#12959) Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 13 +++++++------ vllm/attention/backends/mla/utils.py | 5 +++-- vllm/attention/backends/utils.py | 14 ++++++-------- vllm/v1/attention/backends/flash_attn.py | 7 ++++--- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 971fe411695cb..5aca10079f9be 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,8 +14,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadataBuilder, AttentionType) from vllm.attention.backends.utils import ( - PAD_SLOT_ID, VLLM_FLASH_ATTN_VERSION, CommonAttentionState, - compute_slot_mapping, compute_slot_mapping_start_idx, + PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, + compute_slot_mapping_start_idx, get_flash_attn_version, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) @@ -640,6 +640,7 @@ class FlashAttentionImpl(AttentionImpl): f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") self.attn_type = attn_type + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -759,7 +760,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # prefix-enabled attention @@ -782,7 +783,7 @@ class FlashAttentionImpl(AttentionImpl): block_table=prefill_meta.block_tables, softcap=logits_soft_cap, out=prefill_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) if decode_meta := attn_metadata.decode_metadata: @@ -811,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl): softcap=logits_soft_cap, block_table=decode_meta.block_tables, out=decode_output, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -832,7 +833,7 @@ class FlashAttentionImpl(AttentionImpl): alibi_slopes=alibi_slopes, softcap=logits_soft_cap, out=decode_output.unsqueeze(1), - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index c22f7e92103b8..a41140ec83782 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -12,7 +12,7 @@ from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -181,6 +181,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj + self.vllm_flash_attn_version = get_flash_attn_version() def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: @@ -515,7 +516,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index e8a34434122c4..5c1f9916e22c2 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -587,11 +587,11 @@ def get_num_prefill_decode_query_kv_tokens( num_decode_query_tokens) -try: - from vllm.vllm_flash_attn.flash_attn_interface import ( - fa_version_unsupported_reason, is_fa_version_supported) +def get_flash_attn_version(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + fa_version_unsupported_reason, is_fa_version_supported) - def flash_attn_version(): # if hopper default to FA3, otherwise stick to FA2 for now # TODO(lucas): profile FA3 on ampere to see if it makes sense to # use FA3 as default for both @@ -610,7 +610,5 @@ try: assert is_fa_version_supported(fa_version) return fa_version - - VLLM_FLASH_ATTN_VERSION = flash_attn_version() -except (ImportError, AssertionError): - VLLM_FLASH_ATTN_VERSION = None + except (ImportError, AssertionError): + return None diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 204afc9f4025d..5cb1e2fd26a5c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,7 +10,7 @@ import triton.language as tl from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import VLLM_FLASH_ATTN_VERSION +from vllm.attention.backends.utils import get_flash_attn_version from vllm.logger import init_logger from vllm.utils import cdiv from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -132,6 +132,7 @@ class FlashAttentionImpl(AttentionImpl): "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") + self.vllm_flash_attn_version = get_flash_attn_version() def forward( self, @@ -205,7 +206,7 @@ class FlashAttentionImpl(AttentionImpl): window_size=self.sliding_window, block_table=attn_metadata.block_table, softcap=self.logits_soft_cap, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output @@ -227,7 +228,7 @@ class FlashAttentionImpl(AttentionImpl): logits_soft_cap=self.logits_soft_cap, block_table=attn_metadata.block_table, common_prefix_len=attn_metadata.common_prefix_len, - fa_version=VLLM_FLASH_ATTN_VERSION, + fa_version=self.vllm_flash_attn_version, ) return output