From d8bccde686348b38fc4590dffe8069c68627ab67 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sun, 27 Apr 2025 20:27:56 -0400 Subject: [PATCH] [BugFix] Fix vllm_flash_attn install issues (#17267) Signed-off-by: Lucas Wilkinson Co-authored-by: Jee Jee Li Co-authored-by: Aaron Pham --- .github/CODEOWNERS | 1 + .gitignore | 2 - setup.py | 26 +- vllm/attention/backends/flash_attn.py | 6 +- vllm/attention/backends/mla/common.py | 2 +- .../utils}/fa_utils.py | 0 vllm/engine/arg_utils.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/attention/backends/mla/common.py | 2 +- vllm/vllm_flash_attn/__init__.py | 22 -- vllm/vllm_flash_attn/flash_attn_interface.pyi | 245 ------------------ 11 files changed, 28 insertions(+), 284 deletions(-) rename vllm/{vllm_flash_attn => attention/utils}/fa_utils.py (100%) delete mode 100644 vllm/vllm_flash_attn/__init__.py delete mode 100644 vllm/vllm_flash_attn/flash_attn_interface.pyi diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 860c5c6cd5374..76aa5f7a35d5a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -12,6 +12,7 @@ /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth /vllm/model_executor/guided_decoding @mgoin @russellb /vllm/multimodal @DarkLight1337 @ywang96 +/vllm/vllm_flash_attn @LucasWilkinson CMakeLists.txt @tlrmchlsmth # vLLM V1 diff --git a/.gitignore b/.gitignore index 49330717640bb..728213ceb74f0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,6 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* -!vllm/vllm_flash_attn/__init__.py -!vllm/vllm_flash_attn/fa_utils.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/setup.py b/setup.py index ed4b88364a6d7..a1867960e5930 100755 --- a/setup.py +++ b/setup.py @@ -269,15 +269,17 @@ class cmake_build_ext(build_ext): # First, run the standard build_ext command to compile the extensions super().run() - # copy vllm/vllm_flash_attn/*.py from self.build_lib to current + # copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current # directory so that they can be included in the editable build import glob - files = glob.glob( - os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py")) + files = glob.glob(os.path.join(self.build_lib, "vllm", + "vllm_flash_attn", "**", "*.py"), + recursive=True) for file in files: dst_file = os.path.join("vllm/vllm_flash_attn", - os.path.basename(file)) + file.split("vllm/vllm_flash_attn/")[-1]) print(f"Copying {file} to {dst_file}") + os.makedirs(os.path.dirname(dst_file), exist_ok=True) self.copy_file(file, dst_file) @@ -377,12 +379,22 @@ class repackage_wheel(build_ext): "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", - "vllm/vllm_flash_attn/flash_attn_interface.py", "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] - file_members = filter(lambda x: x.filename in files_to_copy, - wheel.filelist) + + file_members = list( + filter(lambda x: x.filename in files_to_copy, wheel.filelist)) + + # vllm_flash_attn python code: + # Regex from + # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` + import re + compiled_regex = re.compile( + r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") + file_members += list( + filter(lambda x: compiled_regex.match(x.filename), + wheel.filelist)) for file in file_members: print(f"Extracting and including {file.filename} " diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index f9c5ad4df54ea..7f8f720eee0ae 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -22,13 +22,13 @@ from vllm.attention.backends.utils import ( compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16: + if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: assert ( layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( "key/v_scale is only supported in FlashAttention 3 with " diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index a3dec0dbda9f8..382a9a6d44d84 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) @@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version if HAS_TRITON: from vllm.attention.ops.triton_flash_attention import triton_attention diff --git a/vllm/vllm_flash_attn/fa_utils.py b/vllm/attention/utils/fa_utils.py similarity index 100% rename from vllm/vllm_flash_attn/fa_utils.py rename to vllm/attention/utils/fa_utils.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5b87a2ce2aac..1c966469db008 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1377,7 +1377,7 @@ class EngineArgs: ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False if fp8_attention and will_use_fa: - from vllm.vllm_flash_attn.fa_utils import ( + from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() if not supported: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 51ae386d33898..0c86ad8a828a6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,11 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f826f8a21789e..e6e483bae2bc8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -197,6 +197,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, MLAAttentionImpl) from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, @@ -204,7 +205,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: from vllm.vllm_flash_attn import flash_attn_varlen_func diff --git a/vllm/vllm_flash_attn/__init__.py b/vllm/vllm_flash_attn/__init__.py deleted file mode 100644 index cf8f1207a65a4..0000000000000 --- a/vllm/vllm_flash_attn/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import importlib.metadata - -try: - __version__ = importlib.metadata.version("vllm-flash-attn") -except importlib.metadata.PackageNotFoundError: - # in this case, vllm-flash-attn is built from installing vllm editable - __version__ = "0.0.0.dev0" - -from .flash_attn_interface import (fa_version_unsupported_reason, - flash_attn_varlen_func, - flash_attn_with_kvcache, - get_scheduler_metadata, - is_fa_version_supported, sparse_attn_func, - sparse_attn_varlen_func) - -__all__ = [ - 'flash_attn_varlen_func', 'flash_attn_with_kvcache', - 'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func', - 'is_fa_version_supported', 'fa_version_unsupported_reason' -] diff --git a/vllm/vllm_flash_attn/flash_attn_interface.pyi b/vllm/vllm_flash_attn/flash_attn_interface.pyi deleted file mode 100644 index ca8311e01358c..0000000000000 --- a/vllm/vllm_flash_attn/flash_attn_interface.pyi +++ /dev/null @@ -1,245 +0,0 @@ -# ruff: ignore -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from typing import Any, Literal, overload - -import torch - -def get_scheduler_metadata( - batch_size: int, - max_seqlen_q: int, - max_seqlen_k: int, - num_heads_q: int, - num_heads_kv: int, - headdim: int, - cache_seqlens: torch.Tensor, - qkv_dtype: torch.dtype = ..., - headdim_v: int | None = ..., - cu_seqlens_q: torch.Tensor | None = ..., - cu_seqlens_k_new: torch.Tensor | None = ..., - cache_leftpad: torch.Tensor | None = ..., - page_size: int = ..., - max_seqlen_k_new: int = ..., - causal: bool = ..., - window_size: tuple[int, int] = ..., - has_softcap: bool = ..., - num_splits: int = ..., - pack_gqa: Any | None = ..., - sm_margin: int = ..., -): ... -@overload -def flash_attn_varlen_func( - q: tuple[int, int, int], - k: tuple[int, int, int], - v: tuple[int, int, int], - max_seqlen_q: int, - cu_seqlens_q: torch.Tensor | None, - max_seqlen_k: int, - cu_seqlens_k: torch.Tensor | None = ..., - seqused_k: Any | None = ..., - q_v: Any | None = ..., - dropout_p: float = ..., - causal: bool = ..., - window_size: list[int] | None = ..., - softmax_scale: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - block_table: Any | None = ..., - return_softmax_lse: Literal[False] = ..., - out: Any = ..., - # FA3 Only - scheduler_metadata: Any | None = ..., - q_descale: Any | None = ..., - k_descale: Any | None = ..., - v_descale: Any | None = ..., - # Version selector - fa_version: int = ..., -) -> tuple[int, int, int]: ... -@overload -def flash_attn_varlen_func( - q: tuple[int, int, int], - k: tuple[int, int, int], - v: tuple[int, int, int], - max_seqlen_q: int, - cu_seqlens_q: torch.Tensor | None, - max_seqlen_k: int, - cu_seqlens_k: torch.Tensor | None = ..., - seqused_k: Any | None = ..., - q_v: Any | None = ..., - dropout_p: float = ..., - causal: bool = ..., - window_size: list[int] | None = ..., - softmax_scale: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - block_table: Any | None = ..., - return_softmax_lse: Literal[True] = ..., - out: Any = ..., - # FA3 Only - scheduler_metadata: Any | None = ..., - q_descale: Any | None = ..., - k_descale: Any | None = ..., - v_descale: Any | None = ..., - # Version selector - fa_version: int = ..., -) -> tuple[tuple[int, int, int], tuple[int, int]]: ... -@overload -def flash_attn_with_kvcache( - q: tuple[int, int, int, int], - k_cache: tuple[int, int, int, int], - v_cache: tuple[int, int, int, int], - k: tuple[int, int, int, int] | None = ..., - v: tuple[int, int, int, int] | None = ..., - rotary_cos: tuple[int, int] | None = ..., - rotary_sin: tuple[int, int] | None = ..., - cache_seqlens: int | torch.Tensor | None = None, - cache_batch_idx: torch.Tensor | None = None, - cache_leftpad: torch.Tensor | None = ..., - block_table: torch.Tensor | None = ..., - softmax_scale: float = ..., - causal: bool = ..., - window_size: tuple[int, int] = ..., # -1 means infinite context window - softcap: float = ..., - rotary_interleaved: bool = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - num_splits: int = ..., - return_softmax_lse: Literal[False] = ..., - *, - out: Any = ..., - # FA3 Only - scheduler_metadata: Any | None = ..., - q_descale: Any | None = ..., - k_descale: Any | None = ..., - v_descale: Any | None = ..., - # Version selector - fa_version: int = ..., -) -> tuple[int, int, int, int]: ... -@overload -def flash_attn_with_kvcache( - q: tuple[int, int, int, int], - k_cache: tuple[int, int, int, int], - v_cache: tuple[int, int, int, int], - k: tuple[int, int, int, int] | None = ..., - v: tuple[int, int, int, int] | None = ..., - rotary_cos: tuple[int, int] | None = ..., - rotary_sin: tuple[int, int] | None = ..., - cache_seqlens: int | torch.Tensor | None = None, - cache_batch_idx: torch.Tensor | None = None, - cache_leftpad: torch.Tensor | None = ..., - block_table: torch.Tensor | None = ..., - softmax_scale: float = ..., - causal: bool = ..., - window_size: tuple[int, int] = ..., # -1 means infinite context window - softcap: float = ..., - rotary_interleaved: bool = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - num_splits: int = ..., - return_softmax_lse: Literal[True] = ..., - *, - out: Any = ..., - # FA3 Only - scheduler_metadata: Any | None = ..., - q_descale: Any | None = ..., - k_descale: Any | None = ..., - v_descale: Any | None = ..., - # Version selector - fa_version: int = ..., -) -> tuple[tuple[int, int, int], tuple[int, int]]: ... -@overload -def sparse_attn_func( - q: tuple[int, int, int, int], - k: tuple[int, int, int, int], - v: tuple[int, int, int, int], - block_count: tuple[int, int, float], - block_offset: tuple[int, int, float, int], - column_count: tuple[int, int, float], - column_index: tuple[int, int, float, int], - dropout_p: float = ..., - softmax_scale: float = ..., - causal: bool = ..., - softcap: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - *, - return_softmax_lse: Literal[False] = ..., - out: Any = ..., -) -> tuple[int, int, int]: ... -@overload -def sparse_attn_func( - q: tuple[int, int, int, int], - k: tuple[int, int, int, int], - v: tuple[int, int, int, int], - block_count: tuple[int, int, float], - block_offset: tuple[int, int, float, int], - column_count: tuple[int, int, float], - column_index: tuple[int, int, float, int], - dropout_p: float = ..., - softmax_scale: float = ..., - causal: bool = ..., - softcap: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - *, - return_softmax_lse: Literal[True] = ..., - out: Any = ..., -) -> tuple[tuple[int, int, int], tuple[int, int]]: ... -@overload -def sparse_attn_varlen_func( - q: tuple[int, int, int], - k: tuple[int, int, int], - v: tuple[int, int, int], - block_count: tuple[int, int, float], - block_offset: tuple[int, int, float, int], - column_count: tuple[int, int, float], - column_index: tuple[int, int, float, int], - cu_seqlens_q: torch.Tensor | None, - cu_seqlens_k: torch.Tensor | None, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float = ..., - softmax_scale: float = ..., - causal: bool = ..., - softcap: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - *, - return_softmax_lse: Literal[False] = ..., - out: Any = ..., -) -> tuple[int, int, int]: ... -@overload -def sparse_attn_varlen_func( - q: tuple[int, int, int], - k: tuple[int, int, int], - v: tuple[int, int, int], - block_count: tuple[int, int, float], - block_offset: tuple[int, int, float, int], - column_count: tuple[int, int, float], - column_index: tuple[int, int, float, int], - cu_seqlens_q: torch.Tensor | None, - cu_seqlens_k: torch.Tensor | None, - max_seqlen_q: int, - max_seqlen_k: int, - dropout_p: float = ..., - softmax_scale: float = ..., - causal: bool = ..., - softcap: float = ..., - alibi_slopes: tuple[int] | tuple[int, int] | None = ..., - deterministic: bool = ..., - return_attn_probs: bool = ..., - *, - return_softmax_lse: Literal[True] = ..., - out: Any = ..., -) -> tuple[tuple[int, int, int], tuple[int, int]]: ... -def is_fa_version_supported( - fa_version: int, device: torch.device | None = None -) -> bool: ... -def fa_version_unsupported_reason( - fa_version: int, device: torch.device | None = None -) -> str | None: ...