mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 23:53:14 +08:00
[Attention] MLA move o_proj q_proj into cuda-graph region (#17484)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
afb12e4294
commit
afcb3f8863
@ -281,8 +281,7 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
# remove padding
|
||||
output = output.view(-1, self.num_heads,
|
||||
q.shape[-1])[..., :v.shape[-1]]
|
||||
output = output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
return self.o_proj(output)[0]
|
||||
return output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
@ -303,4 +302,4 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor)
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@ -239,4 +239,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@ -207,7 +207,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
@ -1032,12 +1032,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -1055,9 +1050,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
self.rotary_emb = rotary_emb
|
||||
self.use_yarn_rope = isinstance(rotary_emb,
|
||||
DeepseekScalingRotaryEmbedding)
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
self.triton_fa_func = triton_attention
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
@ -1141,27 +1134,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
return attn_out, rest[0]
|
||||
return attn_out
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
@ -1345,7 +1324,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||
return output.flatten(start_dim=-2)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
@ -1360,7 +1339,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
q: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
@ -1391,27 +1370,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
|
||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
||||
decode_q = q[num_prefill_tokens:]
|
||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||
decode_input_positions = \
|
||||
attn_metadata.input_positions[num_prefill_tokens:]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
||||
prefill_q = q[:num_prefill_tokens]
|
||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||
prefill_input_positions = \
|
||||
attn_metadata.input_positions[:num_prefill_tokens]
|
||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||
|
||||
if has_decode:
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
@ -1429,9 +1413,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||
attn_metadata.num_decode_tokens,
|
||||
self.o_proj.output_size,
|
||||
device=hidden_states_or_q_c.device,
|
||||
dtype=hidden_states_or_q_c.dtype)
|
||||
self.v_head_dim * self.num_heads,
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
if has_prefill:
|
||||
output[:num_prefill_tokens] = self._forward_prefill(
|
||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||
|
||||
@ -409,4 +409,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_lens)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@ -110,4 +110,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
decode_meta.seq_lens_tensor, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@ -454,9 +454,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
@ -468,17 +466,22 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
q_c = self.q_a_proj(hidden_states)[0]
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0]
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(hidden_states.shape[0],
|
||||
self.num_local_heads * self.v_head_dim))
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
@ -200,7 +200,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
@ -597,12 +597,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: ColumnParallelLinear,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -625,9 +620,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
if current_platform.is_cuda():
|
||||
self.rotary_emb = rotary_emb.forward_cuda
|
||||
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
@ -684,27 +677,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
return attn_out, lse
|
||||
return attn_out
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return self.o_proj(x)[0]
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
q_nope, q_pe = self.q_proj(x)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
return ql_nope.transpose(0, 1), q_pe
|
||||
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
@ -874,7 +853,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
||||
return output.flatten(start_dim=-2)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
@ -889,7 +868,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
@ -908,7 +887,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_toks, ...]
|
||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
@ -923,24 +902,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
decode_q = q[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
prefill_q = q[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_ql_nope, decode_q_pe = \
|
||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
|
||||
@ -146,4 +146,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
@ -115,4 +115,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
return self._v_up_proj(o)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user