[Attention] MLA move rotary embedding to cuda-graph region (#17668)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-05-08 23:14:42 -04:00 committed by GitHub
parent 760e3ecc8f
commit 5e6f939484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 35 additions and 121 deletions

View File

@ -211,8 +211,6 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
@ -377,7 +375,6 @@ class MLACommonState(AttentionState, Generic[T]):
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
@ -393,7 +390,6 @@ class MLACommonState(AttentionState, Generic[T]):
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"input_positions": attn_metadata.decode_metadata.input_positions,
}
if is_encoder_decoder_model:
raise NotImplementedError(
@ -405,16 +401,10 @@ class MLACommonState(AttentionState, Generic[T]):
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_positions = attn_metadata.input_positions
num_positions = input_positions.shape[0]
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# CUDA graph buffer is padded so only perform a partial copy based on
# num_positions
input_buffers["input_positions"][:num_positions].copy_(
input_positions, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata):
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# New for MLA (compared to FlashAttention)
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
@ -563,8 +548,6 @@ class MLACommonMetadata(AttentionMetadata):
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = self.__class__(
# Required by ModelRunner
@ -578,7 +561,6 @@ class MLACommonMetadata(AttentionMetadata):
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
# MLACommonMetadata
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
@ -615,8 +597,6 @@ class MLACommonMetadata(AttentionMetadata):
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = self.__class__(
# Required by ModelRunner
@ -646,7 +626,6 @@ class MLACommonMetadata(AttentionMetadata):
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
input_positions=input_positions,
head_dim=self.head_dim,
is_profile_run=self.is_profile_run)
return self._cached_decode_metadata
@ -765,7 +744,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.input_positions: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
@ -786,13 +764,11 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
@ -912,8 +888,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
input_positions = async_tensor_h2d(self.input_positions, torch.long,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
@ -987,7 +961,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
multi_modal_placeholder_index_maps=None, # Not Attention Related
enable_kv_scales_calculation=False,
# MLACommonMetadata
input_positions=input_positions,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
@ -1033,7 +1006,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
kv_b_proj: ColumnParallelLinear,
) -> None:
self.num_heads = num_heads
@ -1048,10 +1020,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.kv_b_proj = kv_b_proj
self.triton_fa_func = triton_attention
@ -1367,41 +1335,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
has_decode = attn_metadata.decode_metadata is not None
has_prefill = attn_metadata.prefill_metadata is not None
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
q = q.view(-1, self.num_heads, self.qk_head_dim)
decode_q = q[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
attn_metadata.input_positions[num_prefill_tokens:]
prefill_q = q[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
attn_metadata.input_positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
@ -1424,6 +1366,15 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata)
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[num_prefill_tokens:] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)

View File

@ -148,13 +148,11 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1

View File

@ -808,8 +808,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)

View File

@ -453,7 +453,6 @@ class DeepseekV2MLAAttention(nn.Module):
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
kv_b_proj=self.kv_b_proj,
)
@ -475,6 +474,13 @@ class DeepseekV2MLAAttention(nn.Module):
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,

View File

@ -204,7 +204,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@ -269,9 +268,6 @@ class MLACommonPrefillMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
@ -280,9 +276,6 @@ class MLACommonPrefillMetadata:
@dataclass
class MLACommonDecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
@ -443,10 +436,8 @@ class MLACommonMetadataBuilder(Generic[M]):
return modified_batch
def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor, seq_lens: torch.Tensor):
def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
)
@ -464,8 +455,6 @@ class MLACommonMetadataBuilder(Generic[M]):
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@ -473,7 +462,6 @@ class MLACommonMetadataBuilder(Generic[M]):
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
@ -542,7 +530,6 @@ class MLACommonMetadataBuilder(Generic[M]):
self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata(
input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
@ -552,7 +539,6 @@ class MLACommonMetadataBuilder(Generic[M]):
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
@ -599,7 +585,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
kv_b_proj: ColumnParallelLinear,
) -> None:
self.num_heads = num_heads
@ -614,15 +599,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_native
if current_platform.is_cuda():
self.rotary_emb = rotary_emb.forward_cuda
self.kv_b_proj = kv_b_proj
self.vllm_flash_attn_version = get_flash_attn_version()
@ -894,9 +870,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
@ -905,35 +878,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
q = q.view(-1, self.num_heads, self.qk_head_dim)
decode_q = q[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
@ -951,6 +901,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata)
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
output[:num_decode_tokens] = self._forward_decode(
decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)

View File

@ -58,8 +58,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor,
def _build_decode(self, block_table: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
@ -69,7 +68,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
return FlashMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata,