mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 08:35:01 +08:00
1422 lines
59 KiB
Python
1422 lines
59 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
This file implements common components for MLA implementations.
|
|
|
|
First we define:
|
|
|
|
Sq as Q sequence length
|
|
Skv as KV sequence length
|
|
|
|
MLA has two possible ways of computing, a data-movement friendly approach and a
|
|
compute friendly approach, we generally want to use the compute friendly
|
|
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
|
|
and the data-movement friendly approach for "decode" (i.e. the ratio
|
|
Sq / Skv is "large").
|
|
|
|
NOTE what we deem small and large is currently determined by if its labelled
|
|
prefill or decode by the scheduler, but this is something we should probably
|
|
tune.
|
|
|
|
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
|
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
|
|
|
Deepseek's MLA attention works the following way:
|
|
* Use a single latent vector to represent the per-token entry of the KV cache.
|
|
* For decode (i.e. the memory friendly approach) the attention "simulates" a
|
|
multi-head attention, while the compute is similar to multi-query attention.
|
|
|
|
Below is example of both paths assuming batchsize = 1
|
|
|
|
## More Extent Definitions:
|
|
|
|
C Context length, `Skv - Sq`
|
|
H hidden size
|
|
N number of attention heads
|
|
Lq latent dimension for Q 1536 in DSV3
|
|
Lkv latent dimension for K/V 512 in DSV3
|
|
P nope dimension, no rope. 128 in DSV3
|
|
R rope dimension, goes through rope. 64 in DSV3
|
|
V V head dim. 128 in DSV3
|
|
|
|
## Vector/Matrix Definitions
|
|
|
|
h_t hidden states (input to attention) shape [Sq, H]
|
|
q_c latent/compressed Q shape [Sq, Lq]
|
|
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
|
|
q_pe uncompressed Q (rope) shape [Sq, N, R]
|
|
kv_c latent/compressed KV shape [Skv, Lkv]
|
|
k_pe decoupled k position embeddings shape [Skv, R]
|
|
new_kv_c new kv_c from current iter shape [Sq, Lkv]
|
|
new_k_pe new k_pe from current iter shape [Sq, R]
|
|
cache_kv_c cached k_c from previous iters shape [C, Lkv]
|
|
cache_k_pe cached k_pe from previous iters shape [C, R]
|
|
W_DQ project h_t to q_c shape [H, Lq]
|
|
W_UQ project q_c to q_nope shape [Lq, N * P]
|
|
W_QR project q_c to q_pe shape [Lq, N * R]
|
|
W_DKV project h_t to kv_c shape [H, Lkv]
|
|
W_UK project kv_c to k_nope shape [Lkv, N, P]
|
|
W_KR project h_t to k_pe shape [H, R]
|
|
W_UV project kv_c to v shape [Lkv, N, V]
|
|
W_O project v to h_t shape [N * V, H]
|
|
|
|
|
|
## Compute Friendly Approach (i.e. "_forward_prefill"):
|
|
|
|
q_c = h_t @ W_DQ
|
|
q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
|
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
|
new_kv_c = h_t @ W_DKV
|
|
new_k_pe = RoPE(h_t @ W_KR)
|
|
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
|
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
|
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
|
|
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
|
|
|
|
// MHA with QK headdim = P + R
|
|
// V headdim = V
|
|
// spda_o shape [Sq, N, V]
|
|
spda_o = scaled_dot_product_attention(
|
|
torch.cat([q_nope, q_pe], dim=-1),
|
|
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
|
v
|
|
)
|
|
return spda_o @ W_O
|
|
|
|
NOTE: in the actual code,
|
|
`kv_b_proj` is [W_UK; W_UV] concatenated per head
|
|
`q_b_proj` is [W_UQ; W_QR] concatenated per head
|
|
`out_proj` is W_O
|
|
|
|
|
|
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
|
|
|
|
Runtime
|
|
q_c = h_t @ W_DQ
|
|
q_nope = (q_c @ W_UQ).view(-1, N, P)
|
|
ql_nope = einsum("snh,lnh->snl", q, W_UK)
|
|
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
|
new_kv_c = h_t @ W_DKV
|
|
new_k_pe = RoPE(h_t @ W_KR)
|
|
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
|
|
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
|
|
|
|
// MQA with QK headdim = Lkv + R
|
|
// V headdim = Lkv
|
|
// spda_o shape [Sq, N, Lkv]
|
|
// NOTE: this is less compute-friendly since Lkv > P
|
|
// but is more data-movement friendly since its MQA vs MHA
|
|
spda_o = scaled_dot_product_attention(
|
|
torch.cat([ql_nope, q_pe], dim=-1),
|
|
torch.cat([kv_c, k_pe], dim=-1),
|
|
kv_c
|
|
)
|
|
|
|
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
|
|
return o.view(-1, N * V) @ self.num_heads @ W_O
|
|
|
|
|
|
## Chunked Prefill
|
|
|
|
For chunked prefill we want to use the compute friendly algorithm. We are
|
|
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
|
|
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
|
|
|
|
However, the compute-friendly approach can potentially run out of memory if Skv
|
|
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
|
|
|
|
To mitigate this, we chunk the computation of attention with respect to the
|
|
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
|
|
fixed workspace size.
|
|
|
|
The chunked prefill approach is as follows:
|
|
|
|
MCC Max chunk of context to process per iter, computed dynamically,
|
|
used to bound the memory usage
|
|
|
|
q_c = h_t @ W_DQ
|
|
q_nope = (q_c @ W_UQ).view(Sq, N, P)
|
|
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
|
|
new_kv_c = h_t @ W_DKV
|
|
new_k_pe = RoPE(h_t @ W_KR)
|
|
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
|
|
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
|
|
|
|
// MHA between queries and new KV
|
|
// with QK headdim = P + R
|
|
// V headdim = V
|
|
// curr_o shape [Sq, N, V]
|
|
// curr_lse shape [N, Sq], this is just order FA returns
|
|
curr_o, curr_lse = scaled_dot_product_attention(
|
|
torch.cat([q_nope, q_pe], dim=-1),
|
|
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
|
|
new_v,
|
|
casual=True,
|
|
return_softmax_lse=True
|
|
)
|
|
|
|
// Compute attention with the already existing context
|
|
for chunk_idx in range(cdiv(C, MCC)):
|
|
chunk_start = chunk_idx * MCC
|
|
chunk_end = min(chunk_start + MCC, C)
|
|
Sc = chunk_end - chunk_start
|
|
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
|
|
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
|
|
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
|
|
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
|
|
|
|
chunk_o, chunk_lse = scaled_dot_product_attention(
|
|
torch.cat([q_nope, q_pe], dim=-1),
|
|
torch.cat([cache_k_nope_chunk,
|
|
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
|
|
dim=-1),
|
|
cache_v_chunk,
|
|
casual=False,
|
|
return_softmax_lse=True
|
|
)
|
|
|
|
curr_o, curr_lse = merge_attn_states(
|
|
suffix_output=curr_o,
|
|
suffix_lse=curr_lse,
|
|
prefix_output=chunk_o,
|
|
prefix_lse=chunk_lse,
|
|
)
|
|
|
|
return curr_o @ W_O
|
|
"""
|
|
|
|
import functools
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from itertools import accumulate
|
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
|
|
Type, TypeVar)
|
|
|
|
import torch
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm import envs
|
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
|
AttentionMetadata,
|
|
AttentionMetadataBuilder,
|
|
AttentionState, MLAAttentionImpl)
|
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|
compute_slot_mapping_start_idx,
|
|
is_block_tables_empty)
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
LinearBase, RowParallelLinear,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.rotary_embedding import (
|
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
|
from vllm.multimodal import MultiModalPlaceholderMap
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import HAS_TRITON
|
|
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
|
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
|
|
|
if HAS_TRITON:
|
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
|
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
|
|
else:
|
|
merge_attn_states = None
|
|
triton_attention = None
|
|
|
|
try:
|
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
is_vllm_fa = True
|
|
except ImportError:
|
|
is_vllm_fa = False
|
|
try:
|
|
# For rocm use upstream flash attention
|
|
from flash_attn import flash_attn_varlen_func
|
|
except ImportError:
|
|
flash_attn_varlen_func = None
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
|
ModelInputForGPUWithSamplingMetadata)
|
|
|
|
is_hip = current_platform.is_rocm()
|
|
|
|
|
|
class MLACommonBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "TRITON_MLA"
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
|
return MLACommonMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> Type["MLACommonMetadataBuilder"]:
|
|
return MLACommonMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_state_cls() -> Type["MLACommonState"]:
|
|
return MLACommonState
|
|
|
|
@staticmethod
|
|
def get_kv_cache_shape(
|
|
num_blocks: int,
|
|
block_size: int,
|
|
num_kv_heads: int, # assumed to be 1 for MLA
|
|
head_size: int,
|
|
) -> Tuple[int, ...]:
|
|
return (num_blocks, block_size, head_size)
|
|
|
|
@staticmethod
|
|
def swap_blocks(
|
|
src_kv_cache: torch.Tensor,
|
|
dst_kv_cache: torch.Tensor,
|
|
src_to_dst: torch.Tensor,
|
|
) -> None:
|
|
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
|
|
|
@staticmethod
|
|
def copy_blocks(
|
|
kv_caches: List[torch.Tensor],
|
|
src_to_dists: torch.Tensor,
|
|
) -> None:
|
|
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
|
|
|
@staticmethod
|
|
def get_supported_head_sizes() -> List[int]:
|
|
return [576]
|
|
|
|
|
|
T = TypeVar("T", bound="MLACommonMetadata")
|
|
|
|
|
|
class MLACommonState(AttentionState, Generic[T]):
|
|
|
|
def __init__(self, runner):
|
|
self.runner = runner
|
|
self._is_graph_capturing = False
|
|
|
|
scheduler_config = runner.scheduler_config
|
|
self.model_config = runner.model_config
|
|
cache_config = runner.cache_config
|
|
|
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
|
self.enable_prefix_caching = cache_config.enable_prefix_caching
|
|
|
|
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
|
self.context_chunk_workspace_size = min(
|
|
# Max sure there is enough for 8 full length request or at least
|
|
# 4 pages of cache per request
|
|
max(
|
|
8 * self.model_config.max_model_len, 4 *
|
|
scheduler_config.max_num_seqs * cache_config.block_size),
|
|
# For long-context models try not to over-allocate limiting
|
|
# kv-cache space, limiting it to 64k tokens,
|
|
# which would result in the workspace being:
|
|
# 2*(576)*(64*1024) = 144mb
|
|
# (assuming 576 MLA head dim, and fp16)
|
|
# which would result in up-projected context being
|
|
# 2*(192*128)*(64*1024) = 3gb
|
|
# (assuming 192 QK head dim, 128 heads, and fp16)
|
|
128 * 1024)
|
|
assert self.context_chunk_workspace_size >= \
|
|
scheduler_config.max_num_seqs * cache_config.block_size
|
|
|
|
@contextmanager
|
|
def graph_capture(self, max_batch_size: int):
|
|
self._is_graph_capturing = True
|
|
|
|
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
|
PAD_SLOT_ID,
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
self._graph_seq_lens = torch.ones(max_batch_size,
|
|
dtype=torch.int32,
|
|
device=self.runner.device)
|
|
self._graph_block_tables = torch.from_numpy(
|
|
self.runner.graph_block_tables).to(device=self.runner.device)
|
|
|
|
self._positions = torch.zeros((max_batch_size, ),
|
|
dtype=torch.long,
|
|
device=self.runner.device)
|
|
|
|
yield
|
|
|
|
self._is_graph_capturing = False
|
|
del self._graph_slot_mapping
|
|
del self._graph_seq_lens
|
|
del self._graph_block_tables
|
|
del self._positions
|
|
|
|
def graph_clone(self, batch_size: int):
|
|
assert self._is_graph_capturing
|
|
return self.__class__(self.runner)
|
|
|
|
def graph_capture_get_metadata_for_batch(
|
|
self,
|
|
batch_size: int,
|
|
is_encoder_decoder_model: bool = False) -> T:
|
|
assert self._is_graph_capturing
|
|
|
|
attn_metadata = self.runner.attn_backend.make_metadata(
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
use_cuda_graph=True,
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=batch_size,
|
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
|
seq_lens=None,
|
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
|
max_query_len=1,
|
|
max_decode_query_len=1,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
|
query_start_loc=None,
|
|
seq_start_loc=None,
|
|
context_lens_tensor=None,
|
|
block_tables=self._graph_block_tables[:batch_size],
|
|
input_positions=self._positions[:batch_size],
|
|
head_dim=self.runner.model_config.get_head_size())
|
|
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"MLACommonState does not support encoder/decoder yet")
|
|
|
|
return attn_metadata
|
|
|
|
def get_graph_input_buffers(self,
|
|
attn_metadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
input_buffers = {
|
|
"slot_mapping": attn_metadata.slot_mapping,
|
|
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
|
"input_positions": attn_metadata.decode_metadata.input_positions,
|
|
}
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"MLACommonState does not support encoder/decoder yet")
|
|
|
|
return input_buffers
|
|
|
|
def prepare_graph_input_buffers(self,
|
|
input_buffers,
|
|
attn_metadata,
|
|
is_encoder_decoder_model: bool = False):
|
|
input_positions = attn_metadata.input_positions
|
|
num_positions = input_positions.shape[0]
|
|
input_buffers["seq_lens_tensor"].copy_(
|
|
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
|
input_buffers["block_tables"].copy_(
|
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
|
# CUDA graph buffer is padded so only perform a partial copy based on
|
|
# num_positions
|
|
input_buffers["input_positions"][:num_positions].copy_(
|
|
input_positions, non_blocking=True)
|
|
if is_encoder_decoder_model:
|
|
raise NotImplementedError(
|
|
"TritonMLAState does not support encoder/decoder yet")
|
|
|
|
def begin_forward(self, model_input):
|
|
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
|
if not hasattr(self, "context_chunk_workspace"):
|
|
# not self.runner.device does not return the correct device
|
|
# for this process, (init_device sets the correct device but
|
|
# only on the Worker). The only way Ive figured out to get the
|
|
# correct device is to allocate the workspace on the first call
|
|
# to begin_forward and use the device of the input tokens
|
|
assert model_input.input_tokens is not None
|
|
self.context_chunk_workspace = torch.empty(
|
|
(self.context_chunk_workspace_size,
|
|
self.model_config.get_head_size()),
|
|
dtype=self.model_config.dtype,
|
|
device=model_input.input_tokens.device,
|
|
)
|
|
|
|
model_input.attn_metadata.context_chunk_workspace = \
|
|
self.context_chunk_workspace
|
|
|
|
|
|
@dataclass
|
|
class MLACommonMetadata(AttentionMetadata):
|
|
"""Metadata for MLACommon.
|
|
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
|
|
NOTE: Any python object stored here is not updated when it is
|
|
cuda-graph replayed. If you have values that need to be changed
|
|
dynamically, it should be stored in tensor. The tensor has to be
|
|
updated from `CUDAGraphRunner.forward` API.
|
|
"""
|
|
# Whether or not if cuda graph is enabled.
|
|
# Cuda-graph is currently enabled for decoding only.
|
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
|
use_cuda_graph: bool
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
# Input positions for rotrary embeddings since for MLA the rotary
|
|
# position embeddings are applied inside the attention backend
|
|
input_positions: torch.Tensor
|
|
|
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
|
# |---------- N-1 iteration --------|
|
|
# |---------------- N iteration ---------------------|
|
|
# |- tokenA -|......................|-- newTokens ---|
|
|
# |---------- context_len ----------|
|
|
# |-------------------- seq_len ---------------------|
|
|
# |-- query_len ---|
|
|
|
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
|
# the computed tokens + new tokens None if it is a decoding.
|
|
seq_lens: Optional[List[int]]
|
|
# seq_lens stored as a tensor.
|
|
seq_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
|
# requests only.
|
|
max_prefill_seq_len: int
|
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
|
# requests only.
|
|
max_decode_seq_len: int
|
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
|
# so far).
|
|
context_lens_tensor: Optional[torch.Tensor]
|
|
|
|
# (batch_size, max_blocks_per_seq).
|
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
|
# in the kv cache. Each block can contain up to block_size tokens.
|
|
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
|
# captured.
|
|
block_tables: Optional[torch.Tensor]
|
|
|
|
# Maximum query length in the batch.
|
|
max_query_len: Optional[int] = None
|
|
|
|
# Max number of query tokens among request in the batch.
|
|
max_decode_query_len: Optional[int] = None
|
|
|
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
|
# the batch, used to index into subquery. E.g., if the subquery length
|
|
# is [4, 6], it is [0, 4, 10].
|
|
query_start_loc: Optional[torch.Tensor] = None
|
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
|
# [4, 6], it is [0, 4, 10].
|
|
seq_start_loc: Optional[torch.Tensor] = None
|
|
|
|
_cached_prefill_metadata: Optional[Any] = None
|
|
_cached_decode_metadata: Optional[Any] = None
|
|
|
|
num_prefill_tokens: int
|
|
|
|
# The dimension of the attention heads
|
|
head_dim: Optional[int] = None
|
|
|
|
# Used when chunked prefill is enabled to simulate worst case workspace
|
|
# allocations, hopefully to avoid going OOM
|
|
is_profile_run: bool = False
|
|
|
|
# New for MLA (compared to FlashAttention)
|
|
# For chunked prefill
|
|
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
|
context_chunk_starts: Optional[torch.Tensor] = None
|
|
context_chunk_seq_tot: Optional[List[int]] = None
|
|
context_chunk_max_seq_lens: Optional[List[int]] = None
|
|
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
|
|
context_chunk_workspace: Optional[torch.Tensor] = None
|
|
|
|
def __post_init__(self):
|
|
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
|
if self.head_dim is not None and self.head_dim \
|
|
not in supported_head_sizes:
|
|
raise ValueError(
|
|
f"Only {supported_head_sizes} are supported for head_dim,",
|
|
f" received {self.head_dim}.")
|
|
|
|
@property
|
|
def prefill_metadata(self):
|
|
if self.num_prefills == 0:
|
|
return None
|
|
|
|
if self._cached_prefill_metadata is not None:
|
|
return self._cached_prefill_metadata
|
|
|
|
assert self.seq_lens is not None
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
query_start_loc = (None if self.query_start_loc is None else
|
|
self.query_start_loc[:self.num_prefills + 1])
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[:self.num_prefill_tokens])
|
|
seq_lens = (None if self.seq_lens is None else
|
|
self.seq_lens[:self.num_prefills])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[:self.num_prefills])
|
|
seq_start_loc = (None if self.seq_start_loc is None else
|
|
self.seq_start_loc[:self.num_prefills + 1])
|
|
context_lens_tensor = (None if self.context_lens_tensor is None else
|
|
self.context_lens_tensor[:self.num_prefills])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[:self.num_prefills])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[:self.num_prefill_tokens])
|
|
|
|
self._cached_prefill_metadata = self.__class__(
|
|
# Required by ModelRunner
|
|
use_cuda_graph=False, # Not Attention Related
|
|
# Required by Attention Metadata
|
|
num_prefills=self.num_prefills,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=0,
|
|
slot_mapping=slot_mapping,
|
|
# Required by Attention Metadata (not used)
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
# MLACommonMetadata
|
|
input_positions=input_positions,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
|
max_decode_query_len=0,
|
|
max_decode_seq_len=0,
|
|
query_start_loc=query_start_loc,
|
|
seq_start_loc=seq_start_loc,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
head_dim=self.head_dim,
|
|
is_profile_run=self.is_profile_run,
|
|
# MLACommonMetadata Chunk prefill specific
|
|
context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens,
|
|
context_chunk_starts=self.context_chunk_starts,
|
|
context_chunk_seq_tot=self.context_chunk_seq_tot,
|
|
context_chunk_max_seq_lens=self.context_chunk_max_seq_lens,
|
|
)
|
|
return self._cached_prefill_metadata
|
|
|
|
@property
|
|
def decode_metadata(self):
|
|
if self.num_decode_tokens == 0:
|
|
return None
|
|
|
|
if self._cached_decode_metadata is not None:
|
|
return self._cached_decode_metadata
|
|
assert self.seq_lens_tensor is not None
|
|
|
|
# Compute some attn_metadata fields which default to None
|
|
slot_mapping = (None if self.slot_mapping is None else
|
|
self.slot_mapping[self.num_prefill_tokens:])
|
|
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
|
self.seq_lens_tensor[self.num_prefills:])
|
|
block_tables = (None if self.block_tables is None else
|
|
self.block_tables[self.num_prefills:])
|
|
input_positions = (None if self.input_positions is None else
|
|
self.input_positions[self.num_prefill_tokens:])
|
|
|
|
self._cached_decode_metadata = self.__class__(
|
|
# Required by ModelRunner
|
|
use_cuda_graph=self.use_cuda_graph, # Not Attention Related
|
|
# Required by Attention Metadata
|
|
num_prefills=0,
|
|
num_prefill_tokens=0,
|
|
num_decode_tokens=self.num_decode_tokens,
|
|
slot_mapping=slot_mapping,
|
|
# Required by Attention Metadata (not used)
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=False,
|
|
# MLACommonMetadata
|
|
seq_lens=None,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_decode_query_len=self.max_decode_query_len,
|
|
max_query_len=self.max_query_len,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=self.max_decode_seq_len,
|
|
# Batch may be composed of prefill|decodes, adjust query start
|
|
# indices to refer to the start of decodes. E.g.
|
|
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
|
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
|
self.query_start_loc[self.num_prefills])
|
|
if self.query_start_loc is not None else None,
|
|
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
|
if self.seq_start_loc is not None else None,
|
|
context_lens_tensor=None,
|
|
block_tables=block_tables,
|
|
input_positions=input_positions,
|
|
head_dim=self.head_dim,
|
|
is_profile_run=self.is_profile_run)
|
|
return self._cached_decode_metadata
|
|
|
|
def advance_step(self,
|
|
model_input: "ModelInputForGPUWithSamplingMetadata",
|
|
sampled_token_ids: Optional[torch.Tensor],
|
|
block_size: int,
|
|
num_seqs: int,
|
|
num_queries: int,
|
|
turn_prefills_into_decodes: bool = False):
|
|
"""
|
|
Update metadata in-place to advance one decode step.
|
|
"""
|
|
# When using cudagraph, the num_seqs is padded to the next captured
|
|
# batch sized, but num_queries tracks the actual number of requests in
|
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
|
if num_seqs != num_queries:
|
|
assert num_seqs > num_queries
|
|
|
|
if turn_prefills_into_decodes:
|
|
# When Multi-Step is enabled with Chunked-Prefill, prefills and
|
|
# decodes are scheduled together. In the first step, all the
|
|
# prefills turn into decodes. This update reflects that
|
|
# conversion.
|
|
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
|
self.num_decode_tokens += self.num_prefills
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.max_prefill_seq_len = 0
|
|
self.max_query_len = 1
|
|
|
|
self.slot_mapping = self.slot_mapping[:num_seqs]
|
|
else:
|
|
assert self.seq_lens is not None
|
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
|
|
|
assert self.num_prefills == 0
|
|
assert self.num_prefill_tokens == 0
|
|
assert self.num_decode_tokens == num_seqs
|
|
assert self.slot_mapping.shape == (num_seqs, )
|
|
|
|
assert self.seq_lens is not None
|
|
assert len(self.seq_lens) == num_seqs
|
|
assert self.seq_lens_tensor is not None
|
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
|
assert self.max_query_len == 1
|
|
assert self.max_prefill_seq_len == 0
|
|
|
|
assert self.query_start_loc is not None
|
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
|
assert self.seq_start_loc is not None
|
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
|
|
|
assert self.context_lens_tensor is not None
|
|
assert self.context_lens_tensor.shape == (num_queries, )
|
|
|
|
assert self.block_tables is not None
|
|
assert self.block_tables.shape[0] == num_seqs
|
|
|
|
# Update query lengths. Note that we update only queries and not seqs,
|
|
# since tensors may be padded due to captured cuda graph batch size
|
|
for i in range(num_queries):
|
|
self.seq_lens[i] += 1
|
|
self.max_decode_seq_len = max(self.seq_lens)
|
|
|
|
ops.advance_step_flashattn(num_seqs=num_seqs,
|
|
num_queries=num_queries,
|
|
block_size=block_size,
|
|
input_tokens=model_input.input_tokens,
|
|
sampled_token_ids=sampled_token_ids,
|
|
input_positions=model_input.input_positions,
|
|
seq_lens=self.seq_lens_tensor,
|
|
slot_mapping=self.slot_mapping,
|
|
block_tables=self.block_tables)
|
|
|
|
|
|
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
|
|
"""
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
|
|
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
|
self.input_builder = input_builder
|
|
self.runner = input_builder.runner
|
|
self.sliding_window = input_builder.sliding_window
|
|
self.block_size = input_builder.block_size
|
|
self.chunked_prefill_enabled = \
|
|
self.runner.scheduler_config.chunked_prefill_enabled
|
|
self.enable_prefix_caching = \
|
|
self.runner.cache_config.enable_prefix_caching
|
|
|
|
if self.chunked_prefill_enabled or self.enable_prefix_caching:
|
|
attn_state = self.input_builder.runner.attn_state
|
|
self.context_chunk_workspace_size = \
|
|
attn_state.context_chunk_workspace_size
|
|
self.page_size = self.runner.block_size
|
|
|
|
def prepare(self):
|
|
self.slot_mapping: List[int] = []
|
|
self.prefill_seq_lens: List[int] = []
|
|
self.context_lens: List[int] = []
|
|
self.block_tables: List[List[int]] = []
|
|
self.curr_seq_lens: List[int] = []
|
|
self.input_positions: List[int] = []
|
|
self.multimodal_placeholder_maps: Dict[
|
|
str,
|
|
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
|
self.num_prefills = 0
|
|
self.num_prefill_tokens = 0
|
|
self.num_decode_tokens = 0
|
|
self.has_prefix_cache_hit = False
|
|
|
|
def _add_seq_group(
|
|
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
|
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
|
"""Add a sequence group to the metadata. Specifically update/append
|
|
1. context length.
|
|
2. block table.
|
|
3. slot mapping.
|
|
"""
|
|
is_prompt = inter_data.is_prompt
|
|
block_tables = inter_data.block_tables
|
|
|
|
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
|
curr_sliding_window_block, input_positions) in zip(
|
|
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
|
inter_data.orig_seq_lens, inter_data.seq_lens,
|
|
inter_data.query_lens, inter_data.context_lens,
|
|
inter_data.curr_sliding_window_blocks,
|
|
inter_data.input_positions):
|
|
self.input_positions.extend(input_positions)
|
|
self.context_lens.append(context_len)
|
|
if is_prompt:
|
|
self.num_prefills += 1
|
|
self.num_prefill_tokens += token_len
|
|
self.prefill_seq_lens.append(seq_len)
|
|
else:
|
|
self.num_decode_tokens += query_len
|
|
self.curr_seq_lens.append(curr_seq_len)
|
|
|
|
# Compute block table.
|
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
|
# only allowing multiple of block_size chunk size.
|
|
# NOTE: This only works for oooooooxxx style attention.
|
|
block_table = []
|
|
if prefix_cache_hit:
|
|
# NOTE(woosuk): For flash-attn, the block table should
|
|
# include the entries for the incoming prefill tokens.
|
|
block_table = block_tables[seq_id]
|
|
elif ((chunked_prefill_enabled or not is_prompt)
|
|
and block_tables is not None):
|
|
if curr_sliding_window_block == 0:
|
|
block_table = block_tables[seq_id]
|
|
else:
|
|
block_table = block_tables[seq_id][
|
|
-curr_sliding_window_block:]
|
|
self.block_tables.append(block_table)
|
|
|
|
# Compute slot mapping.
|
|
is_profile_run = is_block_tables_empty(block_tables)
|
|
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
|
context_len,
|
|
self.sliding_window)
|
|
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
|
seq_len, context_len, start_idx,
|
|
self.block_size, inter_data.block_tables)
|
|
|
|
def _get_graph_runner_block_tables(
|
|
self, num_seqs: int,
|
|
block_tables: List[List[int]]) -> torch.Tensor:
|
|
# The shape of graph_block_tables is
|
|
# [max batch size, max context len // block size].
|
|
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
|
assert max_batch_size >= num_seqs
|
|
|
|
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
|
for i, block_table in enumerate(block_tables):
|
|
if block_table:
|
|
num_blocks = len(block_table)
|
|
if num_blocks <= max_blocks:
|
|
graph_block_tables[i, :num_blocks] = block_table
|
|
else:
|
|
# It may be possible to have more blocks allocated due
|
|
# to lookahead slots of multi-step, however, they are
|
|
# not used anyway, so can be safely ignored.
|
|
graph_block_tables[
|
|
i, :max_blocks] = block_table[:max_blocks]
|
|
|
|
return torch.from_numpy(graph_block_tables).to(
|
|
device=self.runner.device, non_blocking=True)
|
|
|
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
|
cuda_graph_pad_size: int, batch_size: int):
|
|
"""Build attention metadata with on-device tensors.
|
|
|
|
Args:
|
|
seq_lens: The maybe padded sequence lengths of the input sequences.
|
|
query_lens: The query lengths of the input sequences.
|
|
cuda_graph_pad_size: The padding size for cuda graph.
|
|
-1 if cuda graph is not used.
|
|
batch_size: The maybe padded batch size.
|
|
"""
|
|
prefix_cache_hit = any([
|
|
inter_data.prefix_cache_hit
|
|
for inter_data in self.input_builder.inter_data_list
|
|
])
|
|
|
|
for inter_data in self.input_builder.inter_data_list:
|
|
self._add_seq_group(inter_data,
|
|
self.input_builder.chunked_prefill_enabled,
|
|
prefix_cache_hit)
|
|
|
|
device = self.runner.device
|
|
use_captured_graph = cuda_graph_pad_size != -1
|
|
|
|
max_query_len = max(query_lens)
|
|
decode_query_lens = query_lens[self.num_prefills:]
|
|
if len(decode_query_lens) > 0:
|
|
max_decode_query_len = max(decode_query_lens)
|
|
else:
|
|
max_decode_query_len = 1
|
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
|
num_decode_tokens = self.num_decode_tokens
|
|
query_start_loc = list(accumulate(query_lens, initial=0))
|
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
|
|
|
num_seqs = len(seq_lens)
|
|
if use_captured_graph:
|
|
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
|
self.block_tables.extend([] * cuda_graph_pad_size)
|
|
num_decode_tokens = batch_size - self.num_prefill_tokens
|
|
block_tables = self._get_graph_runner_block_tables(
|
|
num_seqs, self.block_tables)
|
|
else:
|
|
block_tables = make_tensor_with_pad(
|
|
self.block_tables,
|
|
pad=0,
|
|
dtype=torch.int,
|
|
device=device,
|
|
)
|
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
|
|
assert device is not None
|
|
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
|
device, self.runner.pin_memory)
|
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
|
self.runner.pin_memory)
|
|
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
|
device, self.runner.pin_memory)
|
|
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
|
device, self.runner.pin_memory)
|
|
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
|
device,
|
|
self.runner.pin_memory)
|
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
|
device, self.runner.pin_memory)
|
|
|
|
context_chunk_cu_seq_lens = None
|
|
context_chunk_starts = None
|
|
context_chunk_seq_tot = None
|
|
context_chunk_max_seq_lens = None
|
|
|
|
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
|
|
and self.num_prefills > 0 \
|
|
and context_lens_tensor is not None \
|
|
and context_lens_tensor[:self.num_prefills].max() > 0:
|
|
|
|
# NOTE: it is recommend you read the `Chunked Prefill` section in
|
|
# the comment at the top of the file before trying to understand
|
|
# the following code
|
|
|
|
num_prefills_with_context = \
|
|
(context_lens_tensor[:self.num_prefills] > 0).sum().item()
|
|
|
|
# currently we allocate an equal amount of workspace for each
|
|
# prefill in the batch, we could probably use a more advanced
|
|
# algorithm here and allocate more workspace to prefills with
|
|
# longer context lengths
|
|
max_context_chunk = \
|
|
self.context_chunk_workspace_size // num_prefills_with_context
|
|
|
|
# align max_context_chunk to page_size by rounding down,
|
|
# currently the `gather_cache` kernel cannot handle
|
|
# `context_chunk_starts` that are not aligned to page_size
|
|
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
|
assert max_context_chunk > 0
|
|
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
|
|
|
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
|
# `num_prefills_with_context = 4`, create a tensor that looks like
|
|
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
|
context_chunk_starts = \
|
|
torch.arange(num_chunks, device=device, dtype=torch.int32)\
|
|
.unsqueeze(1).expand(-1, self.num_prefills)\
|
|
* max_context_chunk
|
|
chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\
|
|
.unsqueeze(0), context_chunk_starts + max_context_chunk)
|
|
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
|
|
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
|
torch.int32)
|
|
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\
|
|
.unsqueeze(-1)
|
|
context_chunk_cu_seq_lens = \
|
|
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
|
|
context_chunk_max_seq_lens = \
|
|
chunk_seq_lens.max(dim=1).values.tolist()
|
|
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
|
|
assert max(context_chunk_seq_tot) <= \
|
|
self.context_chunk_workspace_size
|
|
|
|
return self.runner.attn_backend.make_metadata(
|
|
# Required by ModelRunner
|
|
use_cuda_graph=use_captured_graph, # Not Attention Related
|
|
# Required by Attention Metadata
|
|
num_prefills=self.num_prefills,
|
|
slot_mapping=slot_mapping_tensor,
|
|
num_prefill_tokens=self.num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
# Required by Attention Metadata (not used)
|
|
multi_modal_placeholder_index_maps=None, # Not Attention Related
|
|
enable_kv_scales_calculation=False,
|
|
# MLACommonMetadata
|
|
input_positions=input_positions,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
max_query_len=max_query_len,
|
|
max_decode_query_len=max_decode_query_len,
|
|
max_prefill_seq_len=max_prefill_seq_len,
|
|
max_decode_seq_len=max_decode_seq_len,
|
|
query_start_loc=query_start_loc_tensor,
|
|
seq_start_loc=seq_start_loc_tensor,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=block_tables,
|
|
head_dim=self.runner.model_config.get_head_size(),
|
|
is_profile_run=self.runner.in_profile_run,
|
|
# MLACommonMetadata Chunk prefill specific
|
|
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
|
|
context_chunk_starts=context_chunk_starts,
|
|
context_chunk_seq_tot=context_chunk_seq_tot,
|
|
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
|
|
)
|
|
|
|
|
|
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|
"""
|
|
NOTE: Please read the comment at the top of the file before trying to
|
|
understand this class
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: Optional[List[float]],
|
|
sliding_window: Optional[int],
|
|
kv_cache_dtype: str,
|
|
blocksparse_params: Optional[Dict[str, Any]],
|
|
logits_soft_cap: Optional[float],
|
|
attn_type: str,
|
|
# MLA Specific Arguments
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
qk_head_dim: int,
|
|
v_head_dim: int,
|
|
rotary_emb: RotaryEmbedding,
|
|
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
|
# attention backend perspective we rely on the layer to pass in the
|
|
# correct matrix
|
|
q_proj: ColumnParallelLinear,
|
|
kv_b_proj: ColumnParallelLinear,
|
|
o_proj: RowParallelLinear,
|
|
) -> None:
|
|
self.num_heads = num_heads
|
|
self.head_size = head_size
|
|
self.scale = float(scale)
|
|
self.num_kv_heads = num_kv_heads
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
|
|
self.rotary_emb = rotary_emb
|
|
self.use_yarn_rope = isinstance(rotary_emb,
|
|
DeepseekScalingRotaryEmbedding)
|
|
self.q_proj = q_proj
|
|
self.kv_b_proj = kv_b_proj
|
|
self.o_proj = o_proj
|
|
self.triton_fa_func = triton_attention
|
|
|
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
|
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
|
# latter has an additional parameter to control FA2 vs FA3
|
|
self.flash_attn_varlen_func = flash_attn_varlen_func
|
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
|
if self.vllm_flash_attn_version is not None:
|
|
self.flash_attn_varlen_func = \
|
|
functools.partial(flash_attn_varlen_func,
|
|
fa_version=self.vllm_flash_attn_version)
|
|
|
|
def _v_up_proj_and_o_proj(self, x):
|
|
# Convert from (B, N, L) to (N, B, L)
|
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
|
x = torch.bmm(x, self.W_UV)
|
|
# Convert from (N, B, V) to (B, N * V)
|
|
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
|
return self.o_proj(x)[0]
|
|
|
|
# Return `ql_nope`, `q_pe`
|
|
def _q_proj_and_k_up_proj(self, x):
|
|
q_nope, q_pe = self.q_proj(x)[0]\
|
|
.view(-1, self.num_heads, self.qk_head_dim)\
|
|
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
|
|
# Convert from (B, N, P) to (N, B, P)
|
|
q_nope = q_nope.transpose(0, 1)
|
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|
# Convert from (N, B, L) to (B, N, L)
|
|
return ql_nope.transpose(0, 1), q_pe
|
|
|
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
|
|
def get_layer_weight(layer):
|
|
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
|
for attr in WEIGHT_NAMES:
|
|
if hasattr(layer, attr):
|
|
return getattr(layer, attr)
|
|
raise AttributeError(
|
|
f"Layer '{layer}' has no recognized weight attribute:"
|
|
f" {WEIGHT_NAMES}.")
|
|
|
|
def get_and_maybe_dequant_weights(layer: LinearBase):
|
|
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
|
# NOTE: This should only be used offline, since it's O(N^3)
|
|
eye = torch.eye(layer.input_size_per_partition,
|
|
dtype=act_dtype,
|
|
device=get_layer_weight(layer).device)
|
|
dequant_weights = layer.quant_method.apply(layer,
|
|
eye,
|
|
bias=None)
|
|
del eye
|
|
# standardize to (output, input)
|
|
return dequant_weights.T
|
|
return layer.weight
|
|
|
|
# we currently do not have quantized bmm's which are needed for
|
|
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
|
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
|
assert kv_b_proj_weight.shape == (
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
|
f"{kv_b_proj_weight.shape=}, "
|
|
f"{self.kv_lora_rank=}, "
|
|
f"{self.num_heads=}, "
|
|
f"{self.qk_nope_head_dim=}, "
|
|
f"{self.v_head_dim=}")
|
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
|
self.kv_lora_rank,
|
|
self.num_heads,
|
|
self.qk_nope_head_dim + self.v_head_dim,
|
|
)
|
|
|
|
W_UK, W_UV = kv_b_proj_weight.split(
|
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
# Convert from (L, N, V) to (N, L, V)
|
|
self.W_UV = W_UV.transpose(0, 1)
|
|
# Convert from (L, N, P) to (N, P, L)
|
|
self.W_UK_T = W_UK.permute(1, 2, 0)
|
|
|
|
def _compute_prefill_context(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: MLACommonMetadata,
|
|
):
|
|
prefill_metadata = attn_metadata.prefill_metadata
|
|
assert prefill_metadata is not None
|
|
assert prefill_metadata.context_chunk_seq_tot is not None
|
|
assert prefill_metadata.context_chunk_cu_seq_lens is not None
|
|
assert prefill_metadata.context_chunk_starts is not None
|
|
assert prefill_metadata.context_chunk_max_seq_lens is not None
|
|
assert prefill_metadata.context_lens_tensor is not None
|
|
|
|
output = None
|
|
iters = len(prefill_metadata.context_chunk_seq_tot)
|
|
|
|
# Fetch from attn_metadata directly, since it late bound by
|
|
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
|
|
# any weirdness around prefill_metadata caching
|
|
assert attn_metadata.context_chunk_workspace is not None
|
|
workspace = attn_metadata.context_chunk_workspace
|
|
|
|
for i in range(iters):
|
|
toks = prefill_metadata.context_chunk_seq_tot[i]
|
|
|
|
ops.gather_cache(
|
|
src_cache=kv_c_and_k_pe_cache,
|
|
dst=workspace,
|
|
block_table=prefill_metadata.block_tables,
|
|
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
|
|
batch_size=prefill_metadata.num_prefills,
|
|
seq_starts=prefill_metadata.context_chunk_starts[i],
|
|
)
|
|
|
|
kv_c_normed = workspace[:toks]\
|
|
[..., :self.kv_lora_rank]
|
|
k_pe = workspace[:toks]\
|
|
[..., self.kv_lora_rank:].unsqueeze(1)
|
|
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv_nope\
|
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
|
dim=-1)
|
|
|
|
# For MLA the v head dim is smaller than qk head dim so we pad
|
|
# out v with 0s to match the qk head dim
|
|
v_padded = torch.nn.functional.pad(v,
|
|
[0, q.shape[-1] - v.shape[-1]],
|
|
value=0)
|
|
|
|
if is_vllm_fa:
|
|
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
|
|
q=q,
|
|
k=k,
|
|
v=v_padded,
|
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
|
max_seqlen_q=prefill_metadata.max_query_len,
|
|
max_seqlen_k=prefill_metadata.
|
|
context_chunk_max_seq_lens[i],
|
|
softmax_scale=self.scale,
|
|
causal=False, # Context is unmasked
|
|
return_softmax_lse=True,
|
|
)
|
|
else:
|
|
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
|
|
q=q,
|
|
k=k,
|
|
v=v_padded,
|
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
|
|
max_seqlen_q=prefill_metadata.max_query_len,
|
|
max_seqlen_k=prefill_metadata.
|
|
context_chunk_max_seq_lens[i],
|
|
softmax_scale=self.scale,
|
|
causal=False, # Context is unmasked
|
|
return_attn_probs=True,
|
|
)
|
|
|
|
if output is None:
|
|
output = attn_output
|
|
output_lse = attn_softmax_lse
|
|
else:
|
|
output_tmp = torch.empty_like(output)
|
|
output_lse_tmp = torch.empty_like(output_lse)
|
|
merge_attn_states(
|
|
output=output_tmp,
|
|
output_lse=output_lse_tmp,
|
|
prefix_output=output,
|
|
prefix_lse=output_lse,
|
|
suffix_output=attn_output,
|
|
suffix_lse=attn_softmax_lse,
|
|
)
|
|
output = output_tmp
|
|
output_lse = output_lse_tmp
|
|
|
|
return output, output_lse
|
|
|
|
def _forward_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
kv_c_normed: torch.Tensor,
|
|
k_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: MLACommonMetadata,
|
|
) -> torch.Tensor:
|
|
|
|
prefill_metadata = attn_metadata.prefill_metadata
|
|
assert prefill_metadata is not None
|
|
|
|
has_context = prefill_metadata.context_lens_tensor is not None \
|
|
and prefill_metadata.context_lens_tensor.max() > 0
|
|
|
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
k_nope, v = kv_nope\
|
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
|
|
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
|
|
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
|
# v with 0s to match the qk head dim
|
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
|
value=0)
|
|
|
|
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
|
|
output = self.triton_fa_func(
|
|
q,
|
|
k,
|
|
v_padded,
|
|
None,
|
|
prefill_metadata.query_start_loc,
|
|
prefill_metadata.query_start_loc,
|
|
prefill_metadata.max_prefill_seq_len,
|
|
prefill_metadata.max_prefill_seq_len,
|
|
True, # causal
|
|
self.scale,
|
|
None, # attn_mask is None unless applying ALiBi mask
|
|
)
|
|
## triton flash attention always return 2 objects
|
|
if not has_context:
|
|
output = output[0]
|
|
elif is_vllm_fa:
|
|
output = self.flash_attn_varlen_func(
|
|
q=q,
|
|
k=k,
|
|
v=v_padded,
|
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
cu_seqlens_k=prefill_metadata.query_start_loc,
|
|
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
return_softmax_lse=has_context,
|
|
)
|
|
else:
|
|
output = self.flash_attn_varlen_func(
|
|
q=q,
|
|
k=k,
|
|
v=v_padded,
|
|
cu_seqlens_q=prefill_metadata.query_start_loc,
|
|
cu_seqlens_k=prefill_metadata.query_start_loc,
|
|
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
|
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
return_attn_probs=has_context,
|
|
)
|
|
|
|
if has_context:
|
|
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
|
|
suffix_output, suffix_lse, *rest = output
|
|
context_output, context_lse = self._compute_prefill_context( \
|
|
q, kv_c_and_k_pe_cache, attn_metadata)
|
|
|
|
output = torch.empty_like(suffix_output)
|
|
merge_attn_states(
|
|
output=output,
|
|
prefix_output=context_output,
|
|
prefix_lse=context_lse,
|
|
suffix_output=suffix_output,
|
|
suffix_lse=suffix_lse,
|
|
)
|
|
|
|
# slice by `:v.shape[-1]` in order to remove v headdim padding
|
|
output = output\
|
|
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
|
.reshape(-1, self.num_heads * v.shape[-1])
|
|
|
|
return self.o_proj(output)[0]
|
|
|
|
@abstractmethod
|
|
def _forward_decode(
|
|
self,
|
|
ql_nope: torch.Tensor,
|
|
q_pe: torch.Tensor,
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
def forward(
|
|
self,
|
|
layer: AttentionLayer,
|
|
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
|
k_c_normed: torch.Tensor, # key in unified attn
|
|
k_pe: torch.Tensor, # value in unified attn
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: T,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if output is not None:
|
|
raise NotImplementedError(
|
|
"output is not yet supported for MLAImplBase")
|
|
|
|
if attn_metadata.is_profile_run and \
|
|
attn_metadata.context_chunk_workspace is not None:
|
|
# During the profile run try to simulate to worse case output size
|
|
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
|
|
# since this can be large
|
|
_ = torch.empty(
|
|
(attn_metadata.context_chunk_workspace.shape[0],
|
|
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
|
|
device=k_c_normed.device,
|
|
dtype=k_c_normed.dtype,
|
|
)
|
|
|
|
has_decode = attn_metadata.decode_metadata is not None
|
|
has_prefill = attn_metadata.prefill_metadata is not None
|
|
|
|
# Restore head dim (for rotary embedding)
|
|
k_pe = k_pe.unsqueeze(1)
|
|
assert hasattr(attn_metadata, "input_positions")
|
|
|
|
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
|
|
|
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
|
decode_k_pe = k_pe[num_prefill_tokens:]
|
|
decode_input_positions = \
|
|
attn_metadata.input_positions[num_prefill_tokens:]
|
|
|
|
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
|
prefill_k_pe = k_pe[:num_prefill_tokens]
|
|
prefill_input_positions = \
|
|
attn_metadata.input_positions[:num_prefill_tokens]
|
|
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
|
|
|
if has_decode:
|
|
decode_ql_nope, decode_q_pe = \
|
|
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
|
decode_input_positions, decode_q_pe, decode_k_pe)
|
|
|
|
if has_prefill:
|
|
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
.view(-1, self.num_heads, self.qk_head_dim)
|
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
|
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
|
|
|
# write the latent and rope to kv cache
|
|
if kv_cache.numel() > 0:
|
|
ops.concat_and_cache_mla(
|
|
k_c_normed,
|
|
k_pe.squeeze(1),
|
|
kv_cache,
|
|
attn_metadata.slot_mapping.flatten(),
|
|
kv_cache_dtype=self.kv_cache_dtype,
|
|
scale=layer._k_scale,
|
|
)
|
|
|
|
output = torch.empty(attn_metadata.num_prefill_tokens +
|
|
attn_metadata.num_decode_tokens,
|
|
self.o_proj.output_size,
|
|
device=hidden_states_or_q_c.device,
|
|
dtype=hidden_states_or_q_c.dtype)
|
|
if has_prefill:
|
|
output[:num_prefill_tokens] = self._forward_prefill(
|
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
|
attn_metadata)
|
|
|
|
if has_decode:
|
|
output[num_prefill_tokens:] = self._forward_decode(
|
|
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
|
|
|
|
return output
|