mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-13 14:27:04 +08:00
317 lines
13 KiB
Python
317 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import abc
|
|
import functools
|
|
from abc import abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.utils import cdiv
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
|
get_kv_connector_cache_layout)
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CommonAttentionMetadata:
|
|
"""
|
|
Per-batch attention metadata, shared across layers and backends.
|
|
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
|
"""
|
|
|
|
query_start_loc: torch.Tensor
|
|
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
|
seq_lens: torch.Tensor
|
|
"""(batch_size,), the length of each request including both computed tokens
|
|
and newly scheduled tokens"""
|
|
|
|
num_reqs: int
|
|
"""Number of requests"""
|
|
num_actual_tokens: int
|
|
"""Total number of tokens in batch"""
|
|
max_query_len: int
|
|
"""Longest query in batch"""
|
|
|
|
M = TypeVar("M")
|
|
|
|
|
|
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|
# Does this backend/builder support CUDA Graphs for attention.
|
|
full_cudagraph_supported: ClassVar[bool] = False
|
|
|
|
@abstractmethod
|
|
def build(self, common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
"""
|
|
Central method that builds attention metadata.
|
|
Some builders (MLA) require reorder_batch to be called prior to build.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def can_run_in_cudagraph(
|
|
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
|
"""
|
|
Can this batch (with given metadata) use CUDA Graphs for attention.
|
|
"""
|
|
return False
|
|
|
|
def build_for_cudagraph_capture(
|
|
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
"""
|
|
Build attention metadata for CUDA graph capture. Uses build by default.
|
|
Subclasses that override this method should call self.build or
|
|
super().build_for_cudagraph_capture.
|
|
"""
|
|
return self.build(common_prefix_len=0,
|
|
common_attn_metadata=common_attn_metadata)
|
|
|
|
def use_cascade_attention(
|
|
self,
|
|
common_prefix_len: int,
|
|
query_lens: np.ndarray,
|
|
num_query_heads: int,
|
|
num_kv_heads: int,
|
|
use_alibi: bool,
|
|
use_sliding_window: bool,
|
|
num_sms: int,
|
|
) -> bool:
|
|
return False
|
|
|
|
def reorder_batch(self, input_batch: "InputBatch",
|
|
scheduler_output: "SchedulerOutput") -> bool:
|
|
"""
|
|
This method can reorder the batch if desired by the backend.
|
|
:return: Has the batch been reordered (default False).
|
|
"""
|
|
return False
|
|
|
|
|
|
def slice_query_start_locs(
|
|
query_start_loc: torch.Tensor,
|
|
req_slice: slice,
|
|
) -> torch.Tensor:
|
|
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
|
|
query_start_loc[req_slice.start]
|
|
|
|
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
|
static_forward_context):
|
|
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
|
f"is not valid: target layer {target_layer_name} ")
|
|
|
|
if current_layer_name == target_layer_name:
|
|
raise ValueError(error_msg +
|
|
"cannot be the same as the current layer.")
|
|
|
|
if target_layer_name not in static_forward_context:
|
|
from vllm.model_executor.models.utils import extract_layer_index
|
|
|
|
# If target layer name is not in the static fwd context, it means either
|
|
# a) the target layer does not come BEFORE the current layer, or
|
|
# b) the target layer is not an Attention layer that exists in the model
|
|
current_layer_idx = extract_layer_index(current_layer_name)
|
|
target_layer_idx = extract_layer_index(target_layer_name)
|
|
if current_layer_idx <= target_layer_idx:
|
|
raise ValueError(error_msg + "must come before the current layer.")
|
|
else:
|
|
raise ValueError(error_msg +
|
|
"is not a valid Attention layer in the model.")
|
|
|
|
# Currently KV sharing is only supported between layers of the same type
|
|
target_layer_attn_type = static_forward_context[
|
|
target_layer_name].attn_type
|
|
expected = static_forward_context[current_layer_name].attn_type
|
|
if target_layer_attn_type != expected:
|
|
raise ValueError(
|
|
error_msg +
|
|
f"must be the same type as the current layer ({expected}).")
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_kv_cache_layout():
|
|
# Override with format specified by the user.
|
|
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
|
if cache_layout is None:
|
|
cache_layout = get_kv_connector_cache_layout()
|
|
else:
|
|
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
|
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
|
|
|
return cache_layout
|
|
|
|
|
|
#
|
|
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
|
|
# local attention blocks, where each block is passed to the attention kernel
|
|
# as an independent local ("virtual") batch item.
|
|
#
|
|
# For example, if are performing a chunked prefill a batch of 3 sequences:
|
|
# q_seqlens = [4, 10, 5]
|
|
# kv_seqlens = [6, 17, 9]
|
|
# Then normally for regular attention we would compute with an attention mask
|
|
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1 1 1 1 1
|
|
# 3 | 1 1 1 1 1 1
|
|
#
|
|
# for local attention (with attn_chunk_size = 4) we would compute with an
|
|
# attention mask like:
|
|
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
|
|
# k_toks > 0 1 2 3 4 5
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# We can simulate this mask using standard flash-attention by breaking the
|
|
# sequences into local ("virtual") batches, where each local batch item is a
|
|
# local attention block, so in this case batch idx 0 would be broken up into:
|
|
#
|
|
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
|
|
# k_toks > 0 1 2 3
|
|
# q_toks v _____________
|
|
# 0 | 1 1 1
|
|
# 1 | 1 1 1 1
|
|
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
|
|
# k_toks > 4 5
|
|
# q_toks v _____________
|
|
# 2 | 1
|
|
# 3 | 1 1
|
|
#
|
|
# e.g. if we have:
|
|
# attn_chunk_size = 4
|
|
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
|
|
# Then this function would return:
|
|
# __b0__ ______b1______ __b2__ < orig batch indices
|
|
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
|
|
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
|
|
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
|
|
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
|
|
def make_local_attention_virtual_batches(
|
|
attn_chunk_size: int,
|
|
query_start_loc_np: np.ndarray,
|
|
seq_lens_np: np.ndarray,
|
|
block_table: torch.Tensor,
|
|
block_size: int = 0,
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
|
|
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
|
|
actual_batch_size = seq_lens_np.shape[0]
|
|
|
|
# Handle if we are starting in the middle of a local attention block,
|
|
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
|
|
# the number of tokens that are not in the first local attention block and
|
|
# then we can simply use a cdiv for the rest.
|
|
# For example if we have:
|
|
# attn_chunk_size = 4
|
|
# q_seqlens = [4, 10, 5]
|
|
# k_seqlens = [6, 17, 9]
|
|
# Then we would get:
|
|
# new_tokens_in_first_block = [2, 1, 4]
|
|
# local_blocks = [2, 4, 2]
|
|
q_tokens_in_first_block = np.minimum(
|
|
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
|
|
q_seqlens).astype(np.int32)
|
|
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
|
|
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
|
|
attn_chunk_size)
|
|
|
|
# Once we know the number of local blocks we can compute the request spans
|
|
# for each batch idx, we can figure out the number of "virtual" requests we
|
|
# have to make,
|
|
# For the above example we would get:
|
|
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
|
|
#
|
|
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
|
|
# (TODO: max a utility to share this code with _prepare_inputs)
|
|
# arange step 1. [2, 4, 2] -> [2, 6, 8]
|
|
cu_num_blocks = np.cumsum(local_blocks)
|
|
virtual_batches = cu_num_blocks[-1]
|
|
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
|
|
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
|
|
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
|
|
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
|
|
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
|
|
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
|
|
# Then we can compute the seqlens_q_local, handling the fact that the
|
|
# first and last blocks could be partial
|
|
seqlens_q_local = \
|
|
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
|
|
# set the first block since this may be a partial block
|
|
seqlens_q_local[arange == 0] = q_tokens_in_first_block
|
|
# set the remaining blocks
|
|
seqlens_q_local[arange > 0] = np.minimum(
|
|
seqlens_q_local - attn_chunk_size * (arange - 1),
|
|
attn_chunk_size)[arange > 0]
|
|
|
|
# convert from q_seqlens to cu_seqlens_q
|
|
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
|
|
.astype(np.int32)
|
|
|
|
# compute the seqlens_k_local,
|
|
# basically a full local attention block for all but the last block in each
|
|
# batch
|
|
# For our example this will be:
|
|
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
|
|
seqlens_k_local = np.full(cu_num_blocks[-1],
|
|
attn_chunk_size,
|
|
dtype=np.int32)
|
|
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
|
|
|
|
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
|
|
(rarange * attn_chunk_size + \
|
|
np.repeat(tokens_in_last_block, local_blocks))
|
|
# For the example the local attention blocks start at:
|
|
# _b0_ _____b1_____ _b2_
|
|
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
|
|
block_starts = k_seqstarts_absolute // block_size
|
|
assert attn_chunk_size % block_size == 0, \
|
|
f"attn_chunk_size {attn_chunk_size} is not " \
|
|
f"divisible by block_size {block_size}"
|
|
pages_per_local_batch = attn_chunk_size // block_size
|
|
|
|
# Create a block_table for the local attention blocks
|
|
# For out example if we have a block-table like (assuming block_size=2):
|
|
# block_table = [
|
|
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
|
|
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
|
|
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
|
|
# ]
|
|
# Then for the local batches we would want a block-table like
|
|
# block_table_local = [
|
|
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
|
|
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
|
|
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
|
|
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
|
|
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
|
|
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
|
|
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
|
|
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
|
|
# ]
|
|
block_indices= np.broadcast_to(
|
|
np.arange(pages_per_local_batch, dtype=np.int32),
|
|
(virtual_batches, pages_per_local_batch)) \
|
|
+ np.expand_dims(block_starts, axis=1)
|
|
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
|
|
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
|
|
local_blocks * pages_per_local_batch)
|
|
block_table_local = block_table[batch_indices, block_indices]\
|
|
.view(virtual_batches, -1)
|
|
|
|
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
|
|
block_table_local
|