mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:45:37 +08:00
274 lines
11 KiB
Python
274 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import itertools
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.config import VllmConfig
|
|
from vllm.utils import cdiv
|
|
from vllm.v1.attention.backends.mamba_attn import (
|
|
BaseMambaAttentionMetadataBuilder)
|
|
from vllm.v1.attention.backends.utils import (PAD_SLOT_ID,
|
|
CommonAttentionMetadata,
|
|
compute_causal_conv1d_metadata,
|
|
split_decodes_and_prefills)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
|
|
def compute_varlen_chunk_metadata(
|
|
query_start_loc: torch.Tensor,
|
|
chunk_size: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
|
|
|
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
|
and a physical `chunk_size`, returns three tensors on the same device:
|
|
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
|
logical-chunk lengths (each logical chunk never crosses a sequence or
|
|
physical-chunk boundary).
|
|
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
|
for each sequence (=-1 for empty sequences).
|
|
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
|
chunk in order.
|
|
|
|
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
|
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
|
(and other callers) can avoid duplicating the logic.
|
|
"""
|
|
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
|
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
|
device = query_start_loc.device
|
|
|
|
qsl64 = query_start_loc.to(torch.int64)
|
|
starts = qsl64[:-1].tolist()
|
|
ends = qsl64[1:].tolist()
|
|
total = int(qsl64[-1].item())
|
|
|
|
chunk_lens: list[int] = []
|
|
seq_idx_chunks: list[int] = []
|
|
last_chunk_indices: list[int] = [-1] * len(starts)
|
|
|
|
for b, (s, e) in enumerate(zip(starts, ends)):
|
|
if e <= s:
|
|
# empty sequence
|
|
continue
|
|
pos = s
|
|
while pos < e:
|
|
# split at both sequence boundaries and physical chunk boundaries
|
|
room = chunk_size - (pos % chunk_size)
|
|
take = min(room, e - pos)
|
|
chunk_lens.append(int(take))
|
|
seq_idx_chunks.append(b)
|
|
last_chunk_indices[b] = len(chunk_lens) - 1
|
|
pos += take
|
|
|
|
# Exclusive prefix sum over logical-chunk lengths
|
|
if chunk_lens:
|
|
cu_chunk_seqlens = torch.tensor([0] +
|
|
list(itertools.accumulate(chunk_lens)),
|
|
device=device,
|
|
dtype=torch.int32)
|
|
# Final boundary must equal total tokens
|
|
assert int(cu_chunk_seqlens[-1].item()) == total
|
|
else:
|
|
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
|
|
|
last_chunk_indices_t = (torch.tensor(
|
|
last_chunk_indices, device=device, dtype=torch.int32)
|
|
if len(starts) > 0 else torch.empty(
|
|
(0, ), device=device, dtype=torch.int32))
|
|
seq_idx_chunks_t = torch.tensor(seq_idx_chunks,
|
|
device=device,
|
|
dtype=torch.int32)
|
|
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
|
|
|
|
|
class Mamba2AttentionBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
|
return Mamba2AttentionMetadataBuilder
|
|
|
|
|
|
@dataclass
|
|
class Mamba2AttentionMetadata:
|
|
num_prefills: int
|
|
num_prefill_tokens: int
|
|
num_decodes: int
|
|
num_decode_tokens: int
|
|
query_start_loc_p: torch.Tensor
|
|
seq_lens: torch.Tensor
|
|
|
|
prep_initial_states: bool
|
|
chunk_size: int
|
|
|
|
# The following tensors only contain prefill requests and will be None if
|
|
# the batch has no prefill request.
|
|
has_initial_states_p: Optional[torch.Tensor]
|
|
seq_idx_p: Optional[torch.Tensor]
|
|
|
|
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
|
# each chunk, its offests into the varlen sequence dimension. It is defined
|
|
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
|
# cu_chunk_seqlen_p[i+1].
|
|
cu_chunk_seqlen_p: Optional[torch.Tensor]
|
|
|
|
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
|
# index of the last chunk for every sequence in the (prefill) batch.
|
|
last_chunk_indices_p: Optional[torch.Tensor]
|
|
|
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
|
|
|
# The following attributes are for triton implementation of causal_conv1d
|
|
nums_dict: Optional[dict] = None
|
|
batch_ptr: Optional[torch.Tensor] = None
|
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
|
|
|
|
|
class Mamba2AttentionMetadataBuilder(
|
|
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
|
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|
vllm_config: VllmConfig, device: torch.device):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
|
assert self.chunk_size is not None, (
|
|
"chunk_size needs to be set in the model config for Mamba2 models")
|
|
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
seq_lens = common_attn_metadata.seq_lens
|
|
|
|
query_start_loc_p = None
|
|
seq_idx_p = None
|
|
cu_chunk_seqlen_p = None
|
|
last_chunk_indices_p = None
|
|
|
|
# Need flags to indicate if there are initial states
|
|
has_initial_states_p = None
|
|
prep_initial_states = False
|
|
|
|
# for causal_conv1d
|
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
|
|
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
split_decodes_and_prefills(
|
|
common_attn_metadata,
|
|
decode_threshold=self.reorder_batch_threshold))
|
|
|
|
# Compute seq_idx for prefill only
|
|
if num_prefills > 0:
|
|
#[batch,]
|
|
has_initial_states_cpu = (
|
|
common_attn_metadata.
|
|
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
|
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
|
has_initial_states_p = has_initial_states_cpu.to(
|
|
common_attn_metadata.query_start_loc.device)
|
|
|
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
|
-num_prefills - 1:] - num_decode_tokens
|
|
|
|
num_computed_tokens_p = \
|
|
common_attn_metadata.num_computed_tokens_cpu[
|
|
num_reqs - num_prefills:num_reqs]
|
|
query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[
|
|
-num_prefills - 1:] - num_decode_tokens
|
|
|
|
# The code below carefully constructs the chunks such that:
|
|
# 1. Chunks contain tokens from a *single* sequence only.
|
|
# 2. For every sequence, we are guaranteed that we can
|
|
# retrieve the mamba state *every* chunk_size tokens.
|
|
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
|
# Constraint (2) dramatically simplifies the implementation
|
|
# of prefix caching for mamba2 (wip). We need to take care
|
|
# of the interaction with chunked prefill in order to
|
|
# satisfy constraint (2).
|
|
# TODO (tdoublep): This code could probably be optimized.
|
|
cu_chunk_seqlen = []
|
|
seq_idx = []
|
|
last_chunk_indices = []
|
|
seqlen_pos = 0
|
|
for req_idx in range(num_prefills):
|
|
this_num_computed = num_computed_tokens_p[req_idx].item()
|
|
this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item(
|
|
) - query_start_loc_p_cpu[req_idx].item()
|
|
|
|
# if computed tokens are not chunk-aligned, use the first
|
|
# chunk to finish it off
|
|
if this_num_computed % self.chunk_size != 0:
|
|
seq_idx.append(req_idx)
|
|
cu_chunk_seqlen.append(seqlen_pos)
|
|
# how many tokens to finish the chunk?
|
|
chunk_len = cdiv(this_num_computed, self.chunk_size
|
|
) * self.chunk_size - this_num_computed
|
|
# we can only use at most this_new_tokens
|
|
chunk_len = min(chunk_len, this_new_tokens)
|
|
seqlen_pos += chunk_len
|
|
this_new_tokens -= chunk_len
|
|
|
|
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
|
for chunk in range(n_chunks):
|
|
seq_idx.append(req_idx)
|
|
cu_chunk_seqlen.append(seqlen_pos)
|
|
chunk_len = min(self.chunk_size, this_new_tokens)
|
|
seqlen_pos += chunk_len
|
|
this_new_tokens -= chunk_len
|
|
|
|
assert this_new_tokens == 0
|
|
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
|
|
|
cu_chunk_seqlen.append(seqlen_pos)
|
|
|
|
seq_idx_p = torch.as_tensor(seq_idx,
|
|
device=query_start_loc_p.device,
|
|
dtype=torch.int32)
|
|
cu_chunk_seqlen_p = torch.as_tensor(
|
|
cu_chunk_seqlen,
|
|
device=query_start_loc_p.device,
|
|
dtype=torch.int32)
|
|
last_chunk_indices_p = torch.as_tensor(
|
|
last_chunk_indices,
|
|
device=query_start_loc_p.device,
|
|
dtype=torch.int32)
|
|
|
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
|
|
|
elif num_decodes <= self.decode_cudagraph_max_bs:
|
|
# Pad state tensor for CUDA graph
|
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
|
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
|
|
non_blocking=True)
|
|
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
|
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
|
|
|
attn_metadata = Mamba2AttentionMetadata(
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
num_decodes=num_decodes,
|
|
num_decode_tokens=num_decode_tokens,
|
|
query_start_loc_p=query_start_loc_p,
|
|
seq_lens=seq_lens,
|
|
prep_initial_states=prep_initial_states,
|
|
chunk_size=self.chunk_size,
|
|
has_initial_states_p=has_initial_states_p,
|
|
seq_idx_p=seq_idx_p,
|
|
state_indices_tensor=state_indices_tensor,
|
|
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
|
last_chunk_indices_p=last_chunk_indices_p,
|
|
nums_dict=nums_dict,
|
|
batch_ptr=batch_ptr,
|
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
|
)
|
|
return attn_metadata
|