Remove V0 attention backends (#25351)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Woosuk Kwon 2025-09-21 16:03:28 -07:00 committed by yewentao256
parent 319966a678
commit a815d820ee
28 changed files with 143 additions and 7376 deletions

View File

@ -5,7 +5,6 @@ from urllib.request import urlopen
from vllm import LLM, SamplingParams
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"

View File

@ -334,8 +334,9 @@ else:
[7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
@pytest.mark.parametrize("backend",
[_Backend.FLASHINFER] if current_platform.is_cuda()
else [_Backend.TRITON_ATTN_VLLM_V1])
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])

View File

@ -18,7 +18,7 @@ if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.attention.backends.xformers import _make_alibi_bias
from tests.kernels.utils import make_alibi_bias
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
@ -429,8 +429,8 @@ def test_multi_query_kv_attention(
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.

View File

@ -67,6 +67,7 @@ def generate_params():
return params
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
def test_env(

View File

@ -11,7 +11,7 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
from vllm.attention.backends.xformers import _make_alibi_bias
from tests.kernels.utils import make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
@ -470,7 +470,7 @@ def test_contexted_kv_attention_alibi(
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
@ -479,7 +479,7 @@ def test_contexted_kv_attention_alibi(
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
# modified from: vllm/v1/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len

View File

@ -16,6 +16,7 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")

View File

@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
Construct the backend instance determined by the backend_name string
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
* Backend instance
'''
if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend
return XFormersBackend()
elif backend_name == STR_FLASH_ATTN_VAL:
from vllm.attention.backends.flash_attn import FlashAttentionBackend
if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"):
from vllm.v1.attention.backends.xformers import (
XFormersAttentionBackend)
return XFormersAttentionBackend()
if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"):
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
return FlashAttentionBackend()
if backend_name == "TRITON_ATTN_VLLM_V1":
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionBackend)
return TritonAttentionBackend()
if backend_name == "FLEX_ATTENTION":
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionBackend)
return FlexAttentionBackend()
if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"):
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
return TorchSDPABackend()
if backend_name == "FLASHINFER":
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend()
raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")
def make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[Any]:
"""Create ALiBi biases compatible with xFormers attention tests."""
from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
if alibi_slopes is None:
return [None for _ in seq_lens]
attn_biases: list[Any] = []
num_heads = alibi_slopes.shape[0]
assert num_heads >= num_kv_heads, (
"ALiBi slopes expect at least as many heads as KV heads")
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
bias_tensor = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias_tensor.mul_(alibi_slopes[:, None, None])
attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
return attn_biases
def _make_metadata_tensors(
seq_lens: Optional[list[int]],
context_lens: Optional[list[int]],

View File

@ -78,9 +78,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
return
if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"):
# Phi4FlashForCausalLM and MotifForCausalLM
# only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
pytest.skip(
"Differential Flash Attention backend has been removed.")
if model_arch == "GptOssForCausalLM":
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when

View File

@ -1,931 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from einops import rearrange
from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.flash_attn import FlashAttentionBackend
# yapf: enable
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set,
is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
logger = init_logger(__name__)
class DifferentialFlashAttentionBackend(AttentionBackend):
accept_output_buffer = False
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2"
return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
@staticmethod
def get_name() -> str:
return "DIFFERENTIAL_FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]:
return DifferentialFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]:
return DifferentialFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]:
return DifferentialFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
class DifferentialFlashAttentionMetadata(AttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional[
"DifferentialFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
"DifferentialFlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
# Cross-layer shared attention block tables
cross_layer_shared_block_tables: Optional[torch.Tensor] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return is_all_encoder_attn_metadata_set(self)
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return is_all_cross_attn_metadata_set(self)
@property
def prefill_metadata(
self) -> Optional["DifferentialFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
cross_layer_shared_block_tables = (
None if self.cross_layer_shared_block_tables is None else
self.cross_layer_shared_block_tables[:self.num_prefills])
self._cached_prefill_metadata = DifferentialFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
def decode_metadata(
self) -> Optional["DifferentialFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
cross_layer_shared_block_tables = (
None if self.cross_layer_shared_block_tables is None else
self.cross_layer_shared_block_tables[self.num_prefills:])
self._cached_decode_metadata = DifferentialFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
class DifferentialFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
def __init__(self, input_builder):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.cross_layer_shared_block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
# TODO: add support for chunked prefill and prefix caching.
assert not chunked_prefill_enabled, \
"chunked prefill is not supported for now"
assert not prefix_cache_hit, "prefix caching is not supported for now"
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
cross_layer_shared_block_table = []
if prefix_cache_hit:
cross_layer_shared_block_table = block_tables[seq_id]
elif block_tables is not None:
if curr_sliding_window_block == 0:
cross_layer_shared_block_table = block_tables[seq_id]
else:
cross_layer_shared_block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.cross_layer_shared_block_tables.append(
cross_layer_shared_block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(self, num_seqs: int,
block_tables: List[List[int]],
graph_block_tables) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
max_batch_size, max_blocks = graph_block_tables.shape
assert max_batch_size >= num_seqs
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
graph_block_tables = graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
self.cross_layer_shared_block_tables.extend([] *
cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables, self.runner.graph_block_tables)
cross_layer_shared_block_tables = \
self._get_graph_runner_block_tables(
num_seqs, self.cross_layer_shared_block_tables,
self.runner.cross_layer_shared_graph_block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
cross_layer_shared_block_tables = make_tensor_with_pad(
self.cross_layer_shared_block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return DifferentialFlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
use_cuda_graph=use_captured_graph,
)
class DifferentialFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
differential_flash_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
if differential_flash_attention_config is None:
differential_flash_attention_config = {}
self.differential_flash_attention_config = \
differential_flash_attention_config
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if is_quantized_kv_cache(self.kv_cache_dtype) and (
not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError(
f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
self.lambda_full = None
self.subln = self.differential_flash_attention_config["subln"]
def split_heads(self, x):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
x = rearrange(x, "... (H two) D -> ... H two D", two=2)
x1 = x[..., 0, :]
x2 = x[..., 1, :]
return x1.contiguous(), x2.contiguous()
def split_kv_cache(self, x):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
if x.numel() == 0:
return torch.empty(0), torch.empty(0)
x1, x2 = x[0], x[1]
return x1, x2
def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor,
value: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata):
if kv_cache.numel() > 0 and key is not None and value is not None:
updated_slot_mapping = attn_metadata.slot_mapping
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(),
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def forward_generate_kv_cache(
self, query: torch.Tensor, key: Optional[torch.Tensor],
value: Optional[torch.Tensor], k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor:
head_size = self.head_size
num_heads = self.num_heads // 2
num_kv_heads = self.num_kv_heads // 2
query = query.view(-1, num_heads, head_size)
if key is not None:
assert value is not None
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
else:
assert value is None
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[
0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch"
assert value.shape[
0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch"
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens, "query shape mismatch"
assert decode_query.shape[
0] == num_decode_tokens, "decode query shape mismatch"
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if k_cache.numel() == 0 \
or prefill_meta.block_tables is None \
or prefill_meta.block_tables.numel() == 0:
# normal attention
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
)
assert prefill_output.shape == output[:
num_prefill_tokens].shape
output[:num_prefill_tokens] = prefill_output
else:
raise Exception("prefix caching not supported")
if decode_meta := attn_metadata.decode_metadata:
block_tables_arg = decode_meta.block_tables
try:
output[num_prefill_tokens:] = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
except Exception as e:
logger.error("Error in PagedAttention.forward_decode: %s",
str(e))
raise e
# Reshape the output tensor.
return output.view(-1, num_heads, head_size)
def forward_with_kv_cache_only(
self,
query: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
):
if not attn_metadata.decode_metadata:
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
else:
block_tables_arg = attn_metadata.block_tables
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=k_cache,
v_cache=v_cache,
block_table=block_tables_arg,
cache_seqlens=attn_metadata.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
return output
def forward(
self,
layer: AttentionLayer,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
layer: Attention layer instance.
q: Query tensor with shape = [num_tokens, num_heads, head_size]
k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Output tensor with shape [num_tokens, num_heads, head_size]
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]
lambda_q1 = self.differential_flash_attention_config["lambda_q1"]
lambda_k1 = self.differential_flash_attention_config["lambda_k1"]
lambda_q2 = self.differential_flash_attention_config["lambda_q2"]
lambda_k2 = self.differential_flash_attention_config["lambda_k2"]
lambda_1 = torch.exp(
torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(
torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q)
self.lambda_full = lambda_1 - lambda_2 + self.lambda_init
if not self.used_shared_kv_cache: # need to generate kv-cache
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
q1, q2 = self.split_heads(q)
k1, k2 = self.split_heads(k)
v1, v2 = self.split_heads(v)
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
# Split by half along the first dimension.
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
if kv_cache1.numel() != 0:
self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata)
self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata)
key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
else:
key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
else: # reuse the kv cache, full attention
q = q.view(-1, self.num_heads, self.head_size)
q1, q2 = self.split_heads(q)
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
attn11 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache1,
attn_metadata)
attn12 = self.forward_with_kv_cache_only(q1, key_cache1,
value_cache2,
attn_metadata)
attn11 = attn11.view(q1.shape)
attn12 = attn12.view(q1.shape)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache1,
attn_metadata)
attn22 = self.forward_with_kv_cache_only(q2, key_cache2,
value_cache2,
attn_metadata)
attn21 = attn21.view(q2.shape)
attn22 = attn22.view(q2.shape)
attn2 = torch.cat([attn21, attn22], dim=-1)
attn = attn1 - self.lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
# reshape back to 2 * num_head
attn_output = rearrange(attn,
"... H (two D) -> ... (H two) D",
two=2)
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
return attn_output

File diff suppressed because it is too large Load Diff

View File

@ -1,929 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import Dict, List, Optional, Tuple, Type
import torch
from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass
class FlashAttentionMetadata(AttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return is_all_encoder_attn_metadata_set(self)
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return is_all_cross_attn_metadata_set(self)
@property
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
self._cached_prefill_metadata = FlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
self._cached_decode_metadata = FlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"FLASH_ATTN backend.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if is_quantized_kv_cache(self.kv_cache_dtype) and (
not self.kv_cache_dtype.startswith("fp8")
or not flash_attn_supports_fp8()):
raise NotImplementedError(
f"FlashAttention does not support {self.kv_cache_dtype} "
"kv-cache on this device "
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
"base dtype bfloat16")
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale
window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")
if fp8_attention and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support FP8 kv-cache on this device.")
if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if fp8_attention:
kv_cache = kv_cache.view(torch.float8_e4m3fn)
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
if fp8_attention:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:]
decode_output = output[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
if fp8_attention:
num_kv_tokens, num_kv_heads, head_size = key.shape
key, _ = ops.scaled_fp8_quant(
key.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._k_scale)
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
value, _ = ops.scaled_fp8_quant(
value.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._v_scale)
value = value.reshape(
(num_kv_tokens, num_kv_heads, head_size))
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None
assert prefill_meta.query_start_loc is not None
max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
# use only for actual varlen decoding
if decode_meta.max_decode_query_len > 1:
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1"
)
assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg,
_,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=decode_output.unsqueeze(1),
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output
def _get_query_key_seq_metadata(
attn_metadata: FlashAttentionMetadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
Raises:
AttributeError: If an invalid attention type is provided.
"""
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.seq_start_loc, max_seq_len)
elif attn_type == AttentionType.ENCODER_DECODER:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_start_loc, max_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER:
# For encoder attention both the query and the key are same i.e. the
# encoder sequence.
return (attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_start_loc,
attn_metadata.max_encoder_seq_len)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _get_causal_option(attn_type: str) -> bool:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return not (attn_type == AttentionType.ENCODER
or attn_type == AttentionType.ENCODER_ONLY
or attn_type == AttentionType.ENCODER_DECODER)

View File

@ -1,227 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
MLACommonState)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA"
@staticmethod
def get_impl_cls() -> Type["FlashMLAImpl"]:
return FlashMLAImpl
@staticmethod
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["FlashMLAState"]:
return FlashMLAState
@dataclass
class FlashMLAMetadata(MLACommonMetadata):
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None
decode_num_splits: Optional[torch.Tensor] = None
@property
def decode_metadata(self):
decode_metadata = super().decode_metadata
# TODO: cache assignment?
if decode_metadata is not None:
decode_metadata.decode_tile_scheduler_metadata=\
self.decode_tile_scheduler_metadata
decode_metadata.decode_num_splits=\
self.decode_num_splits
return decode_metadata
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
batch_size)
if m.num_decode_tokens > 0:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens_tensor[m.num_prefills:],
self.num_q_heads,
1, # MQA for the decode path
)
return m
class FlashMLAState(MLACommonState[FlashMLAMetadata]):
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
@contextmanager
def graph_capture(self, max_batch_size: int):
# Run a dummy `get_mla_metadata` so we can get the right shapes
self._graph_decoder_tile_scheduler_metadata, \
self._graph_decode_num_splits = get_mla_metadata(
torch.ones(
max_batch_size, dtype=torch.int32, device=self.runner.device),
self.num_q_heads,
1, # MQA for the decode path
)
with super().graph_capture(max_batch_size):
yield
del self._graph_decoder_tile_scheduler_metadata
del self._graph_decode_num_splits
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
metadata = super().graph_capture_get_metadata_for_batch(
batch_size, is_encoder_decoder_model)
assert metadata.num_decode_tokens > 0
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
self._graph_seq_lens[:batch_size],
self.num_q_heads,
1, # MQA for the decode path
)
self._graph_decoder_tile_scheduler_metadata.copy_(
decoder_tile_scheduler_metadata)
self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)
metadata.decode_tile_scheduler_metadata=\
self._graph_decoder_tile_scheduler_metadata
metadata.decode_num_splits=\
self._graph_decode_num_splits[:batch_size + 1]
return metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = super().get_graph_input_buffers(
attn_metadata, is_encoder_decoder_model)
input_buffers["decode_tile_scheduler_metadata"] = \
attn_metadata.decode_metadata.decode_tile_scheduler_metadata
input_buffers["decode_num_splits"] = \
attn_metadata.decode_metadata.decode_num_splits
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
is_encoder_decoder_model)
input_buffers["decode_tile_scheduler_metadata"].copy_(
attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
input_buffers["decode_num_splits"].copy_(
attn_metadata.decode_metadata.decode_num_splits)
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str] = None,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
is_supported, reason = is_flashmla_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
)
return self._v_up_proj(o)

File diff suppressed because it is too large Load Diff

View File

@ -1,407 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional, Type, Union
import torch
import vllm.envs as envs
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
MLACommonState)
from vllm.attention.backends.utils import (compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
get_aiter_mla_metadata)
def is_aiter_mla_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_MLA
class AiterMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
@staticmethod
def get_impl_cls() -> Type["AiterMLAImpl"]:
return AiterMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
return AiterMLAMetadata
@staticmethod
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
return AiterMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["AiterMLAState"]:
return AiterMLAState
@dataclass
class AiterMLAMetadata(MLACommonMetadata):
# The following 5 tensors are for current version of AITER MLA
block_table_bound: Optional[torch.Tensor] = None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr: Optional[torch.Tensor] = None
# The page indices of the paged kv cache
paged_kv_indices: Optional[torch.Tensor] = None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens: Optional[torch.Tensor] = None
# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr: Optional[torch.Tensor] = None
@property
def prefill_metadata(self):
prefill_metadata = super().prefill_metadata
self._cached_prefill_metadata = prefill_metadata
if prefill_metadata is not None:
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
prefill_metadata.paged_kv_indices = self.paged_kv_indices
prefill_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
prefill_metadata.block_table_bound = self.block_table_bound
prefill_metadata.qo_indptr = self.qo_indptr
# update the cache
self._cached_prefill_metadata = self.__class__(
**prefill_metadata.__dict__)
return self._cached_prefill_metadata
@property
def decode_metadata(self):
decode_metadata = super().decode_metadata
self._cached_decode_metadata = decode_metadata
if decode_metadata is not None:
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
decode_metadata.paged_kv_indices = self.paged_kv_indices
decode_metadata\
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
decode_metadata.block_table_bound = self.block_table_bound
decode_metadata.qo_indptr = self.qo_indptr
# update the cache
self._cached_decode_metadata = self.__class__(
**decode_metadata.__dict__)
return self._cached_decode_metadata
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
def __init__(self, input_builder):
super().__init__(input_builder)
assert self.block_size == 1, "AITER MLA requires only block size 1."
def prepare(self):
super().prepare()
self.paged_kv_indices: list[int] = []
self.paged_kv_indptr: list[int] = [0]
self.paged_kv_last_page_lens: list[int] = []
self.total_blocks = 0
self.qo_indptr: list[int] = [0]
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
if is_profile_run:
return
# Update paged_kv_* tensors only for non-profile run
block_table = block_tables[seq_id]
self._update_paged_kv_tensors(block_table, seq_len)
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
self.qo_indptr.append(self.qo_indptr[-1] + 1)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_lens.append(last_page_len)
def build(self, seq_lens: list[int], query_lens: list[int],
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
batch_size)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
if use_captured_graph:
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
last_qo_indptr = self.qo_indptr[-1]
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
# For current version of AITER MLA
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device=device,
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device=device,
dtype=torch.int)
paged_kv_last_page_lens_tensor = torch.tensor(
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device=device,
dtype=torch.int)
qo_indptr = torch.tensor(self.qo_indptr,
device=device,
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_lens_tensor = None
block_table_bound_tensor = None
qo_indptr = None
metadata.paged_kv_indptr = paged_kv_indptr_tensor
metadata.paged_kv_indices = paged_kv_indices_tensor
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
metadata.block_table_bound = block_table_bound_tensor
metadata.qo_indptr = qo_indptr
return metadata
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
@contextmanager
def graph_capture(self, max_batch_size: int):
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
get_aiter_mla_metadata(
max_batch_size=max_batch_size,
block_size=self.runner.block_size,
max_block_per_batch=\
self.runner.get_max_block_per_batch(),
device=self.runner.device)
self._paged_kv_indices_tensor = kv_indices
self._paged_kv_indptr_tensor = kv_indptr
self._paged_kv_last_page_lens_tensor = last_page_lens
self._qo_indptr_tensor = qo_indptr
with super().graph_capture(max_batch_size):
yield
del self._paged_kv_indices_tensor
del self._paged_kv_indptr_tensor
del self._paged_kv_last_page_lens_tensor
del self._qo_indptr_tensor
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
metadata = super().graph_capture_get_metadata_for_batch(
batch_size, is_encoder_decoder_model)
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
paged_kv_indices = self._paged_kv_indices_tensor
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
batch_size]
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
metadata.paged_kv_indptr = paged_kv_indptr
metadata.paged_kv_indices = paged_kv_indices
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
metadata.qo_indptr = qo_indptr
return metadata
def get_graph_input_buffers(self,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
input_buffers = super().get_graph_input_buffers(
attn_metadata, is_encoder_decoder_model)
input_buffers[
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
input_buffers[
"paged_kv_indices"] = attn_metadata.\
decode_metadata.paged_kv_indices
input_buffers[
"paged_kv_last_page_lens"] = attn_metadata.\
decode_metadata.paged_kv_last_page_lens
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata: AiterMLAMetadata,
is_encoder_decoder_model: bool = False):
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
is_encoder_decoder_model)
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
0]
input_buffers["paged_kv_indptr"].copy_(
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
input_buffers["paged_kv_last_page_lens"].copy_(
attn_metadata.decode_metadata.paged_kv_last_page_lens,
non_blocking=True)
input_buffers["qo_indptr"].copy_(
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
from aiter import flash_attn_varlen_func
self.flash_attn_varlen_func = flash_attn_varlen_func
def _flash_attn_varlen_diff_headdims(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
softmax_scale: float, return_softmax_lse: bool,
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
output = self.flash_attn_varlen_func(
q,
k,
v,
**kwargs,
)
return output
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AiterMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.empty(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.qo_indptr,
attn_metadata.max_query_len,
attn_metadata.paged_kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
return self._v_up_proj(o)

View File

@ -1,953 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass
from functools import cache
from typing import List, Optional, Tuple, Type
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 256
@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
and envs.VLLM_ROCM_USE_AITER \
@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if is_rocm_aiter_paged_attn_enabled():
# Import AITERPagedAttention only when the flag is enabled
from vllm.attention.ops.rocm_aiter_paged_attn import (
AITERPagedAttention)
return AITERPagedAttention()
return PagedAttention()
class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_FLASH"
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
paged_attn = _get_paged_attn_module()
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
paged_attn = _get_paged_attn_module()
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
paged_attn = _get_paged_attn_module()
paged_attn.copy_blocks(kv_caches, src_to_dists)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.block_tables is not None
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if self._cached_decode_metadata.query_start_loc is not None:
qs = self._cached_decode_metadata.query_start_loc
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata
class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
_metadata_cls = ROCmFlashAttentionMetadata
def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype,
seq_lens: Optional[List[int]],
make_attn_mask: bool = True) -> List[torch.Tensor]:
attn_biases = []
if seq_lens:
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat(
(num_heads, 1, 1)).to(alibi_slopes.device)
bias.mul_(alibi_slopes[:, None, None])
if make_attn_mask:
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
alibi_slopes.device)
attn_biases.append((bias + inf_mask).to(dtype))
else:
attn_biases.append(bias.to(dtype))
return attn_biases
def _get_seq_len_block_table_args(
attn_metadata: ROCmFlashAttentionMetadata,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Encoder-only attn -> select prefill sequence lengths with
bidirectional attention
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention, encoder-only
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
* Causal masking flag
'''
if attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
causal_mask = False
# No block tables associated with encoder attention
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.encoder_seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, we use the prefill sequence lengths
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
# Encoder-only models typically use bidirectional attention
causal_mask = False
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert attn_metadata.seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
query_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
max_seq_len = attn_metadata.max_prefill_seq_len
causal_mask = True
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
max_seq_len, attn_metadata.seq_lens, causal_mask)
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens_tensor is not None
query_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
device=attn_metadata.encoder_seq_lens_tensor.device,
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
assert attn_metadata.encoder_seq_lens is not None
assert attn_metadata.seq_lens_tensor is not None
key_seq_start_loc = torch.tensor(
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
device=attn_metadata.seq_lens_tensor.device,
dtype=attn_metadata.seq_lens_tensor.dtype)
causal_mask = False
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (query_start_loc, attn_metadata.max_prefill_seq_len,
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
attn_metadata.seq_lens, causal_mask)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class ROCmFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"ROCM_FLASH backend.")
if use_irope:
logger.warning_once(
"Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context.")
if use_irope:
logger.warning(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context.")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0
else:
self.logits_soft_cap = logits_soft_cap
self.attn_type = attn_type
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.paged_attn_module = _get_paged_attn_module()
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
)
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
" logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.triton_attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
if not current_platform.has_device_capability(90):
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
if self.use_naive_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Naive FlashAttention does not support "
"attention logits soft capping.")
self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")
self.aiter_kv_scales_initialized = False
self.force_fp8_attention = (
get_current_vllm_config() is not None
and get_current_vllm_config().model_config.override_attention_dtype
== "fp8")
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
return (x[:, :,
None, :].expand(tokens, n_kv_heads, n_rep,
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, quant_key: QuantKey):
if self.use_triton_flash_attn:
return quant_key == kFp8StaticTensorSym
# Only supported in the Triton backend
return False
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
cross-attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
* ENCODER_ONLY: bidirectional attention with no KV caching;
use prefill sequence attributes
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None and not self.use_triton_flash_attn:
raise NotImplementedError(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
paged_attn = self.paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
block_size = kv_cache.shape[2] // (self.num_kv_heads *
self.head_size)
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
self.aiter_kv_scales_initialized = True
k_scale.fill_(layer._k_scale.item())
v_scale.fill_(layer._v_scale.item())
layer._k_scale = k_scale
layer._v_scale = v_scale
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = paged_attn.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if key is not None and value is not None:
# Reshape the input keys and values and store them in the
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
paged_attn.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping
if self.attn_type != AttentionType.ENCODER_DECODER else
attn_metadata.cross_slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.attn_type != AttentionType.ENCODER:
num_prefill_tokens = attn_metadata.num_prefill_tokens
elif self.attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only models, all tokens are processed in one go
num_prefill_tokens = query.shape[0]
else:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
# For encoder-only and encoder models,
# we process all tokens at once
# For decoder and encoder-decoder,
# we may need to limit key/value to prefill tokens
if key is not None and value is not None \
and self.attn_type not in [AttentionType.ENCODER_DECODER,
AttentionType.ENCODER_ONLY]:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
# normal attention and DECODER
if self.attn_type == AttentionType.DECODER and (
kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = (prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
attn_metadata.seq_lens, True)
# prefix-enabled attention and ENCODER/ENCODER_DECODER
else:
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = _get_seq_len_block_table_args(
prefill_meta, self.attn_type)
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks = None
if self.use_triton_flash_attn:
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
use_fp8_scales = (layer._q_scale and layer._k_scale
and layer._v_scale and layer._prob_scale
and (self.kv_cache_dtype == "fp8"
or self.force_fp8_attention))
full_scales = (
layer._q_scale.item(), layer._k_scale.item(),
layer._v_scale.item(),
layer._prob_scale.item()) if use_fp8_scales else None
self.triton_attn_func(
query,
key,
value,
output[:num_prefill_tokens],
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
key_max_seq_len,
causal_mask,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
output_scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.alibi_slopes is not None:
attn_masks = _make_alibi_bias(
self.alibi_slopes,
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=causal_mask) # type: ignore
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention
self.sdpa_attn_func(
query,
key,
value,
output[:num_prefill_tokens],
query_seq_start_loc,
num_prefill_tokens,
self.num_heads,
self.head_size,
self.scale,
attn_masks,
)
else:
# upstream FA does not support an output arg, copy
output[:num_prefill_tokens] = self.fa_attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=key_max_seq_len,
softmax_scale=self.scale,
causal=causal_mask,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)
else:
# prefix-enabled attention -
# not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
# Decoding run.
# Whether to use rocm custom paged attention or not
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
from vllm.platforms.rocm import use_rocm_custom_paged_attention
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len)
assert max_seq_len is not None
max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
query_start_loc = None
ops.paged_attention_rocm(
output[num_prefill_tokens:],
exp_sums,
max_logits,
tmp_output,
decode_query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
query_start_loc,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
output_scale,
)
else:
# PagedAttention does not support fused quant, manually quantize
if output_scale is None:
out_pa = output[num_prefill_tokens:]
else:
out_pa = torch.empty_like(output[num_prefill_tokens:],
dtype=query.dtype)
out_pa[:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
)
# Manually perform quantization
if output_scale is not None:
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
out_q = output.view(-1, self.num_heads * self.head_size)
ops.scaled_fp8_quant(out_uq,
output_scale,
output=out_q[num_prefill_tokens:])
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _sdpa_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
seq_lens: torch.Tensor,
num_tokens: int,
num_heads: int,
head_size: int,
scale: float,
attn_masks: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
start = 0
assert output.shape == (num_tokens, num_heads, head_size)
assert output.dtype == query.dtype
assert output.device == query.device
for i, seq_len in enumerate(seq_lens):
end = start + seq_len
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH):
sub_out = torch.nn.functional.scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
value[:, start:end, :],
dropout_p=0.0,
is_causal=attn_masks is None,
attn_mask=attn_masks[i] if attn_masks else None,
scale=scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
return output

View File

@ -1,111 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import List, Optional, Type
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj(o)

View File

@ -338,10 +338,9 @@ class CommonAttentionState(AttentionState):
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
@ -360,10 +359,9 @@ class CommonAttentionState(AttentionState):
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers

View File

@ -1,805 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (AttentionBias,
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (
CommonAttentionState, CommonMetadataBuilder,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "XFORMERS"
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata
@staticmethod
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor] = None
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc: Optional[torch.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None
def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
@property
def is_all_encoder_attn_metadata_set(self):
'''
All attention metadata required for encoder attention is set.
'''
return is_all_encoder_attn_metadata_set(self)
@property
def is_all_cross_attn_metadata_set(self):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return is_all_cross_attn_metadata_set(self)
@property
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
# Recover cached prefill-phase attention
# metadata structure
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
# Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
# Recover cached decode-phase attention
# metadata structure
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if self._cached_decode_metadata.query_start_loc is not None:
qs = self._cached_decode_metadata.query_start_loc
self._cached_decode_metadata.query_start_loc = qs - qs[0]
return self._cached_decode_metadata
def _get_attn_bias(
attn_metadata: XFormersMetadata,
attn_type: str,
) -> Optional[AttentionBias]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
return attn_metadata.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _set_attn_bias(
attn_metadata: XFormersMetadata,
attn_bias: List[Optional[AttentionBias]],
attn_type: str,
) -> None:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER_DECODER:
attn_metadata.cross_attn_bias = attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
_metadata_cls = XFormersMetadata
class XFormersImpl(AttentionImpl[XFormersMetadata]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported in V0 "
"XFORMERS backend.")
if logits_soft_cap is not None:
logger.warning_once("XFormers does not support logits soft cap. "
"Outputs may be slightly off.")
if use_irope:
logger.warning_once(
"Using irope in XFormers is not supported yet, it will fall"
" back to global attention for long context.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
self.attn_type = attn_type
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
Used for encoder branch of encoder-decoder models.
* ENCODER_ONLY: no kv_caching, uses the normal attention
attributes (seq_lens/seq_lens_tensor/max_seq_len).
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")
attn_type = self.attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes.")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
else:
assert value is None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
if (key is not None) and (value is not None):
if attn_type == AttentionType.ENCODER_DECODER:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping = attn_metadata.cross_slot_mapping
else:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_query_tokens:]
# QKV for prefill.
query = query[:num_prefill_query_tokens]
if key is not None and value is not None:
key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens]
assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta, attn_type=attn_type)
assert out.shape == output[:num_prefill_query_tokens].shape
output[:num_prefill_query_tokens] = out
else:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have prefix attention.")
assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
out = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
layer._k_scale,
layer._v_scale,
)
assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_query_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
(
seq_lens_arg,
max_seq_len_arg,
block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
block_tables_arg,
seq_lens_arg,
max_seq_len_arg,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformers_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: XFormersMetadata,
attn_type: str = AttentionType.DECODER,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
"""
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None:
# Cross attention block of decoder branch of encoder-decoder
# model uses seq_lens for dec / encoder_seq_lens for enc
if (attn_type == AttentionType.ENCODER_DECODER):
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None
# Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens,
device=query.device)
# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
elif attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens, device=query.device)
# Self-attention block of encoder-only model just
# uses the seq_lens directly.
elif attn_type == AttentionType.ENCODER_ONLY:
assert attn_metadata.seq_lens is not None
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, device=query.device)
# Self-attention block of decoder branch just
# uses the seq_lens directly
elif attn_type == AttentionType.DECODER:
assert attn_metadata.seq_lens is not None
# Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens, device=query.device)
else:
raise ValueError("Unknown AttentionType: %s", attn_type)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_bias = [attn_bias]
else:
assert attn_type == AttentionType.DECODER
assert attn_metadata.seq_lens is not None
attn_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads, query.dtype,
attn_metadata.seq_lens)
_set_attn_bias(attn_metadata, attn_bias, attn_type)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=attn_bias[0],
p=0.0,
scale=self.scale)
return out.view_as(original_query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
assert attn_metadata.seq_lens is not None
output = torch.empty_like(original_query)
start = 0
for i, seq_len in enumerate(attn_metadata.seq_lens):
end = start + seq_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.view_as(original_query[start:end]))
start += seq_len
return output
def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: List[int],
) -> List[AttentionBias]:
attn_biases: List[AttentionBias] = []
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_biases

View File

@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
is_runai_obj_uri)
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType,
LazyLoader, common_broadcastable_dtype)
from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype
if TYPE_CHECKING:
from transformers import PretrainedConfig
@ -1103,10 +1102,6 @@ class ModelConfig:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"] = True
if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL:
raise ValueError("please set VLLM_ATTENTION_BACKEND to "
f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}")
def verify_with_parallel_config(
self,
parallel_config: ParallelConfig,

View File

@ -44,7 +44,7 @@ class model_aware_kv_ops_helper:
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
# For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim

View File

@ -44,8 +44,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
maybe_override_with_speculators)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
is_in_ray_actor)
from vllm.v1.sample.logits_processor import LogitsProcessor
# yapf: enable
@ -1163,17 +1163,6 @@ class EngineArgs:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
assert self.enforce_eager, (
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
sliding_window: Optional[int] = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding

View File

@ -529,7 +529,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
# - "FLASH_ATTN": use FlashAttention
# - "XFORMERS": use XFormers
# - "ROCM_FLASH": use ROCmFlashAttention
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA

View File

@ -53,13 +53,18 @@ class Mamba2Metadata:
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata)
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
elif current_platform.is_cuda():
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
return (FlashAttentionMetadata, XFormersMetadata,
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata)
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
PlaceholderAttentionMetadata)
if current_platform.is_cuda():
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata)
from vllm.v1.attention.backends.xformers import (
XFormersAttentionMetadata)
return (FlashAttentionMetadata, XFormersAttentionMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")

View File

@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module):
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
For more info see MLACommonImpl in:
vllm/v1/attention/backends/mla/utils.py
"""
def __init__(

View File

@ -226,8 +226,10 @@ class CudaPlatformBase(Platform):
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. "
"Set VLLM_USE_V1=1 to enable them.")
from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
@ -246,35 +248,17 @@ class CudaPlatformBase(Platform):
use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None)
def _get_version(name, import_suffix) -> str:
if use_v1:
logger.info_once(f"Using {name} backend on V1 engine.")
return f"vllm.v1.attention.backends.mla.{import_suffix}"
else:
logger.info_once(f"Using {name} backend.")
return f"vllm.attention.backends.{import_suffix}"
if use_cutlassmla:
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend")
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend")
if use_flashinfermla:
if use_v1:
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once(
"Using FlashInfer MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
)
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once("Using FlashInfer MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
if use_flashmla:
if block_size != 64:
logger.warning(
@ -282,20 +266,18 @@ class CudaPlatformBase(Platform):
" (currently only supports block size 64).",
block_size)
else:
return _get_version("FlashMLA", "flashmla.FlashMLABackend")
if use_flashattn:
if use_v1:
logger.info_once(
"Using FlashAttention MLA backend on V1 engine.")
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashattn_mla.FlashAttnMLABackend")
else:
logger.warning(
"FlashAttention MLA backend is only supported on V1 "
"engine.")
"flashmla.FlashMLABackend")
if use_flashattn:
logger.info_once(
"Using FlashAttention MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashattn_mla.FlashAttnMLABackend")
if use_triton:
return _get_version("Triton MLA",
"triton_mla.TritonMLABackend")
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
@ -382,78 +364,9 @@ class CudaPlatformBase(Platform):
)
return FLEX_ATTENTION_V1
# Backends for V0 engine
if selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
logger.info("Using DifferentialFlashAttention backend.")
return ("vllm.attention.backends.differential_flash_attn."
"DifferentialFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
# Volta and Turing NVIDIA GPUs.
logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
target_backend = _Backend.XFORMERS
elif dtype not in (torch.float16, torch.bfloat16):
logger.info(
"Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
target_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
target_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is
# installed.
if target_backend == _Backend.FLASH_ATTN:
try:
import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend, flash_attn_supports_fp8)
supported_sizes = \
FlashAttentionBackend.get_supported_head_sizes()
if head_size not in supported_sizes:
logger.info(
"Cannot use FlashAttention-2 backend for head size %d.",
head_size)
target_backend = _Backend.XFORMERS
fp8_kv_cache = (kv_cache_dtype is not None
and kv_cache_dtype.startswith("fp8"))
if (fp8_kv_cache and not flash_attn_supports_fp8()):
logger.info(
"Cannot use FlashAttention backend for FP8 KV cache.")
target_backend = _Backend.XFORMERS
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
target_backend = _Backend.XFORMERS
if target_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
logger.info("Using Flash Attention backend.")
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend.")
@classmethod
def get_punica_wrapper(cls) -> str:

View File

@ -191,6 +191,11 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
if use_mla:
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. "
"Set VLLM_USE_V1=1 to enable them.")
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
is_aiter_mla_enabled)
@ -201,39 +206,24 @@ class RocmPlatform(Platform):
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
if use_v1:
logger.info_once(
"Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
elif selected_backend == _Backend.ROCM_AITER_MLA \
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
if block_size == 1:
if use_v1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
logger.info("Using AITER MLA backend")
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
else:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
selected_backend = _Backend.ROCM_FLASH
f"does not support block size {block_size}.")
if selected_backend in (_Backend.ROCM_AITER_MLA,
_Backend.ROCM_AITER_MLA_VLLM_V1):
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
@ -245,14 +235,9 @@ class RocmPlatform(Platform):
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend.")
@classmethod
def set_device(cls, device: torch.device) -> None:

View File

@ -157,10 +157,8 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
MB_bytes = 1_000_000