mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
708 lines
29 KiB
Python
Executable File
708 lines
29 KiB
Python
Executable File
# SPDX-License-Identifier: Apache-2.0
|
|
"""Attention layer with FlashAttention."""
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionMetadata, AttentionType,
|
|
is_quantized_kv_cache)
|
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import cdiv
|
|
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
|
get_flash_attn_version)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
|
|
if current_platform.is_cuda():
|
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class FlashAttentionBackend(AttentionBackend):
|
|
|
|
accept_output_buffer: bool = True
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> list[int]:
|
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "FLASH_ATTN_VLLM_V1"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["FlashAttentionImpl"]:
|
|
return FlashAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
return FlashAttentionMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
|
return FlashAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> tuple[int, ...]:
|
|
if block_size % 16 != 0:
|
|
raise ValueError("Block size must be a multiple of 16.")
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
|
return use_cascade_attention(*args, **kwargs)
|
|
|
|
|
|
@dataclass
|
|
class FlashAttentionMetadata:
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
max_query_len: int
|
|
query_start_loc: torch.Tensor
|
|
max_seq_len: int
|
|
seq_lens: torch.Tensor
|
|
block_table: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
# For cascade attention.
|
|
use_cascade: bool
|
|
common_prefix_len: int
|
|
cu_prefix_query_lens: Optional[torch.Tensor]
|
|
prefix_kv_lens: Optional[torch.Tensor]
|
|
suffix_kv_lens: Optional[torch.Tensor]
|
|
|
|
# For logging.
|
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
|
|
|
# for local attention
|
|
@dataclass
|
|
class LocalAttentionMetadata:
|
|
local_query_start_loc: torch.Tensor
|
|
local_seqused_k: torch.Tensor
|
|
local_block_table: torch.Tensor
|
|
local_max_query_len: int
|
|
local_max_seq_len: int
|
|
|
|
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
|
|
|
|
|
#
|
|
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
|
# local attention blocks, where each block is passed to the attention kernel
|
|
# as an independent local ("virtual") batch item.
|
|
#
|
|
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
|
# q_seqlens = [4, 10, 5]
|
|
# kv_seqlens = [6, 17, 9]
|
|
# Then normally for regular attention we would compute with an attention mask
|
|
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1 1 1 1 1
|
|
# 3 | 1 1 1 1 1 1
|
|
#
|
|
# for local attention (with attn_chunk_size = 4) we would compute with an
|
|
# attention mask like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# We can simulate this mask using standard flash-attention by breaking the
|
|
# sequences into local ("virtual") batches, where each local batch item is a
|
|
# local attention block, so in this case batch idx 0 would be broken up into:
|
|
#
|
|
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
|
# k_toks > 0 1 2 3
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
|
# k_toks > 4 5
|
|
# q_toks v _____________
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# e.g. if we have:
|
|
# attn_chunk_size = 4
|
|
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
|
# Then this function would return:
|
|
# __b0__ ______b1______ __b2__ < orig batch indices
|
|
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
|
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
|
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
|
def make_local_attention_virtual_batches(
|
|
attn_chunk_size: int,
|
|
query_start_loc_np: np.ndarray,
|
|
seq_lens_np: np.ndarray,
|
|
block_table: torch.Tensor,
|
|
page_size: int = 0,
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
|
actual_batch_size = seq_lens_np.shape[0]
|
|
|
|
# Handle if we are starting in the middle of a local attention block,
|
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
|
# the number of tokens that are not in the first local attention block and
|
|
# then we can simply use a cdiv for the rest.
|
|
# For example if we have:
|
|
# attn_chunk_size = 4
|
|
# q_seqlens = [4, 10, 5]
|
|
# k_seqlens = [6, 17, 9]
|
|
# Then we would get:
|
|
# new_tokens_in_first_block = [2, 1, 4]
|
|
# local_blocks = [2, 4, 2]
|
|
q_tokens_in_first_block = np.minimum(
|
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
|
q_seqlens).astype(np.int32)
|
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
|
attn_chunk_size)
|
|
|
|
# Once we know the number of local blocks we can compute the request spans
|
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
|
# have to make,
|
|
# For the above example we would get:
|
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
|
#
|
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
|
cu_num_blocks = np.cumsum(local_blocks)
|
|
virtual_batches = cu_num_blocks[-1]
|
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
|
# first and last blocks could be partial
|
|
seqlens_q_local = \
|
|
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
|
# set the first block since this may be a partial block
|
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
|
# set the remaining blocks
|
|
seqlens_q_local[arange > 0] = np.minimum(
|
|
seqlens_q_local - attn_chunk_size * (arange - 1),
|
|
attn_chunk_size)[arange > 0]
|
|
|
|
# convert from q_seqlens to cu_seqlens_q
|
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
|
.astype(np.int32)
|
|
|
|
# compute the seqlens_k_local,
|
|
# basically a full local attention block for all but the last block in each
|
|
# batch
|
|
# For our example this will be:
|
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
|
seqlens_k_local = np.full(cu_num_blocks[-1],
|
|
attn_chunk_size,
|
|
dtype=np.int32)
|
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
|
|
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
|
(rarange * attn_chunk_size + \
|
|
np.repeat(tokens_in_last_block, local_blocks))
|
|
# For the example the local attention blocks start at:
|
|
# _b0_ _____b1_____ _b2_
|
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
|
block_starts = k_seqstarts_absolute // page_size
|
|
assert attn_chunk_size % page_size == 0, \
|
|
f"attn_chunk_size {attn_chunk_size} is not " \
|
|
f"divisible by page_size {page_size}"
|
|
pages_per_local_batch = attn_chunk_size // page_size
|
|
|
|
# Create a block_table for the local attention blocks
|
|
# For out example if we have a block-table like (assuming page_size=2):
|
|
# block_table = [
|
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
|
# ]
|
|
# Then for the local batches we would want a block-table like
|
|
# block_table_local = [
|
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
|
# ]
|
|
block_indices= np.broadcast_to(
|
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
|
(virtual_batches, pages_per_local_batch)) \
|
|
+ np.expand_dims(block_starts, axis=1)
|
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
|
local_blocks * pages_per_local_batch)
|
|
block_table_local = block_table[batch_indices, block_indices]\
|
|
.view(virtual_batches, -1)
|
|
|
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
|
block_table_local
|
|
|
|
|
|
class FlashAttentionMetadataBuilder:
|
|
|
|
def __init__(self, runner: "GPUModelRunner"):
|
|
self.runner = runner
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
return False
|
|
|
|
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
|
common_prefix_len: int):
|
|
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
|
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
|
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
|
non_blocking=True)
|
|
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
|
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
|
|
block_table = (
|
|
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
|
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
|
self.runner.device, non_blocking=True).long()
|
|
|
|
# for local attention
|
|
local_attn_metadata = None
|
|
if self.runner.attention_chunk_size is not None:
|
|
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
|
virt_block_table = make_local_attention_virtual_batches(
|
|
self.runner.attention_chunk_size,
|
|
self.runner.query_start_loc_np[:num_reqs + 1],
|
|
self.runner.seq_lens_np[:num_reqs],
|
|
block_table,
|
|
self.runner.block_size,
|
|
)
|
|
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
|
local_query_start_loc=torch.from_numpy(
|
|
virt_q_cu_seqlens_np).to(self.runner.device,
|
|
non_blocking=True),
|
|
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
|
|
self.runner.device, non_blocking=True),
|
|
local_block_table=virt_block_table,
|
|
local_max_query_len=seqlens_q_local_np.max(),
|
|
local_max_seq_len=virt_k_seqlens_np.max(),
|
|
)
|
|
|
|
use_cascade = common_prefix_len > 0
|
|
if use_cascade:
|
|
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
prefix_kv_lens = torch.tensor([common_prefix_len],
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
|
|
common_prefix_len)
|
|
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
|
|
self.runner.device)
|
|
else:
|
|
cu_prefix_query_lens = None
|
|
prefix_kv_lens = None
|
|
suffix_kv_lens = None
|
|
|
|
attn_metadata = FlashAttentionMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
query_start_loc=query_start_loc,
|
|
max_seq_len=max_seq_len,
|
|
seq_lens=seq_lens,
|
|
block_table=block_table,
|
|
slot_mapping=slot_mapping,
|
|
use_cascade=use_cascade,
|
|
common_prefix_len=common_prefix_len,
|
|
cu_prefix_query_lens=cu_prefix_query_lens,
|
|
prefix_kv_lens=prefix_kv_lens,
|
|
suffix_kv_lens=suffix_kv_lens,
|
|
local_attn_metadata=local_attn_metadata,
|
|
)
|
|
return attn_metadata
|
|
|
|
|
|
class FlashAttentionImpl(AttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[list[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
use_irope: bool = False,
|
|
) -> None:
|
|
if blocksparse_params is not None:
|
|
raise ValueError(
|
|
"FlashAttention does not support block-sparse attention.")
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
|
self.alibi_slopes = alibi_slopes
|
|
if sliding_window is None:
|
|
self.sliding_window = (-1, -1)
|
|
else:
|
|
self.sliding_window = (sliding_window - 1, 0)
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
if logits_soft_cap is None:
|
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
|
logits_soft_cap = 0
|
|
self.logits_soft_cap = logits_soft_cap
|
|
|
|
assert self.num_heads % self.num_kv_heads == 0
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
|
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
|
if head_size not in support_head_sizes:
|
|
raise ValueError(
|
|
f"Head size {head_size} is not supported by FlashAttention. "
|
|
f"Supported head sizes are: {support_head_sizes}. "
|
|
"Set VLLM_USE_V1=0 to use another attention backend.")
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"FlashAttentionImpl")
|
|
self.use_irope = use_irope
|
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
|
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
|
and not flash_attn_supports_fp8():
|
|
raise NotImplementedError(
|
|
"FlashAttention does not support fp8 kv-cache on this device.")
|
|
|
|
def forward(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: FlashAttentionMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with FlashAttention.
|
|
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
NOTE: FP8 quantization, flash-attn expect the size of
|
|
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
|
We use torch's .expand() to avoid duplicating values
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if attn_metadata is None:
|
|
# Profiling run.
|
|
return output
|
|
|
|
# IMPORTANT!
|
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
|
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
|
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
|
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
|
# Minimize the PyTorch ops in this method as much as possible.
|
|
# Whenever making a change in this method, please benchmark the
|
|
# performance to make sure it does not introduce any overhead.
|
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
# Reshape the input keys and values and store them in the cache.
|
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
|
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
|
# the slot_mapping's shape to determine the number of actual tokens.
|
|
key_cache, value_cache = kv_cache.unbind(0)
|
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
key,
|
|
value,
|
|
key_cache,
|
|
value_cache,
|
|
attn_metadata.slot_mapping,
|
|
self.kv_cache_dtype,
|
|
layer._k_scale,
|
|
layer._v_scale,
|
|
)
|
|
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
key_cache = key_cache.view(torch.float8_e4m3fn)
|
|
value_cache = value_cache.view(torch.float8_e4m3fn)
|
|
num_tokens, num_heads, head_size = query.shape
|
|
query, _ = ops.scaled_fp8_quant(
|
|
query.reshape(
|
|
(num_tokens, num_heads * head_size)).contiguous(),
|
|
layer._q_scale)
|
|
query = query.reshape((num_tokens, num_heads, head_size))
|
|
|
|
# Compute attention and update output up to `num_actual_tokens`.
|
|
use_local_attn = \
|
|
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
|
|
|
if not attn_metadata.use_cascade or use_local_attn:
|
|
if use_local_attn:
|
|
assert attn_metadata.local_attn_metadata is not None
|
|
local_metadata = attn_metadata.local_attn_metadata
|
|
cu_seqlens_q = local_metadata.local_query_start_loc
|
|
seqused_k = local_metadata.local_seqused_k
|
|
max_seqlen_q = local_metadata.local_max_query_len
|
|
max_seqlen_k = local_metadata.local_max_seq_len
|
|
block_table = local_metadata.local_block_table
|
|
else:
|
|
cu_seqlens_q = attn_metadata.query_start_loc
|
|
seqused_k = attn_metadata.seq_lens
|
|
max_seqlen_q = attn_metadata.max_query_len
|
|
max_seqlen_k = attn_metadata.max_seq_len
|
|
block_table = attn_metadata.block_table
|
|
|
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
|
|
|
flash_attn_varlen_func(
|
|
q=query[:num_actual_tokens],
|
|
k=key_cache,
|
|
v=value_cache,
|
|
out=output[:num_actual_tokens],
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
max_seqlen_q=max_seqlen_q,
|
|
seqused_k=seqused_k,
|
|
max_seqlen_k=max_seqlen_k,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
alibi_slopes=self.alibi_slopes,
|
|
window_size=self.sliding_window,
|
|
block_table=block_table,
|
|
softcap=self.logits_soft_cap,
|
|
fa_version=self.vllm_flash_attn_version,
|
|
q_descale=layer._q_scale.expand(descale_shape),
|
|
k_descale=layer._k_scale.expand(descale_shape),
|
|
v_descale=layer._v_scale.expand(descale_shape),
|
|
)
|
|
return output
|
|
|
|
assert not use_local_attn, (
|
|
"Cascade attention does not support local attention.")
|
|
# Cascade attention (rare case).
|
|
cascade_attention(
|
|
output[:num_actual_tokens],
|
|
query[:num_actual_tokens],
|
|
key_cache,
|
|
value_cache,
|
|
cu_query_lens=attn_metadata.query_start_loc,
|
|
max_query_len=attn_metadata.max_query_len,
|
|
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
|
|
prefix_kv_lens=attn_metadata.prefix_kv_lens,
|
|
suffix_kv_lens=attn_metadata.suffix_kv_lens,
|
|
max_kv_len=attn_metadata.max_seq_len,
|
|
softmax_scale=self.scale,
|
|
alibi_slopes=self.alibi_slopes,
|
|
sliding_window=self.sliding_window,
|
|
logits_soft_cap=self.logits_soft_cap,
|
|
block_table=attn_metadata.block_table,
|
|
common_prefix_len=attn_metadata.common_prefix_len,
|
|
fa_version=self.vllm_flash_attn_version,
|
|
q_descale=layer._q_scale,
|
|
k_descale=layer._k_scale,
|
|
v_descale=layer._v_scale,
|
|
)
|
|
return output
|
|
|
|
|
|
def use_cascade_attention(
|
|
common_prefix_len: int,
|
|
query_lens: np.ndarray,
|
|
num_query_heads: int,
|
|
num_kv_heads: int,
|
|
use_alibi: bool,
|
|
use_sliding_window: bool,
|
|
num_sms: int,
|
|
) -> bool:
|
|
"""Decide whether to use cascade attention.
|
|
|
|
This function 1) checks whether cascade attention is supported with the
|
|
given configuration, and 2) heuristically decides whether using cascade
|
|
attention can improve performance.
|
|
"""
|
|
# Too short common prefix. Probably not worth using cascade attention.
|
|
# We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
|
|
# NOTE(woosuk): This is the common case. We should return False as soon as
|
|
# possible to avoid any unnecessary computation.
|
|
if common_prefix_len < 256:
|
|
return False
|
|
# Cascade attention is currently not supported with these variants.
|
|
if use_alibi or use_sliding_window:
|
|
return False
|
|
# Too few queries. Probably not worth using cascade attention.
|
|
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
|
|
num_reqs = len(query_lens)
|
|
if num_reqs < 8:
|
|
return False
|
|
|
|
# Heuristics to decide whether using cascade attention is beneficial.
|
|
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
|
# is likely to be faster since it saves memory bandwidth.
|
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
|
# The criteria for using FlashDecoding can be found in the following link:
|
|
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
|
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
|
and not use_alibi and np.all(query_lens == 1))
|
|
if not use_flash_decoding:
|
|
# Use cascade attention.
|
|
return True
|
|
|
|
# 2. When FlashDecoding is used for normal attention, it is not clear
|
|
# whether cascade attention is beneficial, because FlashDecoding can
|
|
# launch more CTAs than cascade attention.
|
|
# We use a simple performance model to compare the two methods.
|
|
# NOTE(woosuk): The performance model is very rough and may not be
|
|
# accurate.
|
|
num_tokens = num_reqs
|
|
# NOTE(woosuk): These are default tile sizes. flash-attn might use
|
|
# different tile sizes (e.g., 64 or 256) depending on the configuration.
|
|
q_tile_size = 128
|
|
kv_tile_size = 128
|
|
num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)
|
|
|
|
cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
|
|
cascade_waves = cdiv(cascade_ctas, num_sms)
|
|
cascade_time = cascade_waves * num_prefix_tiles
|
|
|
|
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
|
cdiv(num_queries_per_kv, q_tile_size))
|
|
flash_decoding_ctas *= num_prefix_tiles
|
|
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
|
|
|
# Use cascade attention if it is faster than FlashDecoding.
|
|
return cascade_time < flash_decoding_time
|
|
|
|
|
|
def cascade_attention(
|
|
output: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key_cache: torch.Tensor,
|
|
value_cache: torch.Tensor,
|
|
cu_query_lens: torch.Tensor,
|
|
max_query_len: int,
|
|
cu_prefix_query_lens: torch.Tensor,
|
|
prefix_kv_lens: torch.Tensor,
|
|
suffix_kv_lens: torch.Tensor,
|
|
max_kv_len: int,
|
|
softmax_scale: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
sliding_window: tuple[int, int],
|
|
logits_soft_cap: float,
|
|
block_table: torch.Tensor,
|
|
common_prefix_len: int,
|
|
fa_version: int,
|
|
q_descale: Optional[torch.Tensor] = None,
|
|
k_descale: Optional[torch.Tensor] = None,
|
|
v_descale: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
|
|
# TODO: Support sliding window.
|
|
assert sliding_window == (-1, -1), (
|
|
"Cascade attention does not support sliding window.")
|
|
|
|
num_tokens = query.shape[0]
|
|
block_size = key_cache.shape[-3]
|
|
assert common_prefix_len % block_size == 0
|
|
num_common_kv_blocks = common_prefix_len // block_size
|
|
assert num_common_kv_blocks > 0
|
|
descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
|
|
|
|
# Process shared prefix.
|
|
prefix_output, prefix_lse = flash_attn_varlen_func(
|
|
q=query,
|
|
k=key_cache,
|
|
v=value_cache,
|
|
cu_seqlens_q=cu_prefix_query_lens,
|
|
seqused_k=prefix_kv_lens,
|
|
max_seqlen_q=num_tokens,
|
|
max_seqlen_k=common_prefix_len,
|
|
softmax_scale=softmax_scale,
|
|
causal=False,
|
|
window_size=sliding_window,
|
|
block_table=block_table[:1],
|
|
softcap=logits_soft_cap,
|
|
return_softmax_lse=True,
|
|
fa_version=fa_version,
|
|
q_descale=q_descale.expand(descale_shape)
|
|
if q_descale is not None else None,
|
|
k_descale=k_descale.expand(descale_shape)
|
|
if k_descale is not None else None,
|
|
v_descale=v_descale.expand(descale_shape)
|
|
if v_descale is not None else None,
|
|
)
|
|
|
|
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
|
|
|
# Process suffix per query.
|
|
suffix_output, suffix_lse = flash_attn_varlen_func(
|
|
q=query,
|
|
k=key_cache,
|
|
v=value_cache,
|
|
cu_seqlens_q=cu_query_lens,
|
|
seqused_k=suffix_kv_lens,
|
|
max_seqlen_q=max_query_len,
|
|
max_seqlen_k=max_kv_len - common_prefix_len,
|
|
softmax_scale=softmax_scale,
|
|
causal=True,
|
|
window_size=sliding_window,
|
|
block_table=block_table[:, num_common_kv_blocks:],
|
|
softcap=logits_soft_cap,
|
|
return_softmax_lse=True,
|
|
fa_version=fa_version,
|
|
q_descale=q_descale.expand(descale_shape)
|
|
if q_descale is not None else None,
|
|
k_descale=k_descale.expand(descale_shape)
|
|
if k_descale is not None else None,
|
|
v_descale=v_descale.expand(descale_shape)
|
|
if v_descale is not None else None,
|
|
)
|
|
|
|
# Merge prefix and suffix outputs, and store the result in output.
|
|
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
|
suffix_lse)
|