mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 09:24:28 +08:00
[Chore] added stubs for vllm_flash_attn during development mode (#17228)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
dc2ceca5c5
commit
e782e0a170
@ -58,7 +58,8 @@ ignore_patterns = [
|
||||
line-length = 80
|
||||
exclude = [
|
||||
# External file, leaving license intact
|
||||
"examples/other/fp8/quantizer/quantize.py"
|
||||
"examples/other/fp8/quantizer/quantize.py",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
|
||||
1
setup.py
1
setup.py
@ -378,7 +378,6 @@ class repackage_wheel(build_ext):
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
"vllm/vllm_flash_attn/__init__.py",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
# "vllm/_version.py", # not available in nightly wheels yet
|
||||
]
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version("vllm-flash-attn")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# in this case, vllm-flash-attn is built from installing vllm editable
|
||||
__version__ = "0.0.0.dev0"
|
||||
|
||||
from .flash_attn_interface import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
get_scheduler_metadata,
|
||||
is_fa_version_supported, sparse_attn_func,
|
||||
sparse_attn_varlen_func)
|
||||
|
||||
__all__ = [
|
||||
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
|
||||
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
|
||||
'is_fa_version_supported', 'fa_version_unsupported_reason'
|
||||
]
|
||||
245
vllm/vllm_flash_attn/flash_attn_interface.pyi
Normal file
245
vllm/vllm_flash_attn/flash_attn_interface.pyi
Normal file
@ -0,0 +1,245 @@
|
||||
# ruff: ignore
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import torch
|
||||
|
||||
def get_scheduler_metadata(
|
||||
batch_size: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
headdim: int,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype: torch.dtype = ...,
|
||||
headdim_v: int | None = ...,
|
||||
cu_seqlens_q: torch.Tensor | None = ...,
|
||||
cu_seqlens_k_new: torch.Tensor | None = ...,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
page_size: int = ...,
|
||||
max_seqlen_k_new: int = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ...,
|
||||
has_softcap: bool = ...,
|
||||
num_splits: int = ...,
|
||||
pack_gqa: Any | None = ...,
|
||||
sm_margin: int = ...,
|
||||
): ...
|
||||
@overload
|
||||
def flash_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
max_seqlen_q: int,
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
max_seqlen_k: int,
|
||||
cu_seqlens_k: torch.Tensor | None = ...,
|
||||
seqused_k: Any | None = ...,
|
||||
q_v: Any | None = ...,
|
||||
dropout_p: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: list[int] | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
block_table: Any | None = ...,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def flash_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
max_seqlen_q: int,
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
max_seqlen_k: int,
|
||||
cu_seqlens_k: torch.Tensor | None = ...,
|
||||
seqused_k: Any | None = ...,
|
||||
q_v: Any | None = ...,
|
||||
dropout_p: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: list[int] | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
block_table: Any | None = ...,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def flash_attn_with_kvcache(
|
||||
q: tuple[int, int, int, int],
|
||||
k_cache: tuple[int, int, int, int],
|
||||
v_cache: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int] | None = ...,
|
||||
v: tuple[int, int, int, int] | None = ...,
|
||||
rotary_cos: tuple[int, int] | None = ...,
|
||||
rotary_sin: tuple[int, int] | None = ...,
|
||||
cache_seqlens: int | torch.Tensor | None = None,
|
||||
cache_batch_idx: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
block_table: torch.Tensor | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
||||
softcap: float = ...,
|
||||
rotary_interleaved: bool = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
num_splits: int = ...,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
*,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[int, int, int, int]: ...
|
||||
@overload
|
||||
def flash_attn_with_kvcache(
|
||||
q: tuple[int, int, int, int],
|
||||
k_cache: tuple[int, int, int, int],
|
||||
v_cache: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int] | None = ...,
|
||||
v: tuple[int, int, int, int] | None = ...,
|
||||
rotary_cos: tuple[int, int] | None = ...,
|
||||
rotary_sin: tuple[int, int] | None = ...,
|
||||
cache_seqlens: int | torch.Tensor | None = None,
|
||||
cache_batch_idx: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
block_table: torch.Tensor | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
||||
softcap: float = ...,
|
||||
rotary_interleaved: bool = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
num_splits: int = ...,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
*,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def sparse_attn_func(
|
||||
q: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int],
|
||||
v: tuple[int, int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def sparse_attn_func(
|
||||
q: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int],
|
||||
v: tuple[int, int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def sparse_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
cu_seqlens_k: torch.Tensor | None,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def sparse_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
cu_seqlens_k: torch.Tensor | None,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
def is_fa_version_supported(
|
||||
fa_version: int, device: torch.device | None = None
|
||||
) -> bool: ...
|
||||
def fa_version_unsupported_reason(
|
||||
fa_version: int, device: torch.device | None = None
|
||||
) -> str | None: ...
|
||||
Loading…
x
Reference in New Issue
Block a user