[Attention] MLA move o_proj q_proj into cuda-graph region (#17484)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-05-01 23:16:26 -04:00 committed by GitHub
parent afb12e4294
commit afcb3f8863
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 55 additions and 85 deletions

View File

@ -281,8 +281,7 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
output = output.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
return output.reshape(-1, self.num_heads * v.shape[-1])
def _forward_decode(
self,
@ -303,4 +302,4 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -239,4 +239,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True,
)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -207,7 +207,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@ -1032,12 +1032,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -1055,9 +1050,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
@ -1141,27 +1134,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x):
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -1345,7 +1324,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse,
)
return self.o_proj(output.flatten(start_dim=-2))[0]
return output.flatten(start_dim=-2)
@abstractmethod
def _forward_decode(
@ -1360,7 +1339,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
q: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
@ -1391,27 +1370,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert hasattr(attn_metadata, "input_positions")
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
q = q.view(-1, self.num_heads, self.qk_head_dim)
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
decode_q = q[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
attn_metadata.input_positions[num_prefill_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
prefill_q = q[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
attn_metadata.input_positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
@ -1429,9 +1413,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
output = torch.empty(attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens,
self.o_proj.output_size,
device=hidden_states_or_q_c.device,
dtype=hidden_states_or_q_c.dtype)
self.v_head_dim * self.num_heads,
device=q.device,
dtype=q.dtype)
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,

View File

@ -409,4 +409,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -110,4 +110,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
decode_meta.seq_lens_tensor, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -454,9 +454,7 @@ class DeepseekV2MLAAttention(nn.Module):
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
@ -468,17 +466,22 @@ class DeepseekV2MLAAttention(nn.Module):
hidden_states: torch.Tensor,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
hidden_states_or_q_c = hidden_states
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
class DeepseekV2DecoderLayer(nn.Module):

View File

@ -200,7 +200,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
@ -597,12 +597,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -625,9 +620,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if current_platform.is_cuda():
self.rotary_emb = rotary_emb.forward_cuda
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
@ -684,27 +677,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse
return attn_out
def _v_up_proj_and_o_proj(self, x):
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -874,7 +853,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse=suffix_lse,
)
return self.o_proj(output.flatten(start_dim=-2))[0]
return output.flatten(start_dim=-2)
@abstractmethod
def _forward_decode(
@ -889,7 +868,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
q: torch.Tensor,
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
@ -908,7 +887,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
@ -923,24 +902,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
q = q.view(-1, self.num_heads, self.qk_head_dim)
decode_q = q[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_q = q[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if has_decode:
assert attn_metadata.decode is not None
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(

View File

@ -146,4 +146,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True,
)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -115,4 +115,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)