mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 18:55:16 +08:00
Remove V0 attention backends (#25351)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
319966a678
commit
a815d820ee
@ -5,7 +5,6 @@ from urllib.request import urlopen
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN"
|
||||
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||||
|
||||
|
||||
|
||||
@ -334,8 +334,9 @@ else:
|
||||
[7, 256, 533] if current_platform.is_cuda() else [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
|
||||
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASHINFER] if current_platform.is_cuda()
|
||||
else [_Backend.TRITON_ATTN_VLLM_V1])
|
||||
@pytest.mark.parametrize(
|
||||
"split_attention",
|
||||
[False, True] if current_platform.is_rocm() else [False])
|
||||
|
||||
@ -18,7 +18,7 @@ if not current_platform.is_rocm():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from tests.kernels.utils import make_alibi_bias
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
@ -429,8 +429,8 @@ def test_multi_query_kv_attention(
|
||||
alibi_bias = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||
seq_lens)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||
seq_lens)
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
# Dynamic sequence length not supported with custom attn_bias.
|
||||
|
||||
@ -67,6 +67,7 @@ def generate_params():
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size",
|
||||
generate_params())
|
||||
def test_env(
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from vllm.attention.backends.xformers import _make_alibi_bias
|
||||
from tests.kernels.utils import make_alibi_bias
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
@ -470,7 +470,7 @@ def test_contexted_kv_attention_alibi(
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output_ref = torch.empty_like(output)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
@ -479,7 +479,7 @@ def test_contexted_kv_attention_alibi(
|
||||
# FIXME(DefTruth): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
# modified from: vllm/attention/backends/xformers.py#L343
|
||||
# modified from: vllm/v1/attention/backends/xformers.py#L343
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
|
||||
@ -16,6 +16,7 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
|
||||
|
||||
@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
||||
Construct the backend instance determined by the backend_name string
|
||||
argument.
|
||||
|
||||
"XFORMERS" -> construct xformers backend
|
||||
|
||||
TODO: other backends
|
||||
|
||||
Note: at time of writing the Attention wrapper automatically selects
|
||||
its own backend for Attention.forward(); so the backend instance which
|
||||
you generate with this function is not meant to be used for *running*
|
||||
@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
|
||||
|
||||
* Backend instance
|
||||
'''
|
||||
if backend_name == STR_XFORMERS_ATTN_VAL:
|
||||
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
||||
from vllm.attention.backends.xformers import XFormersBackend
|
||||
return XFormersBackend()
|
||||
elif backend_name == STR_FLASH_ATTN_VAL:
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"):
|
||||
from vllm.v1.attention.backends.xformers import (
|
||||
XFormersAttentionBackend)
|
||||
return XFormersAttentionBackend()
|
||||
if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"):
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
return FlashAttentionBackend()
|
||||
if backend_name == "TRITON_ATTN_VLLM_V1":
|
||||
from vllm.v1.attention.backends.triton_attn import (
|
||||
TritonAttentionBackend)
|
||||
return TritonAttentionBackend()
|
||||
if backend_name == "FLEX_ATTENTION":
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionBackend)
|
||||
return FlexAttentionBackend()
|
||||
if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"):
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend
|
||||
return TorchSDPABackend()
|
||||
if backend_name == "FLASHINFER":
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend()
|
||||
|
||||
raise AssertionError(
|
||||
f"Unrecognized backend_name {backend_name} for unit test")
|
||||
|
||||
|
||||
def make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: list[int],
|
||||
) -> list[Any]:
|
||||
"""Create ALiBi biases compatible with xFormers attention tests."""
|
||||
from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
|
||||
|
||||
if alibi_slopes is None:
|
||||
return [None for _ in seq_lens]
|
||||
|
||||
attn_biases: list[Any] = []
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
assert num_heads >= num_kv_heads, (
|
||||
"ALiBi slopes expect at least as many heads as KV heads")
|
||||
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
bias_tensor = torch.empty(
|
||||
1,
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias_tensor.mul_(alibi_slopes[:, None, None])
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_metadata_tensors(
|
||||
seq_lens: Optional[list[int]],
|
||||
context_lens: Optional[list[int]],
|
||||
|
||||
@ -78,9 +78,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
return
|
||||
|
||||
if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"):
|
||||
# Phi4FlashForCausalLM and MotifForCausalLM
|
||||
# only supports DIFFERENTIAL_FLASH_ATTN backend
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
|
||||
pytest.skip(
|
||||
"Differential Flash Attention backend has been removed.")
|
||||
if model_arch == "GptOssForCausalLM":
|
||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
|
||||
|
||||
@ -1,931 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DifferentialFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer = False
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2"
|
||||
return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "DIFFERENTIAL_FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]:
|
||||
return DifferentialFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]:
|
||||
return DifferentialFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]:
|
||||
return DifferentialFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DifferentialFlashAttentionMetadata(AttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional[
|
||||
"DifferentialFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional[
|
||||
"DifferentialFlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Cross-layer shared attention block tables
|
||||
cross_layer_shared_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Optional["DifferentialFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
cross_layer_shared_block_tables = (
|
||||
None if self.cross_layer_shared_block_tables is None else
|
||||
self.cross_layer_shared_block_tables[:self.num_prefills])
|
||||
|
||||
self._cached_prefill_metadata = DifferentialFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(
|
||||
self) -> Optional["DifferentialFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
cross_layer_shared_block_tables = (
|
||||
None if self.cross_layer_shared_block_tables is None else
|
||||
self.cross_layer_shared_block_tables[self.num_prefills:])
|
||||
self._cached_decode_metadata = DifferentialFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class DifferentialFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.cross_layer_shared_block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
# TODO: add support for chunked prefill and prefix caching.
|
||||
assert not chunked_prefill_enabled, \
|
||||
"chunked prefill is not supported for now"
|
||||
assert not prefix_cache_hit, "prefix caching is not supported for now"
|
||||
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
cross_layer_shared_block_table = []
|
||||
if prefix_cache_hit:
|
||||
cross_layer_shared_block_table = block_tables[seq_id]
|
||||
elif block_tables is not None:
|
||||
if curr_sliding_window_block == 0:
|
||||
cross_layer_shared_block_table = block_tables[seq_id]
|
||||
else:
|
||||
cross_layer_shared_block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.cross_layer_shared_block_tables.append(
|
||||
cross_layer_shared_block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(self, num_seqs: int,
|
||||
block_tables: List[List[int]],
|
||||
graph_block_tables) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
max_batch_size, max_blocks = graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
graph_block_tables = graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
|
||||
self.cross_layer_shared_block_tables.extend([] *
|
||||
cuda_graph_pad_size)
|
||||
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables, self.runner.graph_block_tables)
|
||||
cross_layer_shared_block_tables = \
|
||||
self._get_graph_runner_block_tables(
|
||||
num_seqs, self.cross_layer_shared_block_tables,
|
||||
self.runner.cross_layer_shared_graph_block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
cross_layer_shared_block_tables = make_tensor_with_pad(
|
||||
self.cross_layer_shared_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return DifferentialFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
cross_layer_shared_block_tables=cross_layer_shared_block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
differential_flash_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if differential_flash_attention_config is None:
|
||||
differential_flash_attention_config = {}
|
||||
self.differential_flash_attention_config = \
|
||||
differential_flash_attention_config
|
||||
self.used_shared_kv_cache = kv_sharing_target_layer_name is not None
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window - 1,
|
||||
0) if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=self.alibi_slopes is not None)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and (
|
||||
not self.kv_cache_dtype.startswith("fp8")
|
||||
or not flash_attn_supports_fp8()):
|
||||
raise NotImplementedError(
|
||||
f"FlashAttention does not support {self.kv_cache_dtype} "
|
||||
"kv-cache on this device "
|
||||
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.lambda_full = None
|
||||
self.subln = self.differential_flash_attention_config["subln"]
|
||||
|
||||
def split_heads(self, x):
|
||||
# split by num_heads, the stripe pattern is friendly to tensor parallel.
|
||||
x = rearrange(x, "... (H two) D -> ... H two D", two=2)
|
||||
x1 = x[..., 0, :]
|
||||
x2 = x[..., 1, :]
|
||||
return x1.contiguous(), x2.contiguous()
|
||||
|
||||
def split_kv_cache(self, x):
|
||||
# split by num_heads, the stripe pattern is friendly to tensor parallel.
|
||||
if x.numel() == 0:
|
||||
return torch.empty(0), torch.empty(0)
|
||||
|
||||
x1, x2 = x[0], x[1]
|
||||
return x1, x2
|
||||
|
||||
def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor,
|
||||
value: torch.Tensor, kv_cache: torch.Tensor,
|
||||
attn_metadata: DifferentialFlashAttentionMetadata):
|
||||
if kv_cache.numel() > 0 and key is not None and value is not None:
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def forward_generate_kv_cache(
|
||||
self, query: torch.Tensor, key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor], k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor:
|
||||
|
||||
head_size = self.head_size
|
||||
num_heads = self.num_heads // 2
|
||||
num_kv_heads = self.num_kv_heads // 2
|
||||
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[
|
||||
0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch"
|
||||
assert value.shape[
|
||||
0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch"
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens, "query shape mismatch"
|
||||
assert decode_query.shape[
|
||||
0] == num_decode_tokens, "decode query shape mismatch"
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if k_cache.numel() == 0 \
|
||||
or prefill_meta.block_tables is None \
|
||||
or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
prefill_output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
assert prefill_output.shape == output[:
|
||||
num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = prefill_output
|
||||
else:
|
||||
raise Exception("prefix caching not supported")
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
block_tables_arg = decode_meta.block_tables
|
||||
try:
|
||||
output[num_prefill_tokens:] = flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
).squeeze(1)
|
||||
except Exception as e:
|
||||
logger.error("Error in PagedAttention.forward_decode: %s",
|
||||
str(e))
|
||||
raise e
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, num_heads, head_size)
|
||||
|
||||
def forward_with_kv_cache_only(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_metadata: DifferentialFlashAttentionMetadata,
|
||||
):
|
||||
if not attn_metadata.decode_metadata:
|
||||
block_tables_arg = attn_metadata.cross_layer_shared_block_tables
|
||||
else:
|
||||
block_tables_arg = attn_metadata.block_tables
|
||||
|
||||
output = flash_attn_with_kvcache(
|
||||
q=query.unsqueeze(1),
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=attn_metadata.seq_lens_tensor,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
).squeeze(1)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: DifferentialFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
layer: Attention layer instance.
|
||||
q: Query tensor with shape = [num_tokens, num_heads, head_size]
|
||||
k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
|
||||
v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: KV cache tensor with shape
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size].
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
output: Output tensor with shape [num_tokens, num_heads, head_size]
|
||||
output_scale: Optional output scale tensor.
|
||||
output_block_scale: Optional output block scale tensor.
|
||||
NOTE: It in-place updates the output tensor.
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for DifferentialFlashAttentionImpl")
|
||||
|
||||
if self.lambda_full is None:
|
||||
self.lambda_init = self.differential_flash_attention_config[
|
||||
"lambda_init"]
|
||||
lambda_q1 = self.differential_flash_attention_config["lambda_q1"]
|
||||
lambda_k1 = self.differential_flash_attention_config["lambda_k1"]
|
||||
lambda_q2 = self.differential_flash_attention_config["lambda_q2"]
|
||||
lambda_k2 = self.differential_flash_attention_config["lambda_k2"]
|
||||
lambda_1 = torch.exp(
|
||||
torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q)
|
||||
lambda_2 = torch.exp(
|
||||
torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q)
|
||||
self.lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
|
||||
if not self.used_shared_kv_cache: # need to generate kv-cache
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
q1, q2 = self.split_heads(q)
|
||||
k1, k2 = self.split_heads(k)
|
||||
v1, v2 = self.split_heads(v)
|
||||
|
||||
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
|
||||
# Split by half along the first dimension.
|
||||
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
|
||||
assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous"
|
||||
assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous"
|
||||
|
||||
if kv_cache1.numel() != 0:
|
||||
self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata)
|
||||
self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata)
|
||||
|
||||
key_cache1, value_cache1 = self.split_kv_cache(kv_cache1)
|
||||
key_cache2, value_cache2 = self.split_kv_cache(kv_cache2)
|
||||
else:
|
||||
key_cache1, value_cache1 = torch.empty(0), torch.empty(0)
|
||||
key_cache2, value_cache2 = torch.empty(0), torch.empty(0)
|
||||
attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1,
|
||||
value_cache1,
|
||||
attn_metadata)
|
||||
attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1,
|
||||
value_cache2,
|
||||
attn_metadata)
|
||||
attn11 = attn11.view(q1.shape)
|
||||
attn12 = attn12.view(q1.shape)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2,
|
||||
value_cache1,
|
||||
attn_metadata)
|
||||
attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2,
|
||||
value_cache2,
|
||||
attn_metadata)
|
||||
attn21 = attn21.view(q2.shape)
|
||||
attn22 = attn22.view(q2.shape)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
|
||||
attn = attn1 - self.lambda_full * attn2
|
||||
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
# reshape back to 2 * num_head
|
||||
attn_output = rearrange(attn,
|
||||
"... H (two D) -> ... (H two) D",
|
||||
two=2)
|
||||
|
||||
else: # reuse the kv cache, full attention
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
q1, q2 = self.split_heads(q)
|
||||
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
|
||||
kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache)
|
||||
key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1]
|
||||
key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1]
|
||||
|
||||
attn11 = self.forward_with_kv_cache_only(q1, key_cache1,
|
||||
value_cache1,
|
||||
attn_metadata)
|
||||
attn12 = self.forward_with_kv_cache_only(q1, key_cache1,
|
||||
value_cache2,
|
||||
attn_metadata)
|
||||
attn11 = attn11.view(q1.shape)
|
||||
attn12 = attn12.view(q1.shape)
|
||||
attn1 = torch.cat([attn11, attn12], dim=-1)
|
||||
|
||||
attn21 = self.forward_with_kv_cache_only(q2, key_cache2,
|
||||
value_cache1,
|
||||
attn_metadata)
|
||||
attn22 = self.forward_with_kv_cache_only(q2, key_cache2,
|
||||
value_cache2,
|
||||
attn_metadata)
|
||||
attn21 = attn21.view(q2.shape)
|
||||
attn22 = attn22.view(q2.shape)
|
||||
attn2 = torch.cat([attn21, attn22], dim=-1)
|
||||
|
||||
attn = attn1 - self.lambda_full * attn2
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
# reshape back to 2 * num_head
|
||||
attn_output = rearrange(attn,
|
||||
"... H (two D) -> ... (H two) D",
|
||||
two=2)
|
||||
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
|
||||
return attn_output
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,929 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlashAttention."""
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import (
|
||||
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
encoder_seq_start_loc=self.encoder_seq_start_loc,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"FLASH_ATTN backend.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window - 1,
|
||||
0) if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.vllm_flash_attn_version = get_flash_attn_version(
|
||||
requires_alibi=self.alibi_slopes is not None)
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and (
|
||||
not self.kv_cache_dtype.startswith("fp8")
|
||||
or not flash_attn_supports_fp8()):
|
||||
raise NotImplementedError(
|
||||
f"FlashAttention does not support {self.kv_cache_dtype} "
|
||||
"kv-cache on this device "
|
||||
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by FlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_tokens, num_heads, head_size]
|
||||
kv_cache: KV cache tensor with shape
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size].
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
NOTE: It in-place updates the output tensor.
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
|
||||
assert (
|
||||
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
||||
"key/v_scale is only supported in FlashAttention 3 with "
|
||||
"base dtype bfloat16")
|
||||
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
kv_cache_dtype: str = self.kv_cache_dtype
|
||||
softmax_scale: float = self.scale
|
||||
window_size = self.sliding_window
|
||||
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
|
||||
if fp8_attention and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support FP8 kv-cache on this device.")
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
# We skip updating the KV cache under two conditions:
|
||||
# a. When the Attention Type is ENCODER. In this phase, we compute
|
||||
# only the encoder attention without updating the cache.
|
||||
# b. When both Key and Value are None. This occurs during
|
||||
# cross-attention computation in the decoding phase, where the
|
||||
# KV cache is already populated with the cross-attention
|
||||
# tensor. Thus, we skip cache updates during this time.
|
||||
if (attn_type != AttentionType.ENCODER) and (key is not None) and (
|
||||
value is not None):
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if fp8_attention:
|
||||
kv_cache = kv_cache.view(torch.float8_e4m3fn)
|
||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fn)
|
||||
|
||||
if fp8_attention:
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
decode_output = output[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
prefill_output = output[:num_prefill_query_tokens]
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
|
||||
_get_query_key_seq_metadata(prefill_meta, True, attn_type)
|
||||
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
if fp8_attention:
|
||||
num_kv_tokens, num_kv_heads, head_size = key.shape
|
||||
|
||||
key, _ = ops.scaled_fp8_quant(
|
||||
key.reshape((num_kv_tokens,
|
||||
num_kv_heads * head_size)).contiguous(),
|
||||
layer._k_scale)
|
||||
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
|
||||
|
||||
value, _ = ops.scaled_fp8_quant(
|
||||
value.reshape((num_kv_tokens,
|
||||
num_kv_heads * head_size)).contiguous(),
|
||||
layer._v_scale)
|
||||
value = value.reshape(
|
||||
(num_kv_tokens, num_kv_heads, head_size))
|
||||
|
||||
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
|
||||
flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=q_seq_start_loc,
|
||||
cu_seqlens_k=k_seq_start_loc,
|
||||
max_seqlen_q=q_seq_len,
|
||||
max_seqlen_k=k_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=_get_causal_option(attn_type),
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support prefix caching")
|
||||
assert prefill_meta.seq_lens is not None
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
flash_attn_varlen_func( # noqa
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens_tensor,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
# Use flash_attn_varlen_func kernel for speculative decoding
|
||||
# because different queries might have different lengths.
|
||||
|
||||
assert decode_meta.max_decode_query_len is not None
|
||||
# use only for actual varlen decoding
|
||||
if decode_meta.max_decode_query_len > 1:
|
||||
assert attn_type == AttentionType.DECODER, (
|
||||
"Only decoder-only models support max_decode_query_len > 1"
|
||||
)
|
||||
assert decode_meta.query_start_loc is not None
|
||||
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
|
||||
key.shape[1])
|
||||
flash_attn_varlen_func(
|
||||
q=decode_query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
seqused_k=decode_meta.seq_lens_tensor,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
out=decode_output,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
(
|
||||
seq_lens_arg,
|
||||
_,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
|
||||
flash_attn_with_kvcache(
|
||||
q=decode_query.unsqueeze(1),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
block_table=block_tables_arg,
|
||||
cache_seqlens=seq_lens_arg,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=decode_output.unsqueeze(1),
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def _get_query_key_seq_metadata(
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
is_prompt: bool,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
"""
|
||||
Returns sequence metadata for key and query based on the specified
|
||||
attention type and whether input is a prompt.
|
||||
|
||||
This function computes the starting locations and maximum sequence lengths
|
||||
for key and query sequences for different attention types.
|
||||
|
||||
Args:
|
||||
attn_metadata: The attention metadata object
|
||||
is_prompt (bool): A flag indicating if the input is a prompt
|
||||
attn_type (AttentionType): The type of attention being used.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing four integers:
|
||||
- Starting location for the query sequence.
|
||||
- Maximum sequence length for the query sequence.
|
||||
- Starting location for the key sequence.
|
||||
- Maximum sequence length for the key sequence.
|
||||
|
||||
Raises:
|
||||
AttributeError: If an invalid attention type is provided.
|
||||
"""
|
||||
if attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.seq_start_loc, max_seq_len)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
# This is cross attention between the where the key
|
||||
# is the precomputed encoder attention and query
|
||||
# is the input sequence.
|
||||
# Choose query max length based on whether it is prompt
|
||||
# or not.
|
||||
if is_prompt:
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
else:
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
return (attn_metadata.seq_start_loc, max_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
# For encoder attention both the query and the key are same i.e. the
|
||||
# encoder sequence.
|
||||
return (attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_start_loc,
|
||||
attn_metadata.max_encoder_seq_len)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert is_prompt, "Should not have decode for encoder only model."
|
||||
return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _get_causal_option(attn_type: str) -> bool:
|
||||
"""
|
||||
Determine whether the given attention type is suitable for causal
|
||||
attention mechanisms.
|
||||
|
||||
Args:
|
||||
attn_type (AttentionType): The type of attention being evaluated
|
||||
|
||||
Returns:
|
||||
bool: Returns `True` if the attention type is suitable for causal
|
||||
attention (i.e., not encoder, encoder-only, or encoder-decoder),
|
||||
otherwise returns `False`.
|
||||
"""
|
||||
return not (attn_type == AttentionType.ENCODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY
|
||||
or attn_type == AttentionType.ENCODER_DECODER)
|
||||
@ -1,227 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["FlashMLAState"]:
|
||||
return FlashMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]] = None
|
||||
decode_num_splits: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
# TODO: cache assignment?
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.decode_tile_scheduler_metadata=\
|
||||
self.decode_tile_scheduler_metadata
|
||||
decode_metadata.decode_num_splits=\
|
||||
self.decode_num_splits
|
||||
return decode_metadata
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
m = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
|
||||
if m.num_decode_tokens > 0:
|
||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||
get_mla_metadata(
|
||||
m.seq_lens_tensor[m.num_prefills:],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class FlashMLAState(MLACommonState[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, *args, **kwds):
|
||||
super().__init__(*args, **kwds)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
# Run a dummy `get_mla_metadata` so we can get the right shapes
|
||||
self._graph_decoder_tile_scheduler_metadata, \
|
||||
self._graph_decode_num_splits = get_mla_metadata(
|
||||
torch.ones(
|
||||
max_batch_size, dtype=torch.int32, device=self.runner.device),
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._graph_decoder_tile_scheduler_metadata
|
||||
del self._graph_decode_num_splits
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
assert metadata.num_decode_tokens > 0
|
||||
|
||||
decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata(
|
||||
self._graph_seq_lens[:batch_size],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
self._graph_decoder_tile_scheduler_metadata.copy_(
|
||||
decoder_tile_scheduler_metadata)
|
||||
self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits)
|
||||
|
||||
metadata.decode_tile_scheduler_metadata=\
|
||||
self._graph_decoder_tile_scheduler_metadata
|
||||
metadata.decode_num_splits=\
|
||||
self._graph_decode_num_splits[:batch_size + 1]
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers["decode_tile_scheduler_metadata"] = \
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata
|
||||
input_buffers["decode_num_splits"] = \
|
||||
attn_metadata.decode_metadata.decode_num_splits
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
input_buffers["decode_tile_scheduler_metadata"].copy_(
|
||||
attn_metadata.decode_metadata.decode_tile_scheduler_metadata)
|
||||
input_buffers["decode_num_splits"].copy_(
|
||||
attn_metadata.decode_metadata.decode_num_splits)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
is_supported, reason = is_flashmla_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashMLA with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=decode_meta.block_tables,
|
||||
cache_seqlens=decode_meta.seq_lens_tensor,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata,
|
||||
num_splits=decode_meta.decode_num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,407 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
MLACommonState)
|
||||
from vllm.attention.backends.utils import (compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd,
|
||||
get_aiter_mla_metadata)
|
||||
|
||||
|
||||
def is_aiter_mla_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER \
|
||||
and envs.VLLM_ROCM_USE_AITER_MLA
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AiterMLAMetadata"]:
|
||||
return AiterMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["AiterMLAState"]:
|
||||
return AiterMLAState
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata):
|
||||
# The following 5 tensors are for current version of AITER MLA
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: Optional[torch.Tensor] = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# This is just to make new AITER MLA API work
|
||||
# -- MTP support is not added yet.
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self):
|
||||
prefill_metadata = super().prefill_metadata
|
||||
self._cached_prefill_metadata = prefill_metadata
|
||||
|
||||
if prefill_metadata is not None:
|
||||
prefill_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
prefill_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
prefill_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
prefill_metadata.block_table_bound = self.block_table_bound
|
||||
prefill_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
**prefill_metadata.__dict__)
|
||||
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self):
|
||||
decode_metadata = super().decode_metadata
|
||||
|
||||
self._cached_decode_metadata = decode_metadata
|
||||
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.paged_kv_indptr = self.paged_kv_indptr
|
||||
decode_metadata.paged_kv_indices = self.paged_kv_indices
|
||||
decode_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
decode_metadata.block_table_bound = self.block_table_bound
|
||||
decode_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
**decode_metadata.__dict__)
|
||||
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
BLOCK_TABLE_EXTENDER: list[list[int]] = [[]]
|
||||
|
||||
def __init__(self, input_builder):
|
||||
super().__init__(input_builder)
|
||||
assert self.block_size == 1, "AITER MLA requires only block size 1."
|
||||
|
||||
def prepare(self):
|
||||
super().prepare()
|
||||
self.paged_kv_indices: list[int] = []
|
||||
self.paged_kv_indptr: list[int] = [0]
|
||||
self.paged_kv_last_page_lens: list[int] = []
|
||||
self.total_blocks = 0
|
||||
self.qo_indptr: list[int] = [0]
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks):
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
if is_profile_run:
|
||||
return
|
||||
|
||||
# Update paged_kv_* tensors only for non-profile run
|
||||
block_table = block_tables[seq_id]
|
||||
self._update_paged_kv_tensors(block_table, seq_len)
|
||||
|
||||
def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
|
||||
# Get the number of valid blocks based on sequence length.
|
||||
# If seq_len = 16, block_size = 16,
|
||||
# block_table_bound is 1 with 1 valid block.
|
||||
# If seq_len = 15, block_size = 16,
|
||||
# block_table_bound is 0 + 1 with 1 valid block.
|
||||
self.total_blocks += len(block_table)
|
||||
block_table_bound = seq_len // self.block_size + 1 \
|
||||
if seq_len % self.block_size != 0 \
|
||||
else seq_len // self.block_size
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
self.qo_indptr.append(self.qo_indptr[-1] + 1)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
self.paged_kv_last_page_lens.append(last_page_len)
|
||||
|
||||
def build(self, seq_lens: list[int], query_lens: list[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata:
|
||||
metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size,
|
||||
batch_size)
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
if use_captured_graph:
|
||||
last_paged_kv_indptr = self.paged_kv_indptr[-1]
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
|
||||
last_qo_indptr = self.qo_indptr[-1]
|
||||
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
|
||||
|
||||
# For current version of AITER MLA
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
# extend to the maximum number of blocks as returned by the
|
||||
# scheduler
|
||||
self.paged_kv_indices.extend(
|
||||
[0] * (self.total_blocks - len(self.paged_kv_indices)))
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
paged_kv_last_page_lens_tensor = torch.tensor(
|
||||
self.paged_kv_last_page_lens, device=device, dtype=torch.int)
|
||||
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
|
||||
qo_indptr = torch.tensor(self.qo_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_lens_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
qo_indptr = None
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr_tensor
|
||||
metadata.paged_kv_indices = paged_kv_indices_tensor
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
|
||||
metadata.block_table_bound = block_table_bound_tensor
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
|
||||
get_aiter_mla_metadata(
|
||||
max_batch_size=max_batch_size,
|
||||
block_size=self.runner.block_size,
|
||||
max_block_per_batch=\
|
||||
self.runner.get_max_block_per_batch(),
|
||||
device=self.runner.device)
|
||||
self._paged_kv_indices_tensor = kv_indices
|
||||
self._paged_kv_indptr_tensor = kv_indptr
|
||||
self._paged_kv_last_page_lens_tensor = last_page_lens
|
||||
self._qo_indptr_tensor = qo_indptr
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
|
||||
del self._paged_kv_indices_tensor
|
||||
del self._paged_kv_indptr_tensor
|
||||
del self._paged_kv_last_page_lens_tensor
|
||||
del self._qo_indptr_tensor
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
is_encoder_decoder_model: bool = False) -> AiterMLAMetadata:
|
||||
|
||||
metadata = super().graph_capture_get_metadata_for_batch(
|
||||
batch_size, is_encoder_decoder_model)
|
||||
|
||||
paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1]
|
||||
paged_kv_indices = self._paged_kv_indices_tensor
|
||||
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
|
||||
batch_size]
|
||||
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr
|
||||
metadata.paged_kv_indices = paged_kv_indices
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = super().get_graph_input_buffers(
|
||||
attn_metadata, is_encoder_decoder_model)
|
||||
input_buffers[
|
||||
'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr
|
||||
input_buffers[
|
||||
"paged_kv_indices"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_indices
|
||||
input_buffers[
|
||||
"paged_kv_last_page_lens"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_last_page_lens
|
||||
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
super().prepare_graph_input_buffers(input_buffers, attn_metadata,
|
||||
is_encoder_decoder_model)
|
||||
|
||||
num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[
|
||||
0]
|
||||
input_buffers["paged_kv_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True)
|
||||
input_buffers["paged_kv_indices"][:num_total_blocks].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True)
|
||||
input_buffers["paged_kv_last_page_lens"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_last_page_lens,
|
||||
non_blocking=True)
|
||||
input_buffers["qo_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
softmax_scale: float, return_softmax_lse: bool,
|
||||
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.qo_indptr,
|
||||
attn_metadata.max_query_len,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_lens)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
@ -1,953 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer ROCm GPUs."""
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8StaticTensorSym)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_paged_attn_enabled() -> bool:
|
||||
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
|
||||
and envs.VLLM_ROCM_USE_AITER \
|
||||
|
||||
|
||||
@cache
|
||||
def _get_paged_attn_module() -> PagedAttention:
|
||||
"""
|
||||
Initializes the appropriate PagedAttention module from `attention/ops`,
|
||||
which is used as helper function
|
||||
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
|
||||
|
||||
The choice of attention module depends on whether
|
||||
AITER paged attention is enabled:
|
||||
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
|
||||
- Otherwise, it defaults to using the original `PagedAttention`.
|
||||
"""
|
||||
if is_rocm_aiter_paged_attn_enabled():
|
||||
# Import AITERPagedAttention only when the flag is enabled
|
||||
from vllm.attention.ops.rocm_aiter_paged_attn import (
|
||||
AITERPagedAttention)
|
||||
return AITERPagedAttention()
|
||||
return PagedAttention()
|
||||
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_FLASH"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||
return ROCmFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return ROCmFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
|
||||
return ROCmFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
paged_attn = _get_paged_attn_module()
|
||||
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
paged_attn = _get_paged_attn_module()
|
||||
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
paged_attn = _get_paged_attn_module()
|
||||
paged_attn.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
|
||||
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class ROCmFlashAttentionMetadataBuilder(
|
||||
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
|
||||
|
||||
_metadata_cls = ROCmFlashAttentionMetadata
|
||||
|
||||
|
||||
def _make_alibi_bias(alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: Optional[List[int]],
|
||||
make_attn_mask: bool = True) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
if seq_lens:
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat(
|
||||
(num_heads, 1, 1)).to(alibi_slopes.device)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if make_attn_mask:
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
|
||||
alibi_slopes.device)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
else:
|
||||
attn_biases.append(bias.to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_seq_len_block_table_args(
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> tuple:
|
||||
'''
|
||||
The particular choice of sequence-length
|
||||
attributes which should be extracted from attn_metadata is dependent
|
||||
on the type of attention operation.
|
||||
|
||||
Decoder attn -> select entirely decoder self-attention-related fields
|
||||
Encoder/decoder cross-attn -> select encoder sequence lengths
|
||||
Encoder attn -> select encoder sequence lengths fields
|
||||
Encoder-only attn -> select prefill sequence lengths with
|
||||
bidirectional attention
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention op
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention, encoder-only
|
||||
|
||||
Returns:
|
||||
|
||||
* Appropriate sequence-lengths tensors for query and key
|
||||
* Appropriate max sequence-length scalar
|
||||
* Causal masking flag
|
||||
'''
|
||||
|
||||
if attn_type == AttentionType.ENCODER:
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# No block tables associated with encoder attention
|
||||
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.encoder_seq_lens, causal_mask)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only models, we use the prefill sequence lengths
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
# Encoder-only models typically use bidirectional attention
|
||||
causal_mask = False
|
||||
|
||||
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
||||
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
||||
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
# Decoder self-attention
|
||||
# Choose max_seq_len based on whether we are in prompt_run
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
query_seq_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
max_seq_len = attn_metadata.max_prefill_seq_len
|
||||
causal_mask = True
|
||||
|
||||
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
|
||||
max_seq_len, attn_metadata.seq_lens, causal_mask)
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens_tensor is not None
|
||||
query_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
|
||||
device=attn_metadata.encoder_seq_lens_tensor.device,
|
||||
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
|
||||
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
assert attn_metadata.seq_lens_tensor is not None
|
||||
key_seq_start_loc = torch.tensor(
|
||||
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
|
||||
device=attn_metadata.seq_lens_tensor.device,
|
||||
dtype=attn_metadata.seq_lens_tensor.dtype)
|
||||
causal_mask = False
|
||||
|
||||
# Enc/dec cross-attention KVs match encoder sequence length;
|
||||
# cross-attention utilizes special "cross" block tables
|
||||
return (query_start_loc, attn_metadata.max_prefill_seq_len,
|
||||
key_seq_start_loc, attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.seq_lens, causal_mask)
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|
||||
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"ROCM_FLASH backend.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in ROCm Flash Attention is not supported yet, it "
|
||||
"will fail back to global attention for long context.")
|
||||
if use_irope:
|
||||
logger.warning(
|
||||
"Using irope in V0 is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
self.logits_soft_cap = 0.0
|
||||
else:
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.attn_type = attn_type
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = ((sliding_window, sliding_window)
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.paged_attn_module = _get_paged_attn_module()
|
||||
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
|
||||
)
|
||||
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.use_naive_attn = False
|
||||
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
if self.use_triton_flash_attn:
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Triton FlashAttention does not support attention"
|
||||
" logits soft capping."
|
||||
" please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
triton_attention)
|
||||
self.triton_attn_func = triton_attention
|
||||
logger.debug("Using Triton FA in ROCmBackend")
|
||||
if self.sliding_window != (-1, -1):
|
||||
logger.warning("ROCm Triton FA does not currently support "
|
||||
"sliding window attention. If using half "
|
||||
"precision, please try using the ROCm CK "
|
||||
"FA backend instead by setting the env var "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
else:
|
||||
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
|
||||
# either
|
||||
if not current_platform.has_device_capability(90):
|
||||
self.use_naive_attn = True
|
||||
else:
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
self.fa_attn_func = flash_attn_varlen_func
|
||||
logger.debug("Using CK FA in ROCmBackend")
|
||||
except ModuleNotFoundError:
|
||||
self.use_naive_attn = True
|
||||
|
||||
if self.use_naive_attn:
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCm Naive FlashAttention does not support "
|
||||
"attention logits soft capping.")
|
||||
|
||||
self.sdpa_attn_func = _sdpa_attention
|
||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||
|
||||
self.aiter_kv_scales_initialized = False
|
||||
self.force_fp8_attention = (
|
||||
get_current_vllm_config() is not None
|
||||
and get_current_vllm_config().model_config.override_attention_dtype
|
||||
== "fp8")
|
||||
|
||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
tokens, n_kv_heads, head_dim = x.shape
|
||||
return (x[:, :,
|
||||
None, :].expand(tokens, n_kv_heads, n_rep,
|
||||
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||
head_dim))
|
||||
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
if self.use_triton_flash_attn:
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
# Only supported in the Triton backend
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
|
||||
cross-attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
* ENCODER_ONLY: bidirectional attention with no KV caching;
|
||||
use prefill sequence attributes
|
||||
|
||||
Args:
|
||||
layer: Attention layer instance.
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: KV cache tensor with shape
|
||||
[2, num_blocks, block_size * num_kv_heads * head_size].
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
output: Optional output tensor.
|
||||
output_scale: Optional output scale tensor.
|
||||
output_block_scale: Optional output block scale tensor.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None and not self.use_triton_flash_attn:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization only supported for Triton"
|
||||
" implementation in ROCMFlashAttentionImpl for now")
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused nvfp4 output quantization is not supported"
|
||||
" for ROCMFlashAttentionImpl")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
paged_attn = self.paged_attn_module
|
||||
|
||||
# Reshaping kv tensors is required for AITER paged attention kernel
|
||||
# because it works on a different tensor shape,
|
||||
# when the size of one element is one byte (int8/fp8 dtypes).
|
||||
# This reshaping is only required on the first forward call
|
||||
# and the kv cache must not be empty.
|
||||
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
|
||||
and not self.aiter_kv_scales_initialized
|
||||
and kv_cache.shape != torch.Size([0])):
|
||||
num_blocks = kv_cache.shape[1]
|
||||
block_size = kv_cache.shape[2] // (self.num_kv_heads *
|
||||
self.head_size)
|
||||
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||
dtype=torch.float32,
|
||||
device=kv_cache.device)
|
||||
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
|
||||
dtype=torch.float32,
|
||||
device=kv_cache.device)
|
||||
self.aiter_kv_scales_initialized = True
|
||||
k_scale.fill_(layer._k_scale.item())
|
||||
v_scale.fill_(layer._v_scale.item())
|
||||
layer._k_scale = k_scale
|
||||
layer._v_scale = v_scale
|
||||
|
||||
# Only update KV cache for decoder self-attention
|
||||
# and encoder-decoder cross-attention
|
||||
if self.attn_type not in [
|
||||
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
||||
] and kv_cache.numel() > 0:
|
||||
key_cache, value_cache = paged_attn.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if key is not None and value is not None:
|
||||
# Reshape the input keys and values and store them in the
|
||||
# cache. If kv_cache is not provided, the new key and value
|
||||
# tensors are not cached. This happens during the initial
|
||||
# memory profiling run.
|
||||
paged_attn.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
attn_metadata.cross_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.attn_type != AttentionType.ENCODER:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
elif self.attn_type == AttentionType.ENCODER_ONLY:
|
||||
# For encoder-only models, all tokens are processed in one go
|
||||
num_prefill_tokens = query.shape[0]
|
||||
else:
|
||||
assert attn_metadata.num_encoder_tokens is not None
|
||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
|
||||
# For encoder-only and encoder models,
|
||||
# we process all tokens at once
|
||||
# For decoder and encoder-decoder,
|
||||
# we may need to limit key/value to prefill tokens
|
||||
if key is not None and value is not None \
|
||||
and self.attn_type not in [AttentionType.ENCODER_DECODER,
|
||||
AttentionType.ENCODER_ONLY]:
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
# normal attention and DECODER
|
||||
if self.attn_type == AttentionType.DECODER and (
|
||||
kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||
or prefill_meta.block_tables.numel() == 0):
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = (prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
attn_metadata.seq_lens, True)
|
||||
# prefix-enabled attention and ENCODER/ENCODER_DECODER
|
||||
else:
|
||||
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
|
||||
key_max_seq_len, seq_lens,
|
||||
causal_mask) = _get_seq_len_block_table_args(
|
||||
prefill_meta, self.attn_type)
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
attn_masks = None
|
||||
if self.use_triton_flash_attn:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
seq_lens,
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
|
||||
use_fp8_scales = (layer._q_scale and layer._k_scale
|
||||
and layer._v_scale and layer._prob_scale
|
||||
and (self.kv_cache_dtype == "fp8"
|
||||
or self.force_fp8_attention))
|
||||
|
||||
full_scales = (
|
||||
layer._q_scale.item(), layer._k_scale.item(),
|
||||
layer._v_scale.item(),
|
||||
layer._prob_scale.item()) if use_fp8_scales else None
|
||||
self.triton_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output[:num_prefill_tokens],
|
||||
query_seq_start_loc,
|
||||
key_seq_start_loc,
|
||||
query_max_seq_len,
|
||||
key_max_seq_len,
|
||||
causal_mask,
|
||||
self.scale,
|
||||
attn_masks[0][None]
|
||||
if attn_masks is not None else None,
|
||||
full_scales,
|
||||
output_scale,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Interleave for MQA workaround.
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.seq_lens,
|
||||
make_attn_mask=causal_mask) # type: ignore
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
# sdpa math backend attention
|
||||
self.sdpa_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output[:num_prefill_tokens],
|
||||
query_seq_start_loc,
|
||||
num_prefill_tokens,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
self.scale,
|
||||
attn_masks,
|
||||
)
|
||||
else:
|
||||
# upstream FA does not support an output arg, copy
|
||||
output[:num_prefill_tokens] = self.fa_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=query_seq_start_loc,
|
||||
cu_seqlens_k=key_seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=key_max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=causal_mask,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
else:
|
||||
# prefix-enabled attention -
|
||||
# not applicable for encoder-only models
|
||||
if self.attn_type != AttentionType.ENCODER_ONLY:
|
||||
output[:num_prefill_tokens] = paged_attn.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Skip decode phase for encoder-only models
|
||||
if (decode_meta := attn_metadata.decode_metadata) and (
|
||||
self.attn_type != AttentionType.ENCODER_ONLY):
|
||||
# Decoding run.
|
||||
# Whether to use rocm custom paged attention or not
|
||||
num_seqs, num_heads, head_size = decode_query.shape
|
||||
block_size = value_cache.shape[3]
|
||||
gqa_ratio = num_heads // self.num_kv_heads
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
use_custom = use_rocm_custom_paged_attention(
|
||||
decode_query.dtype, head_size, block_size, gqa_ratio,
|
||||
decode_meta.max_decode_seq_len, self.sliding_window,
|
||||
self.kv_cache_dtype, self.alibi_slopes)
|
||||
|
||||
if use_custom:
|
||||
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
|
||||
!= AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len)
|
||||
assert max_seq_len is not None
|
||||
max_num_partitions = (
|
||||
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
|
||||
_PARTITION_SIZE_ROCM)
|
||||
assert _PARTITION_SIZE_ROCM % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=query.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
query_start_loc = None
|
||||
ops.paged_attention_rocm(
|
||||
output[num_prefill_tokens:],
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
query_start_loc,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
output_scale,
|
||||
)
|
||||
else:
|
||||
# PagedAttention does not support fused quant, manually quantize
|
||||
if output_scale is None:
|
||||
out_pa = output[num_prefill_tokens:]
|
||||
else:
|
||||
out_pa = torch.empty_like(output[num_prefill_tokens:],
|
||||
dtype=query.dtype)
|
||||
|
||||
out_pa[:] = paged_attn.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
decode_meta.block_tables
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.cross_block_tables,
|
||||
decode_meta.seq_lens_tensor
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.encoder_seq_lens_tensor,
|
||||
decode_meta.max_decode_seq_len
|
||||
if self.attn_type != AttentionType.ENCODER_DECODER else
|
||||
decode_meta.max_encoder_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Manually perform quantization
|
||||
if output_scale is not None:
|
||||
out_uq = out_pa.view(-1, self.num_heads * self.head_size)
|
||||
out_q = output.view(-1, self.num_heads * self.head_size)
|
||||
ops.scaled_fp8_quant(out_uq,
|
||||
output_scale,
|
||||
output=out_q[num_prefill_tokens:])
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def _sdpa_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
start = 0
|
||||
assert output.shape == (num_tokens, num_heads, head_size)
|
||||
assert output.dtype == query.dtype
|
||||
assert output.device == query.device
|
||||
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
end = start + seq_len
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.MATH):
|
||||
sub_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
key[:, start:end, :],
|
||||
value[:, start:end, :],
|
||||
dropout_p=0.0,
|
||||
is_causal=attn_masks is None,
|
||||
attn_mask=attn_masks[i] if attn_masks else None,
|
||||
scale=scale).movedim(query.dim() - 2, 0)
|
||||
output[start:end, :, :] = sub_out
|
||||
start = end
|
||||
|
||||
return output
|
||||
@ -1,111 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
@ -338,10 +338,9 @@ class CommonAttentionState(AttentionState):
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
|
||||
self._update_captured_metadata_for_enc_dec_model(
|
||||
batch_size=batch_size, attn_metadata=attn_metadata)
|
||||
|
||||
@ -360,10 +359,9 @@ class CommonAttentionState(AttentionState):
|
||||
# The encoder decoder model works only with XFormers and
|
||||
# Flash Attention backend. Assert the same.
|
||||
assert self.runner.attn_backend.get_name() in \
|
||||
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS'," \
|
||||
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
|
||||
f"got '{self.runner.attn_backend.get_name()}'"
|
||||
["XFORMERS", "FLASH_ATTN"], \
|
||||
f"Expected attn_backend name to be either 'XFORMERS' or " \
|
||||
f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'"
|
||||
self._add_additional_input_buffers_for_enc_dec_model(
|
||||
attn_metadata=attn_metadata, input_buffers=input_buffers)
|
||||
return input_buffers
|
||||
|
||||
@ -1,805 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with xFormers and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args,
|
||||
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
return XFormersImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return XFormersMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
|
||||
return XFormersMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||
|
||||
# Begin encoder attn & enc/dec cross-attn fields...
|
||||
|
||||
# Encoder sequence lengths representation
|
||||
encoder_seq_lens: Optional[List[int]] = None
|
||||
encoder_seq_lens_tensor: Optional[torch.Tensor] = None
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
encoder_seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
# Maximum sequence length among encoder sequences
|
||||
max_encoder_seq_len: Optional[int] = None
|
||||
|
||||
# Number of tokens input to encoder
|
||||
num_encoder_tokens: Optional[int] = None
|
||||
|
||||
# Cross-attention memory-mapping data structures: slot mapping
|
||||
# and block tables
|
||||
cross_slot_mapping: Optional[torch.Tensor] = None
|
||||
cross_block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for encoder attention is set.
|
||||
'''
|
||||
return is_all_encoder_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def is_all_cross_attn_metadata_set(self):
|
||||
'''
|
||||
All attention metadata required for enc/dec cross-attention is set.
|
||||
|
||||
Superset of encoder attention required metadata.
|
||||
'''
|
||||
return is_all_cross_attn_metadata_set(self)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
# Begin encoder & cross attn fields below...
|
||||
encoder_seq_lens=self.encoder_seq_lens,
|
||||
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len=self.max_encoder_seq_len,
|
||||
cross_slot_mapping=self.cross_slot_mapping,
|
||||
cross_block_tables=self.cross_block_tables)
|
||||
|
||||
# Batch may be composed of prefill|decodes, adjust query start indices
|
||||
# to refer to the start of decodes when the two are split apart.
|
||||
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
if self._cached_decode_metadata.query_start_loc is not None:
|
||||
qs = self._cached_decode_metadata.query_start_loc
|
||||
self._cached_decode_metadata.query_start_loc = qs - qs[0]
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
def _get_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: str,
|
||||
) -> Optional[AttentionBias]:
|
||||
'''
|
||||
Extract appropriate attention bias from attention metadata
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
|
||||
Returns:
|
||||
* Appropriate attention bias value given the attention type
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
return attn_metadata.attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
return attn_metadata.encoder_attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
return attn_metadata.cross_attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
def _set_attn_bias(
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_bias: List[Optional[AttentionBias]],
|
||||
attn_type: str,
|
||||
) -> None:
|
||||
'''
|
||||
Update appropriate attention bias field of attention metadata,
|
||||
according to attention type.
|
||||
|
||||
Arguments:
|
||||
|
||||
* attn_metadata: Attention metadata structure associated with attention
|
||||
* attn_bias: The desired attention bias value
|
||||
* attn_type: encoder attention, decoder self-attention,
|
||||
encoder/decoder cross-attention
|
||||
'''
|
||||
|
||||
if (attn_type == AttentionType.DECODER
|
||||
or attn_type == AttentionType.ENCODER_ONLY):
|
||||
attn_metadata.attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
attn_metadata.encoder_attn_bias = attn_bias
|
||||
elif attn_type == AttentionType.ENCODER_DECODER:
|
||||
attn_metadata.cross_attn_bias = attn_bias
|
||||
else:
|
||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
|
||||
|
||||
_metadata_cls = XFormersMetadata
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0 "
|
||||
"XFORMERS backend.")
|
||||
if logits_soft_cap is not None:
|
||||
logger.warning_once("XFormers does not support logits soft cap. "
|
||||
"Outputs may be slightly off.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in XFormers is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
|
||||
self.attn_type = attn_type
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
For decoder-only models: query, key and value must be non-None.
|
||||
|
||||
For encoder/decoder models:
|
||||
* XFormersImpl.forward() may be invoked for both self- and cross-
|
||||
attention layers.
|
||||
* For self-attention: query, key and value must be non-None.
|
||||
* For cross-attention:
|
||||
* Query must be non-None
|
||||
* During prefill, key and value must be non-None; key and value
|
||||
get cached for use during decode.
|
||||
* During decode, key and value may be None, since:
|
||||
(1) key and value tensors were cached during prefill, and
|
||||
(2) cross-attention key and value tensors do not grow during
|
||||
decode
|
||||
|
||||
A note on how the attn_type (attention type enum) argument impacts
|
||||
attention forward() behavior:
|
||||
|
||||
* DECODER: normal decoder-only behavior;
|
||||
use decoder self-attention block table
|
||||
* ENCODER: no KV caching; pass encoder sequence
|
||||
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len) to kernel, in lieu of decoder
|
||||
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
Used for encoder branch of encoder-decoder models.
|
||||
* ENCODER_ONLY: no kv_caching, uses the normal attention
|
||||
attributes (seq_lens/seq_lens_tensor/max_seq_len).
|
||||
* ENCODER_DECODER: cross-attention behavior;
|
||||
use cross-attention block table for caching KVs derived
|
||||
from encoder hidden states; since KV sequence lengths
|
||||
will match encoder sequence lengths, pass encoder sequence
|
||||
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
|
||||
max_encoder_seq_len)
|
||||
|
||||
Args:
|
||||
layer: Attention layer instance.
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: KV cache tensor with shape
|
||||
[2, num_blocks, block_size * num_kv_heads * head_size].
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
output: Optional output tensor.
|
||||
output_scale: Optional output scale tensor.
|
||||
output_block_scale: Optional output block scale tensor.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersImpl")
|
||||
|
||||
attn_type = self.attn_type
|
||||
# Check that appropriate attention metadata attributes are
|
||||
# selected for the desired attention type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
raise AttributeError("Encoder attention requires setting "
|
||||
"encoder metadata attributes.")
|
||||
|
||||
elif (attn_type == AttentionType.ENCODER_DECODER
|
||||
and (not attn_metadata.is_all_cross_attn_metadata_set)):
|
||||
raise AttributeError("Encoder/decoder cross-attention "
|
||||
"requires setting cross-attention "
|
||||
"metadata attributes.")
|
||||
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
if key is not None:
|
||||
assert value is not None
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
else:
|
||||
assert value is None
|
||||
|
||||
# Self-attention vs. cross-attention will impact
|
||||
# which KV cache memory-mapping & which
|
||||
# seqlen datastructures we utilize
|
||||
|
||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
||||
# KV-cache during decoder-self- or
|
||||
# encoder-decoder-cross-attention, but not
|
||||
# during encoder attention.
|
||||
#
|
||||
# Even if there are no new key/value pairs to cache,
|
||||
# we still need to break out key_cache and value_cache
|
||||
# i.e. for later use by paged attention
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (key is not None) and (value is not None):
|
||||
|
||||
if attn_type == AttentionType.ENCODER_DECODER:
|
||||
# Update cross-attention KV cache (prefill-only)
|
||||
# During cross-attention decode, key & value will be None,
|
||||
# preventing this IF-statement branch from running
|
||||
updated_slot_mapping = attn_metadata.cross_slot_mapping
|
||||
else:
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_query_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_query_tokens]
|
||||
if key is not None and value is not None:
|
||||
key = key[:num_prefill_kv_tokens]
|
||||
value = value[:num_prefill_kv_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_query_tokens
|
||||
assert decode_query.shape[0] == num_decode_query_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta, attn_type=attn_type)
|
||||
assert out.shape == output[:num_prefill_query_tokens].shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
else:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have prefix attention.")
|
||||
|
||||
assert prefill_meta.query_start_loc is not None
|
||||
assert prefill_meta.max_query_len is not None
|
||||
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
out = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
self.kv_cache_dtype,
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
||||
"Encoder-only models should not have decode metadata.")
|
||||
|
||||
(
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
block_tables_arg,
|
||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||
|
||||
output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables_arg,
|
||||
seq_lens_arg,
|
||||
max_seq_len_arg,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
def _run_memory_efficient_xformers_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> torch.Tensor:
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
See https://facebookresearch.github.io/xformers/components/ops.html
|
||||
for API spec.
|
||||
|
||||
Args:
|
||||
query: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
attn_type: Select attention type, between encoder attention,
|
||||
decoder self-attention, or encoder/decoder cross-
|
||||
attention. Defaults to decoder self-attention,
|
||||
which is the vLLM default generally
|
||||
"""
|
||||
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
# Note that the output also has the same shape (which is different
|
||||
# from a spec from the doc).
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
attn_bias = _get_attn_bias(attn_metadata, attn_type)
|
||||
if attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
|
||||
# Cross attention block of decoder branch of encoder-decoder
|
||||
# model uses seq_lens for dec / encoder_seq_lens for enc
|
||||
if (attn_type == AttentionType.ENCODER_DECODER):
|
||||
assert attn_metadata.seq_lens is not None
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Cross-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
device=query.device)
|
||||
|
||||
# Encoder branch of encoder-decoder model uses
|
||||
# attn_metadata.encoder_seq_lens
|
||||
elif attn_type == AttentionType.ENCODER:
|
||||
|
||||
assert attn_metadata.encoder_seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.encoder_seq_lens, device=query.device)
|
||||
|
||||
# Self-attention block of encoder-only model just
|
||||
# uses the seq_lens directly.
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Encoder self-attention mask is non-causal
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, device=query.device)
|
||||
|
||||
# Self-attention block of decoder branch just
|
||||
# uses the seq_lens directly
|
||||
elif attn_type == AttentionType.DECODER:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
|
||||
# Decoder self-attention mask is causal
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
attn_metadata.seq_lens, device=query.device)
|
||||
else:
|
||||
raise ValueError("Unknown AttentionType: %s", attn_type)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
attn_bias = [attn_bias]
|
||||
else:
|
||||
assert attn_type == AttentionType.DECODER
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_bias = _make_alibi_bias(self.alibi_slopes,
|
||||
self.num_kv_heads, query.dtype,
|
||||
attn_metadata.seq_lens)
|
||||
|
||||
_set_attn_bias(attn_metadata, attn_bias, attn_type)
|
||||
|
||||
# No alibi slopes.
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
return out.view_as(original_query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
assert attn_metadata.seq_lens is not None
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, seq_len in enumerate(attn_metadata.seq_lens):
|
||||
end = start + seq_len
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += seq_len
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[AttentionBias]:
|
||||
attn_biases: List[AttentionBias] = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
# Calculate a matrix where each element represents ith element- jth
|
||||
# element.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
1, # batch size
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||
|
||||
return attn_biases
|
||||
@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
|
||||
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
|
||||
is_runai_obj_uri)
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType,
|
||||
LazyLoader, common_broadcastable_dtype)
|
||||
from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
@ -1103,10 +1102,6 @@ class ModelConfig:
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"] = True
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL:
|
||||
raise ValueError("please set VLLM_ATTENTION_BACKEND to "
|
||||
f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
|
||||
@ -44,7 +44,7 @@ class model_aware_kv_ops_helper:
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
# For more details, see vllm/v1/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = model_config.kv_lora_rank + \
|
||||
model_config.qk_rope_head_dim
|
||||
|
||||
@ -44,8 +44,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
|
||||
maybe_override_with_speculators)
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
|
||||
is_in_ray_actor)
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
# yapf: enable
|
||||
@ -1163,17 +1163,6 @@ class EngineArgs:
|
||||
self._set_default_args_v0(model_config)
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
|
||||
assert self.enforce_eager, (
|
||||
"Cuda graph is not supported with DualChunkFlashAttention. "
|
||||
"To run the model in eager mode, set 'enforce_eager=True' "
|
||||
"or use '--enforce-eager' in the CLI.")
|
||||
assert current_platform.is_cuda(), (
|
||||
"DualChunkFlashAttention is only supported on CUDA platform.")
|
||||
assert not use_v1, (
|
||||
"DualChunkFlashAttention is not supported on V1 engine. "
|
||||
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
||||
|
||||
sliding_window: Optional[int] = None
|
||||
if not is_interleaved(model_config.hf_text_config):
|
||||
# Only set CacheConfig.sliding_window if the model is all sliding
|
||||
|
||||
@ -529,7 +529,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "TORCH_SDPA": use torch.nn.MultiheadAttention
|
||||
# - "FLASH_ATTN": use FlashAttention
|
||||
# - "XFORMERS": use XFormers
|
||||
# - "ROCM_FLASH": use ROCmFlashAttention
|
||||
# - "FLASHINFER": use flashinfer
|
||||
# - "FLASHMLA": use FlashMLA
|
||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||
|
||||
@ -53,13 +53,18 @@ class Mamba2Metadata:
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
"""Returns the appropriate metadata classes for the current platform."""
|
||||
if current_platform.is_rocm():
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata)
|
||||
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
|
||||
elif current_platform.is_cuda():
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
return (FlashAttentionMetadata, XFormersMetadata,
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.triton_attn import (
|
||||
TritonAttentionMetadata)
|
||||
return (AiterFlashAttentionMetadata, TritonAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
if current_platform.is_cuda():
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.xformers import (
|
||||
XFormersAttentionMetadata)
|
||||
return (FlashAttentionMetadata, XFormersAttentionMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
raise ValueError(
|
||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||
|
||||
@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||
For more info see MLACommonImpl in:
|
||||
vllm/v1/attention/backends/mla/utils.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -226,8 +226,10 @@ class CudaPlatformBase(Platform):
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
"Set VLLM_USE_V1=1 to enable them.")
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
@ -246,35 +248,17 @@ class CudaPlatformBase(Platform):
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
selected_backend is None)
|
||||
|
||||
def _get_version(name, import_suffix) -> str:
|
||||
if use_v1:
|
||||
logger.info_once(f"Using {name} backend on V1 engine.")
|
||||
return f"vllm.v1.attention.backends.mla.{import_suffix}"
|
||||
else:
|
||||
logger.info_once(f"Using {name} backend.")
|
||||
return f"vllm.attention.backends.{import_suffix}"
|
||||
|
||||
if use_cutlassmla:
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"cutlass_mla.CutlassMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"Cutlass MLA backend is only supported on V1 engine")
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"cutlass_mla.CutlassMLABackend")
|
||||
if use_flashinfermla:
|
||||
if use_v1:
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once(
|
||||
"Using FlashInfer MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashinfer_mla.FlashInferMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer MLA backend is only supported on V1 engine"
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once("Using FlashInfer MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashinfer_mla.FlashInferMLABackend")
|
||||
if use_flashmla:
|
||||
if block_size != 64:
|
||||
logger.warning(
|
||||
@ -282,20 +266,18 @@ class CudaPlatformBase(Platform):
|
||||
" (currently only supports block size 64).",
|
||||
block_size)
|
||||
else:
|
||||
return _get_version("FlashMLA", "flashmla.FlashMLABackend")
|
||||
if use_flashattn:
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using FlashAttention MLA backend on V1 engine.")
|
||||
logger.info_once("Using FlashMLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashattn_mla.FlashAttnMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashAttention MLA backend is only supported on V1 "
|
||||
"engine.")
|
||||
"flashmla.FlashMLABackend")
|
||||
if use_flashattn:
|
||||
logger.info_once(
|
||||
"Using FlashAttention MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashattn_mla.FlashAttnMLABackend")
|
||||
if use_triton:
|
||||
return _get_version("Triton MLA",
|
||||
"triton_mla.TritonMLABackend")
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
@ -382,78 +364,9 @@ class CudaPlatformBase(Platform):
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
# Backends for V0 engine
|
||||
if selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN:
|
||||
logger.info("Using DifferentialFlashAttention backend.")
|
||||
return ("vllm.attention.backends.differential_flash_attn."
|
||||
"DifferentialFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
||||
|
||||
target_backend = _Backend.FLASH_ATTN
|
||||
if not cls.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is
|
||||
# installed.
|
||||
if target_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend, flash_attn_supports_fp8)
|
||||
|
||||
supported_sizes = \
|
||||
FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in supported_sizes:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
target_backend = _Backend.XFORMERS
|
||||
fp8_kv_cache = (kv_cache_dtype is not None
|
||||
and kv_cache_dtype.startswith("fp8"))
|
||||
if (fp8_kv_cache and not flash_attn_supports_fp8()):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention backend for FP8 KV cache.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
"vllm.vllm_flash_attn package is not found. "
|
||||
"Make sure that vllm_flash_attn was built and installed "
|
||||
"(on by default).")
|
||||
target_backend = _Backend.XFORMERS
|
||||
|
||||
if target_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend.")
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@ -191,6 +191,11 @@ class RocmPlatform(Platform):
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
"Set VLLM_USE_V1=1 to enable them.")
|
||||
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
|
||||
is_aiter_mla_enabled)
|
||||
|
||||
@ -201,39 +206,24 @@ class RocmPlatform(Platform):
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}.")
|
||||
elif selected_backend == _Backend.ROCM_AITER_MLA \
|
||||
or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1:
|
||||
if block_size == 1:
|
||||
if use_v1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
else:
|
||||
logger.info("Using AITER MLA backend")
|
||||
return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)")
|
||||
else:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend.")
|
||||
|
||||
if selected_backend is None or selected_backend == _Backend.FLASH_ATTN:
|
||||
selected_backend = _Backend.ROCM_FLASH
|
||||
f"does not support block size {block_size}.")
|
||||
if selected_backend in (_Backend.ROCM_AITER_MLA,
|
||||
_Backend.ROCM_AITER_MLA_VLLM_V1):
|
||||
if block_size == 1:
|
||||
logger.info("Using AITER MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
"(currently only supports block size 1)")
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend.")
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
|
||||
@ -245,14 +235,9 @@ class RocmPlatform(Platform):
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not cls.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend.")
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
@ -157,10 +157,8 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
MB_bytes = 1_000_000
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user