mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
448 lines
16 KiB
Python
448 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Attention layer with TreeAttention."""
|
|
|
|
import ast
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
AttentionMetadata, AttentionType)
|
|
from vllm.attention.ops.triton_unified_attention import unified_attention
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.attention.backends.utils import (
|
|
AttentionMetadataBuilder, CommonAttentionMetadata,
|
|
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
|
|
from vllm import _custom_ops as ops
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class TreeAttentionBackend(AttentionBackend):
|
|
|
|
accept_output_buffer: bool = True
|
|
|
|
@classmethod
|
|
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_supported_head_sizes(cls) -> list[int]:
|
|
return [32, 64, 96, 128, 160, 192, 224, 256]
|
|
|
|
@classmethod
|
|
def validate_head_size(cls, head_size: int) -> None:
|
|
supported_head_sizes = cls.get_supported_head_sizes()
|
|
if head_size not in supported_head_sizes:
|
|
attn_type = cls.__name__.removesuffix("Backend")
|
|
raise ValueError(
|
|
f"Head size {head_size} is not supported by {attn_type}. "
|
|
f"Supported head sizes are: {supported_head_sizes}. "
|
|
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
|
"FlexAttention backend which supports all head sizes.")
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "TREE_ATTN_VLLM_V1"
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
|
return TreeAttentionImpl
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
|
return TreeAttentionMetadata
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int,
|
|
head_size: int,
|
|
) -> tuple[int, ...]:
|
|
if block_size % 16 != 0:
|
|
raise ValueError("Block size must be a multiple of 16.")
|
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
|
return TreeAttentionMetadataBuilder
|
|
|
|
@staticmethod
|
|
def use_cascade_attention(*args, **kwargs) -> bool:
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class TreeAttentionMetadata:
|
|
num_actual_tokens: int # Number of tokens excluding padding.
|
|
max_query_len: int
|
|
query_start_loc: torch.Tensor
|
|
max_seq_len: int
|
|
seq_lens: torch.Tensor
|
|
block_table: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
num_prefill_tokens: int = 0
|
|
num_decode_tokens: int = 0
|
|
num_prefills: int = 0
|
|
num_decodes: int = 0
|
|
|
|
tree_attn_bias: Optional[torch.Tensor] = None
|
|
|
|
# Cached Prefill/decode metadata.
|
|
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
|
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
|
|
|
@property
|
|
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
# Recover cached prefill-phase attention
|
|
# metadata structure
|
|
return self._cached_prefill_metadata
|
|
|
|
q_start_loc = self.query_start_loc[self.num_decodes:]
|
|
q_seqlens = torch.diff(q_start_loc)
|
|
kv_seqlens = self.seq_lens[self.num_decodes:]
|
|
# Construct & cache prefill-phase attention metadata structure
|
|
self._cached_prefill_metadata = TreeAttentionMetadata(
|
|
num_actual_tokens=self.num_prefill_tokens,
|
|
max_query_len=int(q_seqlens.max().item()),
|
|
query_start_loc=q_start_loc - q_start_loc[0],
|
|
max_seq_len=int(kv_seqlens.max().item()),
|
|
seq_lens=kv_seqlens,
|
|
block_table=self.block_table[self.num_decodes:],
|
|
slot_mapping=self.slot_mapping[self.num_decode_tokens:],
|
|
)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
# Recover cached decode-phase attention
|
|
# metadata structure
|
|
return self._cached_decode_metadata
|
|
|
|
q_start_loc = self.query_start_loc[:self.num_decodes + 1]
|
|
q_seqlens = torch.diff(q_start_loc)
|
|
kv_seqlens = self.seq_lens[:self.num_decodes]
|
|
# Construct & cache decode-phase attention metadata structure
|
|
self._cached_decode_metadata = TreeAttentionMetadata(
|
|
num_actual_tokens=self.num_decode_tokens,
|
|
max_query_len=int(q_seqlens.max().item()),
|
|
query_start_loc=q_start_loc,
|
|
max_seq_len=int(kv_seqlens.max().item()),
|
|
seq_lens=kv_seqlens,
|
|
block_table=self.block_table[:self.num_decodes],
|
|
slot_mapping=self.slot_mapping[:self.num_decode_tokens],
|
|
tree_attn_bias=self.tree_attn_bias,
|
|
)
|
|
return self._cached_decode_metadata
|
|
|
|
|
|
class TreeAttentionMetadataBuilder(
|
|
AttentionMetadataBuilder[TreeAttentionMetadata]):
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.block_size = kv_cache_spec.block_size
|
|
|
|
spec_config = vllm_config.speculative_config
|
|
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
|
tree_choices: list[tuple[int,
|
|
...]] = (ast.literal_eval(spec_token_tree)
|
|
if spec_token_tree is not None else
|
|
[(0, )])
|
|
# Construct the tree attention bias.
|
|
depth_counts = _get_depth_counts(tree_choices)
|
|
self.tree_attn_bias = _prepare_tree_attn_bias(
|
|
tree_choices,
|
|
depth_counts,
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
return reorder_batch_to_split_decodes_and_prefills(
|
|
input_batch,
|
|
scheduler_output,
|
|
decode_threshold=self.tree_attn_bias.shape[0])
|
|
|
|
def build(
|
|
self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False,
|
|
) -> TreeAttentionMetadata:
|
|
decode_threshold = self.tree_attn_bias.shape[0]
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
split_decodes_and_prefills(common_attn_metadata,
|
|
decode_threshold=decode_threshold))
|
|
|
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
|
q_start_loc = common_attn_metadata.query_start_loc
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
kv_seqlens = common_attn_metadata.seq_lens
|
|
max_seq_len = common_attn_metadata.max_seq_len
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
slot_mapping = common_attn_metadata.slot_mapping
|
|
|
|
return TreeAttentionMetadata(
|
|
num_actual_tokens=num_actual_tokens,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
num_prefills=num_prefills,
|
|
num_decodes=num_decodes,
|
|
max_query_len=max_query_len,
|
|
query_start_loc=q_start_loc,
|
|
max_seq_len=max_seq_len,
|
|
seq_lens=kv_seqlens,
|
|
block_table=block_table,
|
|
slot_mapping=slot_mapping,
|
|
tree_attn_bias=self.tree_attn_bias,
|
|
)
|
|
|
|
def build_for_drafting(
|
|
self,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
draft_index: int,
|
|
) -> TreeAttentionMetadata:
|
|
# Cache the original tree attention bias.
|
|
orig_tree_attn_bias = self.tree_attn_bias
|
|
|
|
if draft_index == 0:
|
|
# Use prefill for drafting at the root level.
|
|
self.tree_attn_bias = torch.empty(0)
|
|
else:
|
|
# Slice the tree attention bias for drafting. Exclude
|
|
# the root level.
|
|
start, end = 1, 1 + common_attn_metadata.max_query_len
|
|
self.tree_attn_bias = self.tree_attn_bias[start:end,
|
|
start:end].contiguous()
|
|
|
|
# Build attention bias.
|
|
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
|
|
|
# Reset the tree attention bias to the original value.
|
|
self.tree_attn_bias = orig_tree_attn_bias
|
|
return attn_metadata
|
|
|
|
|
|
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
|
# Count the number of choices at each depth of the tree.
|
|
depth_counts = []
|
|
prev_depth = 0
|
|
for path in sorted_tree_choices:
|
|
depth = len(path)
|
|
if depth != prev_depth:
|
|
depth_counts.append(0)
|
|
depth_counts[depth - 1] += 1
|
|
prev_depth = depth
|
|
return depth_counts
|
|
|
|
|
|
def _prepare_tree_attn_bias(
|
|
sorted_tree_choices: list[tuple[int, ...]],
|
|
depth_counts: list[int],
|
|
dtype: Optional[torch.dtype],
|
|
device: Optional[torch.device],
|
|
) -> torch.Tensor:
|
|
# +1 comes from the additional root node.
|
|
tree_len = len(sorted_tree_choices) + 1
|
|
tree_attn_mask = torch.full((tree_len, tree_len),
|
|
-torch.inf,
|
|
device=device,
|
|
dtype=dtype)
|
|
|
|
# Set diagonal to all zeros. Each token should
|
|
# attend to itself.
|
|
mask_val = 0
|
|
for i in range(tree_len):
|
|
tree_attn_mask[i, i] = mask_val
|
|
|
|
# Set root to all zeros. All tokens attend to it.
|
|
tree_attn_mask[:, 0] = mask_val
|
|
|
|
# Set all ancestors to zeros.
|
|
start = 0
|
|
for i in range(len(depth_counts)):
|
|
for j in range(depth_counts[i]):
|
|
cur_tree_choice = sorted_tree_choices[start + j]
|
|
# Retrieve ancestor position.
|
|
if len(cur_tree_choice) == 1:
|
|
continue
|
|
ancestor_idx = []
|
|
for c in range(len(cur_tree_choice) - 1):
|
|
ancestor_idx.append(
|
|
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1)
|
|
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
|
start += depth_counts[i]
|
|
return tree_attn_mask
|
|
|
|
|
|
class TreeAttentionImpl(AttentionImpl):
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[list[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: Optional[float] = None,
|
|
attn_type: AttentionType = AttentionType.DECODER,
|
|
kv_sharing_target_layer_name: Optional[str] = None,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
if alibi_slopes is not None:
|
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
|
self.alibi_slopes = alibi_slopes
|
|
if logits_soft_cap is None:
|
|
# Setting logits_soft_cap to 0 means no soft cap.
|
|
logits_soft_cap = 0
|
|
self.logits_soft_cap = logits_soft_cap
|
|
if sliding_window is None:
|
|
self.sliding_window = (-1, -1)
|
|
else:
|
|
self.sliding_window = (sliding_window - 1, 0)
|
|
|
|
TreeAttentionBackend.validate_head_size(head_size)
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError("Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"TreeAttentionImpl.")
|
|
|
|
def forward(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: TreeAttentionMetadata,
|
|
output: Optional[torch.Tensor] = None,
|
|
output_scale: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""Forward pass with TreeAttention.
|
|
|
|
Args:
|
|
query: shape = [num_tokens, num_heads, head_size]
|
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
|
attn_metadata: Metadata for attention.
|
|
Returns:
|
|
shape = [num_tokens, num_heads * head_size]
|
|
"""
|
|
assert output is not None, "Output tensor must be provided."
|
|
|
|
if output_scale is not None:
|
|
raise NotImplementedError(
|
|
"fused output quantization is not yet supported"
|
|
" for TreeAttentionImpl")
|
|
|
|
if attn_metadata is None:
|
|
# Profiling run.
|
|
return output
|
|
|
|
# Cache the input KVs.
|
|
key_cache, value_cache = kv_cache.unbind(0)
|
|
if self.kv_sharing_target_layer_name is None:
|
|
# Reshape the input keys and values and store them in the cache.
|
|
# Skip this if sharing KV cache with an earlier attention layer.
|
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
|
# op uses the slot_mapping's shape to determine the number of
|
|
# actual tokens.
|
|
ops.reshape_and_cache_flash(
|
|
key,
|
|
value,
|
|
key_cache,
|
|
value_cache,
|
|
attn_metadata.slot_mapping,
|
|
self.kv_cache_dtype,
|
|
layer._k_scale,
|
|
layer._v_scale,
|
|
)
|
|
|
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
|
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
|
|
key.shape[1])
|
|
if prefill_meta := attn_metadata.prefill_metadata:
|
|
unified_attention(
|
|
q=query[num_decode_tokens:num_actual_tokens],
|
|
k=key_cache,
|
|
v=value_cache,
|
|
out=output[num_decode_tokens:num_actual_tokens],
|
|
cu_seqlens_q=prefill_meta.query_start_loc,
|
|
max_seqlen_q=prefill_meta.max_query_len,
|
|
seqused_k=prefill_meta.seq_lens,
|
|
max_seqlen_k=prefill_meta.max_seq_len,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
alibi_slopes=self.alibi_slopes,
|
|
window_size=self.sliding_window,
|
|
block_table=prefill_meta.block_table,
|
|
softcap=self.logits_soft_cap,
|
|
q_descale=None, # Not supported
|
|
k_descale=layer._k_scale.expand(descale_shape),
|
|
v_descale=layer._v_scale.expand(descale_shape),
|
|
)
|
|
|
|
if decode_meta := attn_metadata.decode_metadata:
|
|
unified_attention(
|
|
q=query[:num_decode_tokens],
|
|
k=key_cache,
|
|
v=value_cache,
|
|
out=output[:num_decode_tokens],
|
|
cu_seqlens_q=decode_meta.query_start_loc,
|
|
max_seqlen_q=decode_meta.max_query_len,
|
|
seqused_k=decode_meta.seq_lens,
|
|
max_seqlen_k=decode_meta.max_seq_len,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
alibi_slopes=self.alibi_slopes,
|
|
qq_bias=decode_meta.tree_attn_bias,
|
|
window_size=self.sliding_window,
|
|
block_table=decode_meta.block_table,
|
|
softcap=self.logits_soft_cap,
|
|
q_descale=None, # Not supported
|
|
k_descale=layer._k_scale.expand(descale_shape),
|
|
v_descale=layer._v_scale.expand(descale_shape),
|
|
)
|
|
return output
|