mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 03:40:13 +08:00
[BugFix] Fix vllm_flash_attn install issues (#17267)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
20e489eaa1
commit
d8bccde686
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -12,6 +12,7 @@
|
|||||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96
|
/vllm/multimodal @DarkLight1337 @ywang96
|
||||||
|
/vllm/vllm_flash_attn @LucasWilkinson
|
||||||
CMakeLists.txt @tlrmchlsmth
|
CMakeLists.txt @tlrmchlsmth
|
||||||
|
|
||||||
# vLLM V1
|
# vLLM V1
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,8 +3,6 @@
|
|||||||
|
|
||||||
# vllm-flash-attn built from source
|
# vllm-flash-attn built from source
|
||||||
vllm/vllm_flash_attn/*
|
vllm/vllm_flash_attn/*
|
||||||
!vllm/vllm_flash_attn/__init__.py
|
|
||||||
!vllm/vllm_flash_attn/fa_utils.py
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|||||||
26
setup.py
26
setup.py
@ -269,15 +269,17 @@ class cmake_build_ext(build_ext):
|
|||||||
# First, run the standard build_ext command to compile the extensions
|
# First, run the standard build_ext command to compile the extensions
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
# copy vllm/vllm_flash_attn/*.py from self.build_lib to current
|
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
||||||
# directory so that they can be included in the editable build
|
# directory so that they can be included in the editable build
|
||||||
import glob
|
import glob
|
||||||
files = glob.glob(
|
files = glob.glob(os.path.join(self.build_lib, "vllm",
|
||||||
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
|
"vllm_flash_attn", "**", "*.py"),
|
||||||
|
recursive=True)
|
||||||
for file in files:
|
for file in files:
|
||||||
dst_file = os.path.join("vllm/vllm_flash_attn",
|
dst_file = os.path.join("vllm/vllm_flash_attn",
|
||||||
os.path.basename(file))
|
file.split("vllm/vllm_flash_attn/")[-1])
|
||||||
print(f"Copying {file} to {dst_file}")
|
print(f"Copying {file} to {dst_file}")
|
||||||
|
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||||
self.copy_file(file, dst_file)
|
self.copy_file(file, dst_file)
|
||||||
|
|
||||||
|
|
||||||
@ -377,12 +379,22 @@ class repackage_wheel(build_ext):
|
|||||||
"vllm/_flashmla_C.abi3.so",
|
"vllm/_flashmla_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
|
||||||
"vllm/cumem_allocator.abi3.so",
|
"vllm/cumem_allocator.abi3.so",
|
||||||
# "vllm/_version.py", # not available in nightly wheels yet
|
# "vllm/_version.py", # not available in nightly wheels yet
|
||||||
]
|
]
|
||||||
file_members = filter(lambda x: x.filename in files_to_copy,
|
|
||||||
wheel.filelist)
|
file_members = list(
|
||||||
|
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
|
||||||
|
|
||||||
|
# vllm_flash_attn python code:
|
||||||
|
# Regex from
|
||||||
|
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
|
||||||
|
import re
|
||||||
|
compiled_regex = re.compile(
|
||||||
|
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||||
|
file_members += list(
|
||||||
|
filter(lambda x: compiled_regex.match(x.filename),
|
||||||
|
wheel.filelist))
|
||||||
|
|
||||||
for file in file_members:
|
for file in file_members:
|
||||||
print(f"Extracting and including {file.filename} "
|
print(f"Extracting and including {file.filename} "
|
||||||
|
|||||||
@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
|
|||||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||||
|
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||||
|
get_flash_attn_version)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
from vllm.multimodal import MultiModalPlaceholderMap
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||||
flash_attn_with_kvcache)
|
flash_attn_with_kvcache)
|
||||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
|
||||||
get_flash_attn_version)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||||
@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||||
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
|
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
|
||||||
assert (
|
assert (
|
||||||
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
||||||
"key/v_scale is only supported in FlashAttention 3 with "
|
"key/v_scale is only supported in FlashAttention 3 with "
|
||||||
|
|||||||
@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||||
|
|||||||
@ -1377,7 +1377,7 @@ class EngineArgs:
|
|||||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||||
supported = False
|
supported = False
|
||||||
if fp8_attention and will_use_fa:
|
if fp8_attention and will_use_fa:
|
||||||
from vllm.vllm_flash_attn.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8)
|
flash_attn_supports_fp8)
|
||||||
supported = flash_attn_supports_fp8()
|
supported = flash_attn_supports_fp8()
|
||||||
if not supported:
|
if not supported:
|
||||||
|
|||||||
@ -11,11 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
|
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||||
|
get_flash_attn_version)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
|
||||||
get_flash_attn_version)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|||||||
@ -197,6 +197,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|||||||
MLAAttentionImpl)
|
MLAAttentionImpl)
|
||||||
from vllm.attention.backends.utils import get_mla_dims
|
from vllm.attention.backends.utils import get_mla_dims
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase, RowParallelLinear,
|
||||||
@ -204,7 +205,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|||||||
@ -1,22 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import importlib.metadata
|
|
||||||
|
|
||||||
try:
|
|
||||||
__version__ = importlib.metadata.version("vllm-flash-attn")
|
|
||||||
except importlib.metadata.PackageNotFoundError:
|
|
||||||
# in this case, vllm-flash-attn is built from installing vllm editable
|
|
||||||
__version__ = "0.0.0.dev0"
|
|
||||||
|
|
||||||
from .flash_attn_interface import (fa_version_unsupported_reason,
|
|
||||||
flash_attn_varlen_func,
|
|
||||||
flash_attn_with_kvcache,
|
|
||||||
get_scheduler_metadata,
|
|
||||||
is_fa_version_supported, sparse_attn_func,
|
|
||||||
sparse_attn_varlen_func)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
|
|
||||||
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
|
|
||||||
'is_fa_version_supported', 'fa_version_unsupported_reason'
|
|
||||||
]
|
|
||||||
@ -1,245 +0,0 @@
|
|||||||
# ruff: ignore
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal, overload
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def get_scheduler_metadata(
|
|
||||||
batch_size: int,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
num_heads_q: int,
|
|
||||||
num_heads_kv: int,
|
|
||||||
headdim: int,
|
|
||||||
cache_seqlens: torch.Tensor,
|
|
||||||
qkv_dtype: torch.dtype = ...,
|
|
||||||
headdim_v: int | None = ...,
|
|
||||||
cu_seqlens_q: torch.Tensor | None = ...,
|
|
||||||
cu_seqlens_k_new: torch.Tensor | None = ...,
|
|
||||||
cache_leftpad: torch.Tensor | None = ...,
|
|
||||||
page_size: int = ...,
|
|
||||||
max_seqlen_k_new: int = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
window_size: tuple[int, int] = ...,
|
|
||||||
has_softcap: bool = ...,
|
|
||||||
num_splits: int = ...,
|
|
||||||
pack_gqa: Any | None = ...,
|
|
||||||
sm_margin: int = ...,
|
|
||||||
): ...
|
|
||||||
@overload
|
|
||||||
def flash_attn_varlen_func(
|
|
||||||
q: tuple[int, int, int],
|
|
||||||
k: tuple[int, int, int],
|
|
||||||
v: tuple[int, int, int],
|
|
||||||
max_seqlen_q: int,
|
|
||||||
cu_seqlens_q: torch.Tensor | None,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
cu_seqlens_k: torch.Tensor | None = ...,
|
|
||||||
seqused_k: Any | None = ...,
|
|
||||||
q_v: Any | None = ...,
|
|
||||||
dropout_p: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
window_size: list[int] | None = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
block_table: Any | None = ...,
|
|
||||||
return_softmax_lse: Literal[False] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
# FA3 Only
|
|
||||||
scheduler_metadata: Any | None = ...,
|
|
||||||
q_descale: Any | None = ...,
|
|
||||||
k_descale: Any | None = ...,
|
|
||||||
v_descale: Any | None = ...,
|
|
||||||
# Version selector
|
|
||||||
fa_version: int = ...,
|
|
||||||
) -> tuple[int, int, int]: ...
|
|
||||||
@overload
|
|
||||||
def flash_attn_varlen_func(
|
|
||||||
q: tuple[int, int, int],
|
|
||||||
k: tuple[int, int, int],
|
|
||||||
v: tuple[int, int, int],
|
|
||||||
max_seqlen_q: int,
|
|
||||||
cu_seqlens_q: torch.Tensor | None,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
cu_seqlens_k: torch.Tensor | None = ...,
|
|
||||||
seqused_k: Any | None = ...,
|
|
||||||
q_v: Any | None = ...,
|
|
||||||
dropout_p: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
window_size: list[int] | None = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
block_table: Any | None = ...,
|
|
||||||
return_softmax_lse: Literal[True] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
# FA3 Only
|
|
||||||
scheduler_metadata: Any | None = ...,
|
|
||||||
q_descale: Any | None = ...,
|
|
||||||
k_descale: Any | None = ...,
|
|
||||||
v_descale: Any | None = ...,
|
|
||||||
# Version selector
|
|
||||||
fa_version: int = ...,
|
|
||||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
|
||||||
@overload
|
|
||||||
def flash_attn_with_kvcache(
|
|
||||||
q: tuple[int, int, int, int],
|
|
||||||
k_cache: tuple[int, int, int, int],
|
|
||||||
v_cache: tuple[int, int, int, int],
|
|
||||||
k: tuple[int, int, int, int] | None = ...,
|
|
||||||
v: tuple[int, int, int, int] | None = ...,
|
|
||||||
rotary_cos: tuple[int, int] | None = ...,
|
|
||||||
rotary_sin: tuple[int, int] | None = ...,
|
|
||||||
cache_seqlens: int | torch.Tensor | None = None,
|
|
||||||
cache_batch_idx: torch.Tensor | None = None,
|
|
||||||
cache_leftpad: torch.Tensor | None = ...,
|
|
||||||
block_table: torch.Tensor | None = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
|
||||||
softcap: float = ...,
|
|
||||||
rotary_interleaved: bool = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
num_splits: int = ...,
|
|
||||||
return_softmax_lse: Literal[False] = ...,
|
|
||||||
*,
|
|
||||||
out: Any = ...,
|
|
||||||
# FA3 Only
|
|
||||||
scheduler_metadata: Any | None = ...,
|
|
||||||
q_descale: Any | None = ...,
|
|
||||||
k_descale: Any | None = ...,
|
|
||||||
v_descale: Any | None = ...,
|
|
||||||
# Version selector
|
|
||||||
fa_version: int = ...,
|
|
||||||
) -> tuple[int, int, int, int]: ...
|
|
||||||
@overload
|
|
||||||
def flash_attn_with_kvcache(
|
|
||||||
q: tuple[int, int, int, int],
|
|
||||||
k_cache: tuple[int, int, int, int],
|
|
||||||
v_cache: tuple[int, int, int, int],
|
|
||||||
k: tuple[int, int, int, int] | None = ...,
|
|
||||||
v: tuple[int, int, int, int] | None = ...,
|
|
||||||
rotary_cos: tuple[int, int] | None = ...,
|
|
||||||
rotary_sin: tuple[int, int] | None = ...,
|
|
||||||
cache_seqlens: int | torch.Tensor | None = None,
|
|
||||||
cache_batch_idx: torch.Tensor | None = None,
|
|
||||||
cache_leftpad: torch.Tensor | None = ...,
|
|
||||||
block_table: torch.Tensor | None = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
|
||||||
softcap: float = ...,
|
|
||||||
rotary_interleaved: bool = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
num_splits: int = ...,
|
|
||||||
return_softmax_lse: Literal[True] = ...,
|
|
||||||
*,
|
|
||||||
out: Any = ...,
|
|
||||||
# FA3 Only
|
|
||||||
scheduler_metadata: Any | None = ...,
|
|
||||||
q_descale: Any | None = ...,
|
|
||||||
k_descale: Any | None = ...,
|
|
||||||
v_descale: Any | None = ...,
|
|
||||||
# Version selector
|
|
||||||
fa_version: int = ...,
|
|
||||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
|
||||||
@overload
|
|
||||||
def sparse_attn_func(
|
|
||||||
q: tuple[int, int, int, int],
|
|
||||||
k: tuple[int, int, int, int],
|
|
||||||
v: tuple[int, int, int, int],
|
|
||||||
block_count: tuple[int, int, float],
|
|
||||||
block_offset: tuple[int, int, float, int],
|
|
||||||
column_count: tuple[int, int, float],
|
|
||||||
column_index: tuple[int, int, float, int],
|
|
||||||
dropout_p: float = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
softcap: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
*,
|
|
||||||
return_softmax_lse: Literal[False] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
) -> tuple[int, int, int]: ...
|
|
||||||
@overload
|
|
||||||
def sparse_attn_func(
|
|
||||||
q: tuple[int, int, int, int],
|
|
||||||
k: tuple[int, int, int, int],
|
|
||||||
v: tuple[int, int, int, int],
|
|
||||||
block_count: tuple[int, int, float],
|
|
||||||
block_offset: tuple[int, int, float, int],
|
|
||||||
column_count: tuple[int, int, float],
|
|
||||||
column_index: tuple[int, int, float, int],
|
|
||||||
dropout_p: float = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
softcap: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
*,
|
|
||||||
return_softmax_lse: Literal[True] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
|
||||||
@overload
|
|
||||||
def sparse_attn_varlen_func(
|
|
||||||
q: tuple[int, int, int],
|
|
||||||
k: tuple[int, int, int],
|
|
||||||
v: tuple[int, int, int],
|
|
||||||
block_count: tuple[int, int, float],
|
|
||||||
block_offset: tuple[int, int, float, int],
|
|
||||||
column_count: tuple[int, int, float],
|
|
||||||
column_index: tuple[int, int, float, int],
|
|
||||||
cu_seqlens_q: torch.Tensor | None,
|
|
||||||
cu_seqlens_k: torch.Tensor | None,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
dropout_p: float = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
softcap: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
*,
|
|
||||||
return_softmax_lse: Literal[False] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
) -> tuple[int, int, int]: ...
|
|
||||||
@overload
|
|
||||||
def sparse_attn_varlen_func(
|
|
||||||
q: tuple[int, int, int],
|
|
||||||
k: tuple[int, int, int],
|
|
||||||
v: tuple[int, int, int],
|
|
||||||
block_count: tuple[int, int, float],
|
|
||||||
block_offset: tuple[int, int, float, int],
|
|
||||||
column_count: tuple[int, int, float],
|
|
||||||
column_index: tuple[int, int, float, int],
|
|
||||||
cu_seqlens_q: torch.Tensor | None,
|
|
||||||
cu_seqlens_k: torch.Tensor | None,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
dropout_p: float = ...,
|
|
||||||
softmax_scale: float = ...,
|
|
||||||
causal: bool = ...,
|
|
||||||
softcap: float = ...,
|
|
||||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
|
||||||
deterministic: bool = ...,
|
|
||||||
return_attn_probs: bool = ...,
|
|
||||||
*,
|
|
||||||
return_softmax_lse: Literal[True] = ...,
|
|
||||||
out: Any = ...,
|
|
||||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
|
||||||
def is_fa_version_supported(
|
|
||||||
fa_version: int, device: torch.device | None = None
|
|
||||||
) -> bool: ...
|
|
||||||
def fa_version_unsupported_reason(
|
|
||||||
fa_version: int, device: torch.device | None = None
|
|
||||||
) -> str | None: ...
|
|
||||||
Loading…
x
Reference in New Issue
Block a user