mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 19:57:07 +08:00
841 lines
32 KiB
Python
841 lines
32 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import abc
|
|
import enum
|
|
import functools
|
|
from abc import abstractmethod
|
|
from dataclasses import dataclass, fields, make_dataclass
|
|
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
|
|
Protocol, TypeVar, Union, get_args)
|
|
|
|
import numpy as np
|
|
import torch
|
|
from typing_extensions import runtime_checkable
|
|
|
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
from vllm.utils import cdiv
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.attention.backends.abstract import AttentionImpl
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
|
AttentionMetadata)
|
|
from vllm.attention.layer import Attention
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
get_kv_connector_cache_layout)
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
from vllm.v1.worker.ubatch_utils import UBatchSlice
|
|
|
|
logger = init_logger(__name__)
|
|
KVCacheLayoutType = Literal["NHD", "HND"]
|
|
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
|
|
|
|
|
def is_valid_kv_cache_layout(value: str) -> bool:
|
|
return value in get_args(KVCacheLayoutType)
|
|
|
|
|
|
@dataclass
|
|
class CommonAttentionMetadata:
|
|
"""
|
|
Per-batch attention metadata, shared across layers and backends.
|
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
|
|
|
For many of the tensors we keep both GPU and CPU versions.
|
|
"""
|
|
|
|
query_start_loc: torch.Tensor
|
|
query_start_loc_cpu: torch.Tensor
|
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
|
|
seq_lens: torch.Tensor
|
|
seq_lens_cpu: torch.Tensor
|
|
"""(batch_size,), the length of each request including both computed tokens
|
|
and newly scheduled tokens"""
|
|
|
|
num_computed_tokens_cpu: torch.Tensor
|
|
"""(batch_size,), the number of computed tokens for each request"""
|
|
|
|
num_reqs: int
|
|
"""Number of requests"""
|
|
num_actual_tokens: int
|
|
"""Total number of tokens in batch"""
|
|
max_query_len: int
|
|
"""Longest query in batch"""
|
|
max_seq_len: int
|
|
"""Longest context length in batch"""
|
|
|
|
block_table_tensor: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
causal: bool = True
|
|
|
|
# Needed by FastPrefillAttentionBuilder
|
|
logits_indices_padded: Optional[torch.Tensor] = None
|
|
num_logits_indices: Optional[int] = None
|
|
|
|
# Needed by CrossAttentionBuilder
|
|
encoder_seq_lens: Optional[np.ndarray] = None
|
|
|
|
|
|
def slice_query_start_locs(
|
|
query_start_loc: torch.Tensor,
|
|
request_slice: slice,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Creates a new query_start_loc that corresponds to the requests in
|
|
request_slice.
|
|
|
|
Note: This function creates a new tensor to hold the new query_start_locs.
|
|
This will break cudagraph compatibility.
|
|
"""
|
|
return query_start_loc[request_slice.start: request_slice.stop + 1] -\
|
|
query_start_loc[request_slice.start]
|
|
|
|
|
|
def _make_metadata_with_slice(
|
|
ubatch_slice: UBatchSlice,
|
|
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
|
|
"""
|
|
This function creates a new CommonAttentionMetadata that corresponds to
|
|
the requests included in ubatch_slice
|
|
"""
|
|
|
|
request_slice = ubatch_slice.request_slice
|
|
token_slice = ubatch_slice.token_slice
|
|
|
|
query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc,
|
|
request_slice)
|
|
assert len(query_start_loc) >= 2, (
|
|
f"query_start_loc must have at least 2 elements, "
|
|
f"got {len(query_start_loc)}")
|
|
query_start_loc_cpu = slice_query_start_locs(
|
|
attn_metadata.query_start_loc_cpu, request_slice)
|
|
|
|
seq_lens = attn_metadata.seq_lens[request_slice]
|
|
seq_lens_cpu = attn_metadata.seq_lens_cpu[request_slice]
|
|
max_seq_len = int(seq_lens_cpu.max())
|
|
num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[
|
|
request_slice]
|
|
|
|
num_requests = request_slice.stop - request_slice.start
|
|
num_actual_tokens = token_slice.stop - token_slice.start
|
|
max_query_len = int(
|
|
torch.max(torch.abs(query_start_loc_cpu[1:] -
|
|
query_start_loc_cpu[:-1])).item())
|
|
|
|
# This is to account for the case where we are in a dummy
|
|
# run and query_start_loc_cpu is full of 0s
|
|
if max_query_len == 0:
|
|
max_query_len = attn_metadata.max_query_len
|
|
|
|
block_table_tensor = attn_metadata.block_table_tensor[request_slice]
|
|
slot_mapping = attn_metadata.slot_mapping[token_slice]
|
|
|
|
return CommonAttentionMetadata(
|
|
query_start_loc=query_start_loc,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
seq_lens=seq_lens,
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
|
num_reqs=num_requests,
|
|
num_actual_tokens=num_actual_tokens,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_seq_len,
|
|
block_table_tensor=block_table_tensor,
|
|
slot_mapping=slot_mapping,
|
|
)
|
|
|
|
|
|
def split_attn_metadata(
|
|
ubatch_slices: list[UBatchSlice],
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> list[CommonAttentionMetadata]:
|
|
"""
|
|
Creates a new CommonAttentionMetadata instance that corresponds to the
|
|
requests for each UBatchSlice in ubatch_slices.
|
|
|
|
Note: This function does not modify common_attn_metadata
|
|
"""
|
|
results = []
|
|
for ubatch_slice in ubatch_slices:
|
|
results.append(
|
|
_make_metadata_with_slice(ubatch_slice, common_attn_metadata))
|
|
return results
|
|
|
|
|
|
M = TypeVar("M")
|
|
|
|
|
|
class AttentionCGSupport(enum.Enum):
|
|
""" Constants for the cudagraph support of the attention backend
|
|
Here we do not consider the cascade attention, as currently
|
|
it is never cudagraph supported."""
|
|
|
|
ALWAYS = 3
|
|
"""Cudagraph always supported; supports mixed-prefill-decode"""
|
|
UNIFORM_BATCH = 2
|
|
"""Cudagraph supported for batches the only contain query lengths that are
|
|
the same, this can be used for spec-decode
|
|
i.e. "decodes" are 1 + num_speculative_tokens"""
|
|
UNIFORM_SINGLE_TOKEN_DECODE = 1
|
|
"""Cudagraph supported for batches the only contain query_len==1 decodes"""
|
|
NEVER = 0
|
|
"""NO cudagraph support"""
|
|
|
|
|
|
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|
# Does this backend/builder support CUDA Graphs for attention (default: no).
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
AttentionCGSupport.NEVER
|
|
# Does this backend/builder reorder the batch?
|
|
# If not, set this to None. Otherwise set it to the query
|
|
# length that will be pulled into the front of the batch.
|
|
reorder_batch_threshold: ClassVar[Optional[int]] = None
|
|
|
|
@abstractmethod
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|
vllm_config: VllmConfig, device: torch.device):
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.layer_names = layer_names
|
|
self.vllm_config = vllm_config
|
|
self.device = device
|
|
|
|
@abstractmethod
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> M:
|
|
"""
|
|
Central method that builds attention metadata.
|
|
Some builders (MLA) require reorder_batch to be called prior to build.
|
|
|
|
Args:
|
|
common_prefix_len: The length of the common prefix of the batch.
|
|
common_attn_metadata: The common attention metadata.
|
|
fast_build: The meta-data will prioritize speed of building over
|
|
then speed at execution. Can be used for spec-decode where the
|
|
result of a build call may only be used for few layers/iters.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
"""
|
|
Update the order of requests in the batch based on the attention
|
|
backend's needs. For example, some attention backends (namely MLA) may
|
|
want to separate requests based on if the attention computation will be
|
|
compute-bound or memory-bound.
|
|
|
|
Args:
|
|
input_batch: input batch
|
|
scheduler_output: scheduler output.
|
|
|
|
Returns:
|
|
True if the batch was modified, False otherwise.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def build_for_cudagraph_capture(
|
|
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
"""
|
|
Build attention metadata for CUDA graph capture. Uses build by default.
|
|
Subclasses that override this method should call self.build or
|
|
super().build_for_cudagraph_capture.
|
|
"""
|
|
return self.build(common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata)
|
|
|
|
def build_for_drafting(
|
|
self,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
draft_index: int,
|
|
) -> M:
|
|
"""
|
|
Build attention metadata for draft model. Uses build by default.
|
|
|
|
Args:
|
|
common_attn_metadata: The common attention metadata.
|
|
draft_index: The index of the current draft operation.
|
|
When speculating a chain of tokens, this index refers to the
|
|
draft attempt for the i-th token.
|
|
For tree-based attention, this index instead refers to the
|
|
draft attempt for the i-th level in the tree of tokens.
|
|
"""
|
|
return self.build(common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata,
|
|
fast_build=True)
|
|
|
|
def use_cascade_attention(
|
|
self,
|
|
common_prefix_len: int,
|
|
query_lens: np.ndarray,
|
|
num_query_heads: int,
|
|
num_kv_heads: int,
|
|
use_alibi: bool,
|
|
use_sliding_window: bool,
|
|
use_local_attention: bool,
|
|
num_sms: int,
|
|
) -> bool:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_kv_cache_layout():
|
|
# Format specified by the code.
|
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
|
|
|
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
|
|
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
|
|
logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \
|
|
"Setting KV cache layout to %s.", cache_layout)
|
|
return cache_layout
|
|
|
|
# Format specified by the user.
|
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
|
# When neither the user nor the override specified a layout, get default
|
|
if cache_layout is None:
|
|
cache_layout = get_kv_connector_cache_layout()
|
|
else:
|
|
assert is_valid_kv_cache_layout(cache_layout)
|
|
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
|
return cache_layout
|
|
|
|
|
|
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
|
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
|
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
|
|
|
|
|
@dataclass
|
|
class PerLayerParameters:
|
|
"""
|
|
Currently, FlashInfer backend only support models in which all layers share
|
|
the same values for the following hyperparameters. Should not be used for
|
|
trtllm-gen backend since it supports different values for the following
|
|
hyperparameters.
|
|
"""
|
|
|
|
window_left: int
|
|
logits_soft_cap: Optional[float]
|
|
sm_scale: float
|
|
has_sinks: bool = False
|
|
|
|
|
|
def get_per_layer_parameters(
|
|
vllm_config: VllmConfig, layer_names: list[str],
|
|
cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]:
|
|
"""
|
|
Scan layers in `layer_names` and determine some hyperparameters
|
|
to use during `plan`.
|
|
"""
|
|
|
|
layers = get_layers_from_vllm_config(vllm_config, Attention, layer_names)
|
|
per_layer_params: dict[str, PerLayerParameters] = {}
|
|
|
|
for key, layer in layers.items():
|
|
impl = layer.impl
|
|
assert isinstance(impl, cls_)
|
|
|
|
# Infer hyperparameters from the attention layer
|
|
window_size = getattr(impl, "sliding_window", None)
|
|
window_left = window_size[0] if window_size is not None else -1
|
|
logits_soft_cap = getattr(impl, "logits_soft_cap", None)
|
|
sm_scale = impl.scale
|
|
has_sinks = getattr(impl, "sinks", None) is not None
|
|
|
|
per_layer_params[key] = PerLayerParameters(window_left,
|
|
logits_soft_cap, sm_scale,
|
|
has_sinks)
|
|
|
|
return per_layer_params
|
|
|
|
|
|
def infer_global_hyperparameters(
|
|
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
|
|
"""
|
|
Currently, FlashInfer backend other than trtllm-gen
|
|
only support models in which all layers share
|
|
the same values for the following hyperparameters:
|
|
- `window_left`
|
|
- `logits_soft_cap`
|
|
- `sm_scale`
|
|
|
|
So this function asserts that all layers share the same values for these
|
|
hyperparameters and returns the global values.
|
|
"""
|
|
|
|
assert len(per_layer_params) > 0, "No attention layers found in the model."
|
|
|
|
param_sets = list(per_layer_params.values())
|
|
global_params = param_sets[0]
|
|
|
|
# trtllm attention doesn't need global hyper params so disable the check
|
|
if not envs.VLLM_USE_TRTLLM_ATTENTION:
|
|
for params in param_sets:
|
|
if params.window_left != global_params.window_left:
|
|
raise ValueError(
|
|
"Window left is not the same for all layers. " \
|
|
"One potential fix is to set disable_sliding_window=True")
|
|
assert params == global_params, (
|
|
"FlashInfer backend currently only supports models in which all"
|
|
"layers share the same values "
|
|
"for the following hyperparameters:"
|
|
"`window_left`, `logits_soft_cap`, `sm_scale`.")
|
|
|
|
return global_params
|
|
|
|
|
|
#
|
|
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
|
# local attention blocks, where each block is passed to the attention kernel
|
|
# as an independent local ("virtual") batch item.
|
|
#
|
|
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
|
# q_seqlens = [4, 10, 5]
|
|
# kv_seqlens = [6, 17, 9]
|
|
# Then normally for regular attention we would compute with an attention mask
|
|
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1 1 1 1 1
|
|
# 3 | 1 1 1 1 1 1
|
|
#
|
|
# for local attention (with attn_chunk_size = 4) we would compute with an
|
|
# attention mask like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# We can simulate this mask using standard flash-attention by breaking the
|
|
# sequences into local ("virtual") batches, where each local batch item is a
|
|
# local attention block, so in this case batch idx 0 would be broken up into:
|
|
#
|
|
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
|
# k_toks > 0 1 2 3
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
|
# k_toks > 4 5
|
|
# q_toks v _____________
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# e.g. if we have:
|
|
# attn_chunk_size = 4
|
|
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
|
# Then this function would return:
|
|
# __b0__ ______b1______ __b2__ < orig batch indices
|
|
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
|
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
|
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
|
def make_local_attention_virtual_batches(
|
|
attn_chunk_size: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
block_size: int = 0,
|
|
) -> CommonAttentionMetadata:
|
|
query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy()
|
|
seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy()
|
|
block_table = common_attn_metadata.block_table_tensor
|
|
device = common_attn_metadata.query_start_loc.device
|
|
|
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
|
actual_batch_size = seq_lens_np.shape[0]
|
|
|
|
# Handle if we are starting in the middle of a local attention block,
|
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
|
# the number of tokens that are not in the first local attention block and
|
|
# then we can simply use a cdiv for the rest.
|
|
# For example if we have:
|
|
# attn_chunk_size = 4
|
|
# q_seqlens = [4, 10, 5]
|
|
# k_seqlens = [6, 17, 9]
|
|
# Then we would get:
|
|
# new_tokens_in_first_block = [2, 1, 4]
|
|
# local_blocks = [2, 4, 2]
|
|
q_tokens_in_first_block = np.minimum(
|
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
|
q_seqlens).astype(np.int32)
|
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
|
attn_chunk_size)
|
|
|
|
# Once we know the number of local blocks we can compute the request spans
|
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
|
# have to make,
|
|
# For the above example we would get:
|
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
|
#
|
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
|
cu_num_blocks = np.cumsum(local_blocks)
|
|
virtual_batches = cu_num_blocks[-1]
|
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
|
# first and last blocks could be partial
|
|
seqlens_q_local = \
|
|
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
|
# set the first block since this may be a partial block
|
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
|
# set the remaining blocks
|
|
seqlens_q_local[arange > 0] = np.minimum(
|
|
seqlens_q_local - attn_chunk_size * (arange - 1),
|
|
attn_chunk_size)[arange > 0]
|
|
|
|
# convert from q_seqlens to cu_seqlens_q
|
|
cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32)
|
|
np.cumsum(seqlens_q_local, out=cu_seqlens_q_local[1:])
|
|
cu_seqlens_q_local[0] = 0
|
|
|
|
# compute the seqlens_k_local,
|
|
# basically a full local attention block for all but the last block in each
|
|
# batch
|
|
# For our example this will be:
|
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
|
seqlens_k_local = np.full(cu_num_blocks[-1],
|
|
attn_chunk_size,
|
|
dtype=np.int32)
|
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
|
num_computed_tokens_local = seqlens_k_local - seqlens_q_local
|
|
|
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
|
(rarange * attn_chunk_size + \
|
|
np.repeat(tokens_in_last_block, local_blocks))
|
|
# For the example the local attention blocks start at:
|
|
# _b0_ _____b1_____ _b2_
|
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
|
block_starts = k_seqstarts_absolute // block_size
|
|
assert attn_chunk_size % block_size == 0, \
|
|
f"attn_chunk_size {attn_chunk_size} is not " \
|
|
f"divisible by block_size {block_size}"
|
|
pages_per_local_batch = attn_chunk_size // block_size
|
|
|
|
# Create a block_table for the local attention blocks
|
|
# For out example if we have a block-table like (assuming block_size=2):
|
|
# block_table = [
|
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
|
# ]
|
|
# Then for the local batches we would want a block-table like
|
|
# block_table_local = [
|
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
|
# ]
|
|
block_indices = (block_starts[:, None] +
|
|
np.arange(pages_per_local_batch, dtype=np.int32))
|
|
block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] -
|
|
1)
|
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
|
local_blocks * pages_per_local_batch)
|
|
|
|
# NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance
|
|
# regression when using numpy arrays (batch and block indices) to index into
|
|
# torch tensor (block_table). As a workaround, convert numpy arrays to torch
|
|
# tensor first, which recovers perf.
|
|
batch_indices_torch = torch.from_numpy(batch_indices)
|
|
block_indices_torch = torch.from_numpy(block_indices)
|
|
block_table_local = block_table[batch_indices_torch, block_indices_torch]\
|
|
.view(virtual_batches, -1)
|
|
|
|
query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local)
|
|
seq_lens_cpu = torch.from_numpy(seqlens_k_local)
|
|
max_seq_len = int(seq_lens_cpu.max())
|
|
|
|
return CommonAttentionMetadata(
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
query_start_loc=query_start_loc_cpu.to(device=device,
|
|
non_blocking=True),
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
seq_lens=seq_lens_cpu.to(device=device, non_blocking=True),
|
|
num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local),
|
|
num_reqs=len(seq_lens_cpu),
|
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
|
max_query_len=seqlens_q_local.max(),
|
|
max_seq_len=max_seq_len,
|
|
block_table_tensor=block_table_local,
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
causal=True,
|
|
)
|
|
|
|
|
|
def make_kv_sharing_fast_prefill_common_attn_metadata(
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> CommonAttentionMetadata:
|
|
if common_attn_metadata.max_query_len == 1:
|
|
# All requests are decode (assume 1 token for now)
|
|
# Skip computing fast prefill path
|
|
return common_attn_metadata
|
|
|
|
assert common_attn_metadata.logits_indices_padded is not None
|
|
assert common_attn_metadata.num_logits_indices is not None
|
|
|
|
logits_indices_padded = common_attn_metadata.logits_indices_padded
|
|
num_logits_indices = common_attn_metadata.num_logits_indices
|
|
# Get rid of CUDAGraph padding, if any
|
|
logits_indices = logits_indices_padded[:num_logits_indices]
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
# Example inputs
|
|
# num_reqs: 3
|
|
# generation_indices: [14, 18, 19, 27]
|
|
# query_start_loc: [0, 15, 20, 28]
|
|
# seq_lens: [41, 31, 40]
|
|
|
|
# Find how many decode indices belong to each request
|
|
# request_ids: [0, 1, 1, 2]
|
|
request_ids = torch.bucketize(logits_indices,
|
|
query_start_loc[1:],
|
|
right=True)
|
|
|
|
# Figure out how many tokens are in each request
|
|
# num_decode_tokens: [1, 2, 1]
|
|
num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs)
|
|
|
|
# Calculate new query_start_loc with tokens in generation_indices
|
|
# decode_query_start_loc: [0, 1, 3, 4]
|
|
decode_query_start_loc = torch.empty(num_reqs + 1,
|
|
device=query_start_loc.device,
|
|
dtype=query_start_loc.dtype)
|
|
|
|
decode_query_start_loc[0] = 0
|
|
decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0)
|
|
decode_max_query_len = int(num_decode_tokens.max().item())
|
|
total_num_decode_tokens = int(num_decode_tokens.sum().item())
|
|
|
|
common_attn_metadata = CommonAttentionMetadata(
|
|
query_start_loc=decode_query_start_loc,
|
|
query_start_loc_cpu=decode_query_start_loc.to("cpu",
|
|
non_blocking=True),
|
|
seq_lens=seq_lens,
|
|
seq_lens_cpu=seq_lens.to("cpu", non_blocking=True),
|
|
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
|
|
num_reqs=num_reqs,
|
|
num_actual_tokens=total_num_decode_tokens,
|
|
max_query_len=decode_max_query_len,
|
|
max_seq_len=common_attn_metadata.max_seq_len,
|
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
slot_mapping=common_attn_metadata.slot_mapping,
|
|
causal=True,
|
|
)
|
|
return common_attn_metadata
|
|
|
|
|
|
def subclass_attention_backend(
|
|
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
|
builder_cls: type[AttentionMetadataBuilder[M]]
|
|
) -> type[AttentionBackend]:
|
|
"""
|
|
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
|
"""
|
|
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
|
|
|
return type(name, (attention_backend_cls, ),
|
|
{"get_builder_cls": lambda: builder_cls})
|
|
|
|
|
|
def split_decodes_and_prefills(
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
decode_threshold: int = 1,
|
|
) -> tuple[int, int, int, int]:
|
|
"""
|
|
Assuming a reordered batch, finds the boundary between prefill and decode
|
|
requests.
|
|
|
|
Args:
|
|
common_attn_metadata: CommonAttentionMetadata object containing the
|
|
batch metadata.
|
|
decode_threshold: The maximum query length to be considered a decode.
|
|
|
|
Returns:
|
|
num_decodes: The number of decode requests.
|
|
num_prefills: The number of prefill requests.
|
|
num_decode_tokens: The number of tokens in the decode requests.
|
|
num_prefill_tokens: The number of tokens in the prefill requests.
|
|
"""
|
|
max_query_len = common_attn_metadata.max_query_len
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
num_tokens = common_attn_metadata.num_actual_tokens
|
|
query_start_loc = common_attn_metadata.query_start_loc_cpu
|
|
|
|
if max_query_len <= decode_threshold:
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
|
is_prefill = query_lens > decode_threshold
|
|
if not torch.any(is_prefill):
|
|
return num_reqs, 0, num_tokens, 0
|
|
|
|
first_prefill = is_prefill.int().argmax(dim=-1).item()
|
|
assert torch.all(query_lens[first_prefill:] > decode_threshold)
|
|
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
|
|
num_decodes = first_prefill
|
|
num_prefills = num_reqs - num_decodes
|
|
num_decode_tokens = query_start_loc[first_prefill].item()
|
|
num_prefill_tokens = num_tokens - num_decode_tokens
|
|
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
|
|
|
|
|
def reorder_batch_to_split_decodes_and_prefills(
|
|
input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput",
|
|
decode_threshold: int = 1,
|
|
) -> bool:
|
|
"""
|
|
Reorders the batch to split into prefill and decode requests; places all
|
|
requests with <= decode_threshold tokens at the front of the batch.
|
|
|
|
Returns:
|
|
True if the batch was modified, False otherwise.
|
|
"""
|
|
# We now want to reorder the batch so that the "decode" requests are at
|
|
# the front and the "prefill" requests are at the back using the least
|
|
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
|
|
# requests where attention is likely memory-bound and "prefill" to mean
|
|
# requests where attention is likely compute-bound, TODO(lucas): figure out
|
|
# a better naming here)
|
|
decodes = []
|
|
prefills = []
|
|
num_decode_tokens = 0
|
|
num_prefill_tokens = 0
|
|
|
|
for i, req_id in enumerate(input_batch.req_ids):
|
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
|
# for now treat 1 scheduled token as "decode" even if it's not,
|
|
# we should update this to something like < 8 in the future but
|
|
# currently the TritonMLA._forward_decode only supports
|
|
# num_tokens = 1
|
|
if num_tokens <= decode_threshold:
|
|
decodes.append(i)
|
|
num_decode_tokens += num_tokens
|
|
else:
|
|
prefills.append(i)
|
|
num_prefill_tokens += num_tokens
|
|
|
|
# We hope that this is fairly minimal since decodes
|
|
# should be around for a number of iterations so hopefully they are
|
|
# relatively stationary (and new request are generally appended to the
|
|
# persistent batch so already should be at the back)
|
|
# To achieve this we loop over the decodes in descending order and
|
|
# the prefills in ascending order. We swap decodes from the "back"
|
|
# i.e. past where the last decode should be in the reodorered with
|
|
# prefills from the front of the batch.
|
|
# `decodes` and `prefills` are already in ascending order just based on
|
|
# the above loop
|
|
num_decodes = len(decodes)
|
|
num_prefills = len(prefills)
|
|
modified_batch = False
|
|
|
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
|
# If the decode is at the "back" of the batch, i, we can swap it
|
|
# with the prefill closest to the front of the batch
|
|
decode_idx = decodes[num_decodes - i]
|
|
if decode_idx < num_decodes:
|
|
break
|
|
|
|
input_batch.swap_states(prefills[i - 1], decode_idx)
|
|
modified_batch = True
|
|
|
|
return modified_batch
|
|
|
|
|
|
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
|
|
('logits_indices_padded', Optional[torch.Tensor], None),
|
|
('num_logits_indices', int, 0),
|
|
]
|
|
|
|
|
|
def subclass_attention_metadata(
|
|
name_prefix: str,
|
|
metadata_cls: Any,
|
|
fields: list[tuple[str, Any, Any]],
|
|
) -> Any:
|
|
"""
|
|
Return a new subclass of `metadata_cls` with additional fields
|
|
"""
|
|
name: str = name_prefix + metadata_cls.__name__ # type: ignore
|
|
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
|
|
return Wrapped
|
|
|
|
|
|
@runtime_checkable
|
|
class KVSharingFastPrefillMetadata(Protocol):
|
|
logits_indices_padded: torch.Tensor
|
|
num_logits_indices: int
|
|
|
|
|
|
def create_fast_prefill_custom_backend(
|
|
prefix: str,
|
|
underlying_attn_backend: AttentionBackend,
|
|
) -> type[AttentionBackend]:
|
|
|
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
|
|
|
class FastPrefillAttentionBuilder(underlying_builder): # type: ignore
|
|
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> AttentionMetadata:
|
|
new_common_attn_metadata =\
|
|
make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata)
|
|
metadata = super().build(common_prefix_len,
|
|
new_common_attn_metadata, fast_build)
|
|
|
|
class KVSharingFastPrefillAttentionMetadata(
|
|
metadata.__class__, # type: ignore
|
|
KVSharingFastPrefillMetadata):
|
|
|
|
def __init__(self, metadata, common_attn_metadata):
|
|
# Shallow copy all fields in metadata cls
|
|
for field in fields(metadata.__class__):
|
|
setattr(self, field.name,
|
|
getattr(metadata, field.name))
|
|
|
|
# Set additional fields that will be used in model code
|
|
assert (common_attn_metadata.logits_indices_padded
|
|
is not None
|
|
and common_attn_metadata.num_logits_indices
|
|
is not None)
|
|
self.logits_indices_padded = \
|
|
common_attn_metadata.logits_indices_padded
|
|
self.num_logits_indices = \
|
|
common_attn_metadata.num_logits_indices
|
|
|
|
return KVSharingFastPrefillAttentionMetadata(
|
|
metadata, common_attn_metadata)
|
|
|
|
attn_backend = subclass_attention_backend(
|
|
name_prefix=prefix,
|
|
attention_backend_cls=underlying_attn_backend,
|
|
builder_cls=FastPrefillAttentionBuilder)
|
|
|
|
return attn_backend
|