Ning Xie 2f1c19b245
[CI] change spell checker from codespell to typos (#18711)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
2025-06-11 19:57:10 -07:00

611 lines
25 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.")
PAD_SLOT_ID = -1
# Switch to numpy implementation of compute_slot_mapping
# if we have at least this many elements. Could be tuned further.
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
return (isinstance(block_tables, dict)
and all(value is None for value in block_tables.values()))
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
start_idx = max(0, query_len - sliding_window)
return start_idx
def _compute_slot_mapping_python(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
for i in range(range_start, range_end):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
def _compute_slot_mapping_numpy(slot_mapping: List[int],
block_table: List[int], range_start: int,
range_end: int, block_size: int):
block_table_array = np.array(block_table)
idx = np.arange(range_start, range_end)
block_offset = idx % block_size
idx //= block_size
seq_slot_mapping_array = block_table_array[idx]
seq_slot_mapping_array *= block_size
seq_slot_mapping_array += block_offset
slot_mapping.extend(seq_slot_mapping_array)
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
padding_mask_len = max(0, start_idx - context_len)
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
range_start = max(start_idx, context_len)
range_end = seq_len
numel = range_end - range_start
block_table = block_tables[seq_id]
# numpy implementation will be faster than python if we have
# many elements, otherwise it will be slower.
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
range_end, block_size)
else:
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
range_end, block_size)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class CommonAttentionState(AttentionState):
def __init__(self, runner: "ModelRunnerBase"):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata
def get_graph_input_buffers(
self,
attn_metadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in \
["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \
f"Expected attn_backend name to be either 'XFORMERS'," \
f"'ROCM_FLASH', or 'FLASH_ATTN', but " \
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additional_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False) -> None:
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
def begin_forward(self, model_input) -> None:
return
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
attn_metadata):
"""
Updates the attention metadata parameters for CUDA graph capture in an
encoder-decoder model.
This method modifies attention-related tensors and metadata required
for CUDA graph capture in encoder-decoder models. Specifically, it
updates the cross-attention and encoder sequence tensors in the
AttentionMetadata object.
"""
# During decode phase the cross_slot_mapping will be empty. Hence set
# an empty tensor for CUDA Graph capture.
attn_metadata.cross_slot_mapping = torch.tensor(
[], dtype=torch.int).cuda()
attn_metadata.cross_block_tables = torch.full(
(batch_size, self.runner.get_max_block_per_batch()),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.num_encoder_tokens = 0
def _add_additional_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
"""
Saves additional input buffers specific to the encoder-decoder model
from the attention metadata.
This method extracts and stores encoder-decoder related input buffers
from the `attn_metadata` into the `input_buffers` dictionary. The
buffers include encoder sequence lengths, cross-slot mappings, and
cross-block tables, which are essential for the encoder-decoder model
during CUDA graph replay.
"""
input_buffers["encoder_seq_lens_tensor"] = (
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
input_buffers["cross_slot_mapping"] = (
attn_metadata.decode_metadata.cross_slot_mapping)
input_buffers["cross_block_tables"] = (
attn_metadata.decode_metadata.cross_block_tables)
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
input_buffers: Dict[str,
Any]):
"""
Populates input buffers with data from the encoder-decoder model's
attention metadata.
This method fills the input buffers with encoder-decoder specific
tensors. It copies data from the `attn_metadata` and keyword arguments
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
The copied data includes attention-related metadata as well as input
IDs and positional information for the encoder.
"""
input_buffers["encoder_seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
non_blocking=True)
input_buffers["cross_slot_mapping"].copy_(
attn_metadata.decode_metadata.cross_slot_mapping,
non_blocking=True)
input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True)
def is_all_encoder_attn_metadata_set(attn_metadata):
'''
All attention metadata required for encoder attention is set.
'''
return ((attn_metadata.encoder_seq_lens is not None)
and (attn_metadata.encoder_seq_lens_tensor is not None)
and (attn_metadata.max_encoder_seq_len is not None))
def is_all_cross_attn_metadata_set(attn_metadata):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return (attn_metadata.is_all_encoder_attn_metadata_set
and (attn_metadata.cross_slot_mapping is not None)
and (attn_metadata.cross_block_tables is not None))
def get_seq_len_block_table_args(
attn_metadata,
is_prompt: bool,
attn_type: str,
) -> tuple:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if attn_type == AttentionType.DECODER:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if is_prompt:
max_seq_len = attn_metadata.max_prefill_seq_len
else:
max_seq_len = attn_metadata.max_decode_seq_len
return (attn_metadata.seq_lens_tensor, max_seq_len,
attn_metadata.block_tables)
elif attn_type == AttentionType.ENCODER_DECODER:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len,
attn_metadata.cross_block_tables)
elif attn_type == AttentionType.ENCODER:
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def get_num_prefill_decode_query_kv_tokens(
attn_metadata,
attn_type: str,
) -> Tuple[int, int, int]:
"""
Calculate the number of prefill and decode tokens for query, key/value
based on the attention metadata and the specified attention type.
Args:
attn_metadata (AttentionMetadata): Attention Metadata object.
attn_type (AttentionType): The type of attention being used.
Returns:
Tuple[int, int, int]: A tuple containing three integers:
- The number of prefill query tokens.
- The number of prefill key/value tokens.
- The number of decode query tokens.
Raises:
AssertionError: If the number of encoder tokens in `attn_metadata`
is `None` when required for the calculations.
"""
num_prefill_query_tokens = 0
num_decode_query_tokens = 0
num_prefill_kv_tokens = 0
if attn_type == AttentionType.ENCODER:
# Encoder attention is only invoked during prefill phase.
# The same input servers a both query and key.
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_encoder_tokens
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = 0
elif attn_type == AttentionType.ENCODER_DECODER:
assert attn_metadata.num_encoder_tokens is not None
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
# The key is the encoder/cross-attention.
num_prefill_kv_tokens = attn_metadata.num_encoder_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
else: # attn_type == AttentionType.DECODER or
# attn_type == AttentionType.ENCODER_ONLY
num_prefill_query_tokens = attn_metadata.num_prefill_tokens
num_prefill_kv_tokens = attn_metadata.num_prefill_tokens
num_decode_query_tokens = attn_metadata.num_decode_tokens
return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)
@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)