yihong 04149cce27
[BugFix] fix some typos found by typos. (#16314)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2025-04-09 03:43:59 -07:00

1422 lines
59 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import functools
from abc import abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, MLAAttentionImpl)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
else:
merge_attn_states = None
triton_attention = None
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
is_vllm_fa = True
except ImportError:
is_vllm_fa = False
try:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
is_hip = current_platform.is_rocm()
class MLACommonBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> Type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
def get_state_cls() -> Type["MLACommonState"]:
return MLACommonState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
T = TypeVar("T", bound="MLACommonMetadata")
class MLACommonState(AttentionState, Generic[T]):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
scheduler_config = runner.scheduler_config
self.model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
self.context_chunk_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * self.model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.context_chunk_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._positions = torch.zeros((max_batch_size, ),
dtype=torch.long,
device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._positions
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
use_cuda_graph=True,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"MLACommonState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"input_positions": attn_metadata.decode_metadata.input_positions,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"MLACommonState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_positions = attn_metadata.input_positions
num_positions = input_positions.shape[0]
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# CUDA graph buffer is padded so only perform a partial copy based on
# num_positions
input_buffers["input_positions"][:num_positions].copy_(
input_positions, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
if self.chunked_prefill_enabled or self.enable_prefix_caching:
if not hasattr(self, "context_chunk_workspace"):
# not self.runner.device does not return the correct device
# for this process, (init_device sets the correct device but
# only on the Worker). The only way Ive figured out to get the
# correct device is to allocate the workspace on the first call
# to begin_forward and use the device of the input tokens
assert model_input.input_tokens is not None
self.context_chunk_workspace = torch.empty(
(self.context_chunk_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=model_input.input_tokens.device,
)
model_input.attn_metadata.context_chunk_workspace = \
self.context_chunk_workspace
@dataclass
class MLACommonMetadata(AttentionMetadata):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# New for MLA (compared to FlashAttention)
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional[Any] = None
_cached_decode_metadata: Optional[Any] = None
num_prefill_tokens: int
# The dimension of the attention heads
head_dim: Optional[int] = None
# Used when chunked prefill is enabled to simulate worst case workspace
# allocations, hopefully to avoid going OOM
is_profile_run: bool = False
# New for MLA (compared to FlashAttention)
# For chunked prefill
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
context_chunk_starts: Optional[torch.Tensor] = None
context_chunk_seq_tot: Optional[List[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
context_chunk_workspace: Optional[torch.Tensor] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f" received {self.head_dim}.")
@property
def prefill_metadata(self):
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = self.__class__(
# Required by ModelRunner
use_cuda_graph=False, # Not Attention Related
# Required by Attention Metadata
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
# MLACommonMetadata
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
head_dim=self.head_dim,
is_profile_run=self.is_profile_run,
# MLACommonMetadata Chunk prefill specific
context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens,
context_chunk_starts=self.context_chunk_starts,
context_chunk_seq_tot=self.context_chunk_seq_tot,
context_chunk_max_seq_lens=self.context_chunk_max_seq_lens,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self):
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.seq_lens_tensor is not None
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = self.__class__(
# Required by ModelRunner
use_cuda_graph=self.use_cuda_graph, # Not Attention Related
# Required by Attention Metadata
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
# MLACommonMetadata
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
input_positions=input_positions,
head_dim=self.head_dim,
is_profile_run=self.is_profile_run)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.chunked_prefill_enabled = \
self.runner.scheduler_config.chunked_prefill_enabled
self.enable_prefix_caching = \
self.runner.cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
attn_state = self.input_builder.runner.attn_state
self.context_chunk_workspace_size = \
attn_state.context_chunk_workspace_size
self.page_size = self.runner.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.input_positions: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
input_positions = async_tensor_h2d(self.input_positions, torch.long,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
context_chunk_cu_seq_lens = None
context_chunk_starts = None
context_chunk_seq_tot = None
context_chunk_max_seq_lens = None
if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
and self.num_prefills > 0 \
and context_lens_tensor is not None \
and context_lens_tensor[:self.num_prefills].max() > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section in
# the comment at the top of the file before trying to understand
# the following code
num_prefills_with_context = \
(context_lens_tensor[:self.num_prefills] > 0).sum().item()
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.context_chunk_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
context_chunk_starts = \
torch.arange(num_chunks, device=device, dtype=torch.int32)\
.unsqueeze(1).expand(-1, self.num_prefills)\
* max_context_chunk
chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\
.unsqueeze(0), context_chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
torch.int32)
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\
.unsqueeze(-1)
context_chunk_cu_seq_lens = \
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
context_chunk_max_seq_lens = \
chunk_seq_lens.max(dim=1).values.tolist()
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
assert max(context_chunk_seq_tot) <= \
self.context_chunk_workspace_size
return self.runner.attn_backend.make_metadata(
# Required by ModelRunner
use_cuda_graph=use_captured_graph, # Not Attention Related
# Required by Attention Metadata
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps=None, # Not Attention Related
enable_kv_scales_calculation=False,
# MLACommonMetadata
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
head_dim=self.runner.model_config.get_head_size(),
is_profile_run=self.runner.in_profile_run,
# MLACommonMetadata Chunk prefill specific
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
context_chunk_starts=context_chunk_starts,
context_chunk_seq_tot=context_chunk_seq_tot,
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
)
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version()
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
def _compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
assert prefill_metadata.context_chunk_seq_tot is not None
assert prefill_metadata.context_chunk_cu_seq_lens is not None
assert prefill_metadata.context_chunk_starts is not None
assert prefill_metadata.context_chunk_max_seq_lens is not None
assert prefill_metadata.context_lens_tensor is not None
output = None
iters = len(prefill_metadata.context_chunk_seq_tot)
# Fetch from attn_metadata directly, since it late bound by
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
# any weirdness around prefill_metadata caching
assert attn_metadata.context_chunk_workspace is not None
workspace = attn_metadata.context_chunk_workspace
for i in range(iters):
toks = prefill_metadata.context_chunk_seq_tot[i]
ops.gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=prefill_metadata.block_tables,
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills,
seq_starts=prefill_metadata.context_chunk_starts[i],
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank]
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v,
[0, q.shape[-1] - v.shape[-1]],
value=0)
if is_vllm_fa:
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
has_context = prefill_metadata.context_lens_tensor is not None \
and prefill_metadata.context_lens_tensor.max() > 0
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
output = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
output = torch.empty_like(suffix_output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod
def _forward_decode(
self,
ql_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if attn_metadata.is_profile_run and \
attn_metadata.context_chunk_workspace is not None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(attn_metadata.context_chunk_workspace.shape[0],
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
)
has_decode = attn_metadata.decode_metadata is not None
has_prefill = attn_metadata.prefill_metadata is not None
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
attn_metadata.input_positions[num_prefill_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
attn_metadata.input_positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
output = torch.empty(attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens,
self.o_proj.output_size,
device=hidden_states_or_q_c.device,
dtype=hidden_states_or_q_c.dtype)
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output