[Chore] added stubs for vllm_flash_attn during development mode (#17228)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Aaron Pham 2025-04-26 10:45:26 -04:00 committed by GitHub
parent dc2ceca5c5
commit e782e0a170
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 269 additions and 2 deletions

View File

@ -58,7 +58,8 @@ ignore_patterns = [
line-length = 80
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py"
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]

View File

@ -378,7 +378,6 @@ class repackage_wheel(build_ext):
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/vllm_flash_attn/__init__.py",
"vllm/cumem_allocator.abi3.so",
# "vllm/_version.py", # not available in nightly wheels yet
]

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.metadata
try:
__version__ = importlib.metadata.version("vllm-flash-attn")
except importlib.metadata.PackageNotFoundError:
# in this case, vllm-flash-attn is built from installing vllm editable
__version__ = "0.0.0.dev0"
from .flash_attn_interface import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
get_scheduler_metadata,
is_fa_version_supported, sparse_attn_func,
sparse_attn_varlen_func)
__all__ = [
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
'is_fa_version_supported', 'fa_version_unsupported_reason'
]

View File

@ -0,0 +1,245 @@
# ruff: ignore
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any, Literal, overload
import torch
def get_scheduler_metadata(
batch_size: int,
max_seqlen_q: int,
max_seqlen_k: int,
num_heads_q: int,
num_heads_kv: int,
headdim: int,
cache_seqlens: torch.Tensor,
qkv_dtype: torch.dtype = ...,
headdim_v: int | None = ...,
cu_seqlens_q: torch.Tensor | None = ...,
cu_seqlens_k_new: torch.Tensor | None = ...,
cache_leftpad: torch.Tensor | None = ...,
page_size: int = ...,
max_seqlen_k_new: int = ...,
causal: bool = ...,
window_size: tuple[int, int] = ...,
has_softcap: bool = ...,
num_splits: int = ...,
pack_gqa: Any | None = ...,
sm_margin: int = ...,
): ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int]: ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[False] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int, int]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[True] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
def is_fa_version_supported(
fa_version: int, device: torch.device | None = None
) -> bool: ...
def fa_version_unsupported_reason(
fa_version: int, device: torch.device | None = None
) -> str | None: ...