Woosuk Kwon 8407fa02ed fix
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 14:52:23 -07:00

841 lines
32 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
Protocol, TypeVar, Union, get_args)
import numpy as np
import torch
from typing_extensions import runtime_checkable
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.ubatch_utils import UBatchSlice
logger = init_logger(__name__)
KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
@dataclass
class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
max_seq_len: int
"""Longest context length in batch"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
causal: bool = True
# Needed by FastPrefillAttentionBuilder
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
def slice_query_start_locs(
query_start_loc: torch.Tensor,
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
query_start_loc[request_slice.start]
def _make_metadata_with_slice(
ubatch_slice: UBatchSlice,
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
request_slice = ubatch_slice.request_slice
token_slice = ubatch_slice.token_slice
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
request_slice)
assert len(query_start_loc) >= 2, (
f"query_start_loc must have at least 2 elements, "
f"got {len(query_start_loc)}")
query_start_loc_cpu = slice_query_start_locs(
attn_metadata.query_start_loc_cpu, request_slice)
seq_lens = attn_metadata.seq_lens[request_slice]
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
max_seq_len = int(seq_lens_cpu.max())
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
request_slice]
num_requests = request_slice.stop - request_slice.start
num_actual_tokens = token_slice.stop - token_slice.start
max_query_len = int(
torch.max(torch.abs(query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])).item())
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if max_query_len == 0:
max_query_len = attn_metadata.max_query_len
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
slot_mapping = attn_metadata.slot_mapping[token_slice]
return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_requests,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
)
def split_attn_metadata(
ubatch_slices: list[UBatchSlice],
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results = []
for ubatch_slice in ubatch_slices:
results.append(
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
return results
M = TypeVar("M")
class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""
ALWAYS = 3
"""Cudagraph always supported; supports mixed-prefill-decode"""
UNIFORM_BATCH = 2
"""Cudagraph supported for batches the only contain query lengths that are
the same, this can be used for spec-decode
i.e. "decodes" are 1 + num_speculative_tokens"""
UNIFORM_SINGLE_TOKEN_DECODE = 1
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
NEVER = 0
"""NO cudagraph support"""
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention (default: no).
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None
@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@abstractmethod
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
def use_cascade_attention(
self,
common_prefix_len: int,
query_lens: np.ndarray,
num_query_heads: int,
num_kv_heads: int,
use_alibi: bool,
use_sliding_window: bool,
use_local_attention: bool,
num_sms: int,
) -> bool:
return False
@functools.lru_cache
def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
"Setting KV cache layout to %s.", cache_layout)
return cache_layout
# Format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
# When neither the user nor the override specified a layout, get default
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
assert is_valid_kv_cache_layout(cache_layout)
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
return cache_layout
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
@dataclass
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
"""
window_left: int
logits_soft_cap: Optional[float]
sm_scale: float
has_sinks: bool = False
def get_per_layer_parameters(
vllm_config: VllmConfig, layer_names: list[str],
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
"""
Scan layers in `layer_names` and determine some hyperparameters
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
impl = layer.impl
assert isinstance(impl, cls_)
# Infer hyperparameters from the attention layer
window_size = getattr(impl, "sliding_window", None)
window_left = window_size[0] if window_size is not None else -1
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
sm_scale = impl.scale
has_sinks = getattr(impl, "sinks", None) is not None
per_layer_params[key] = PerLayerParameters(window_left,
logits_soft_cap, sm_scale,
has_sinks)
return per_layer_params
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert len(per_layer_params) > 0, "No attention layers found in the model."
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
# trtllm attention doesn't need global hyper params so disable the check
if not envs.VLLM_USE_TRTLLM_ATTENTION:
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. " \
"One potential fix is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
common_attn_metadata: CommonAttentionMetadata,
block_size: int = 0,
) -> CommonAttentionMetadata:
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
block_table = common_attn_metadata.block_table_tensor
device = common_attn_metadata.query_start_loc.device
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = \
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1),
attn_chunk_size)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
cu_seqlens_q_local[0] = 0
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1],
attn_chunk_size,
dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
(rarange * attn_chunk_size + \
np.repeat(tokens_in_last_block, local_blocks))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // block_size
assert attn_chunk_size % block_size == 0, \
f"attn_chunk_size {attn_chunk_size} is not " \
f"divisible by block_size {block_size}"
pages_per_local_batch = attn_chunk_size // block_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices = (block_starts[:, None] +
np.arange(pages_per_local_batch, dtype=np.int32))
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] -
1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
# regression when using numpy arrays (batch and block indices) to index into
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
# tensor first, which recovers perf.
batch_indices_torch = torch.from_numpy(batch_indices)
block_indices_torch = torch.from_numpy(block_indices)
block_table_local = block_table[batch_indices_torch, block_indices_torch]\
.view(virtual_batches, -1)
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
max_seq_len = int(seq_lens_cpu.max())
return CommonAttentionMetadata(
query_start_loc_cpu=query_start_loc_cpu,
query_start_loc=query_start_loc_cpu.to(device=device,
non_blocking=True),
seq_lens_cpu=seq_lens_cpu,
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
num_reqs=len(seq_lens_cpu),
num_actual_tokens=common_attn_metadata.num_actual_tokens,
max_query_len=seqlens_q_local.max(),
max_seq_len=max_seq_len,
block_table_tensor=block_table_local,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
def make_kv_sharing_fast_prefill_common_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> CommonAttentionMetadata:
if common_attn_metadata.max_query_len == 1:
# All requests are decode (assume 1 token for now)
# Skip computing fast prefill path
return common_attn_metadata
assert common_attn_metadata.logits_indices_padded is not None
assert common_attn_metadata.num_logits_indices is not None
logits_indices_padded = common_attn_metadata.logits_indices_padded
num_logits_indices = common_attn_metadata.num_logits_indices
# Get rid of CUDAGraph padding, if any
logits_indices = logits_indices_padded[:num_logits_indices]
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# Example inputs
# num_reqs: 3
# generation_indices: [14, 18, 19, 27]
# query_start_loc: [0, 15, 20, 28]
# seq_lens: [41, 31, 40]
# Find how many decode indices belong to each request
# request_ids: [0, 1, 1, 2]
request_ids = torch.bucketize(logits_indices,
query_start_loc[1:],
right=True)
# Figure out how many tokens are in each request
# num_decode_tokens: [1, 2, 1]
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
# Calculate new query_start_loc with tokens in generation_indices
# decode_query_start_loc: [0, 1, 3, 4]
decode_query_start_loc = torch.empty(num_reqs + 1,
device=query_start_loc.device,
dtype=query_start_loc.dtype)
decode_query_start_loc[0] = 0
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
decode_max_query_len = int(num_decode_tokens.max().item())
total_num_decode_tokens = int(num_decode_tokens.sum().item())
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=decode_query_start_loc,
query_start_loc_cpu=decode_query_start_loc.to("cpu",
non_blocking=True),
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_decode_tokens,
max_query_len=decode_max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
causal=True,
)
return common_attn_metadata
def subclass_attention_backend(
name_prefix: str, attention_backend_cls: type[AttentionBackend],
builder_cls: type[AttentionMetadataBuilder[M]]
) -> type[AttentionBackend]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
return type(name, (attention_backend_cls, ),
{"get_builder_cls": lambda: builder_cls})
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0),
]
def subclass_attention_metadata(
name_prefix: str,
metadata_cls: Any,
fields: list[tuple[str, Any, Any]],
) -> Any:
"""
Return a new subclass of `metadata_cls` with additional fields
"""
name: str = name_prefix + metadata_cls.__name__ # type: ignore
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
return Wrapped
@runtime_checkable
class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor
num_logits_indices: int
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata =\
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
metadata = super().build(common_prefix_len,
new_common_attn_metadata, fast_build)
class KVSharingFastPrefillAttentionMetadata(
metadata.__class__, # type: ignore
KVSharingFastPrefillMetadata):
def __init__(self, metadata, common_attn_metadata):
# Shallow copy all fields in metadata cls
for field in fields(metadata.__class__):
setattr(self, field.name,
getattr(metadata, field.name))
# Set additional fields that will be used in model code
assert (common_attn_metadata.logits_indices_padded
is not None
and common_attn_metadata.num_logits_indices
is not None)
self.logits_indices_padded = \
common_attn_metadata.logits_indices_padded
self.num_logits_indices = \
common_attn_metadata.num_logits_indices
return KVSharingFastPrefillAttentionMetadata(
metadata, common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=FastPrefillAttentionBuilder)
return attn_backend