diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py index 528df2e98679c..4567893a9ef7c 100644 --- a/vllm/attention/backends/cpu_mla.py +++ b/vllm/attention/backends/cpu_mla.py @@ -281,8 +281,7 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): # remove padding output = output.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]] - output = output.reshape(-1, self.num_heads * v.shape[-1]) - return self.o_proj(output)[0] + return output.reshape(-1, self.num_heads * v.shape[-1]) def _forward_decode( self, @@ -303,4 +302,4 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, decode_meta.block_tables, decode_meta.seq_lens_tensor) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 5d0c230933105..0e62748ddbee4 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -239,4 +239,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): causal=True, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 382a9a6d44d84..12d85b74244f4 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -207,7 +207,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, + LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -1032,12 +1032,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -1055,9 +1050,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.rotary_emb = rotary_emb self.use_yarn_rope = isinstance(rotary_emb, DeepseekScalingRotaryEmbedding) - self.q_proj = q_proj self.kv_b_proj = kv_b_proj - self.o_proj = o_proj self.triton_fa_func = triton_attention # Handle the differences between the flash_attn_varlen from flash_attn @@ -1141,27 +1134,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return attn_out, rest[0] return attn_out - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -1345,7 +1324,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): suffix_lse=suffix_lse, ) - return self.o_proj(output.flatten(start_dim=-2))[0] + return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( @@ -1360,7 +1339,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): def forward( self, layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn + q: torch.Tensor, # query in unified attn k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -1391,27 +1370,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): assert hasattr(attn_metadata, "input_positions") num_prefill_tokens: int = attn_metadata.num_prefill_tokens + q = q.view(-1, self.num_heads, self.qk_head_dim) - decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:] + decode_q = q[num_prefill_tokens:] decode_k_pe = k_pe[num_prefill_tokens:] decode_input_positions = \ attn_metadata.input_positions[num_prefill_tokens:] - prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens] + prefill_q = q[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] prefill_input_positions = \ attn_metadata.input_positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] if has_decode: - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_input_positions, decode_q_pe, decode_k_pe) if has_prefill: - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_input_positions, prefill_q_pe, prefill_k_pe) @@ -1429,9 +1413,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): output = torch.empty(attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens, - self.o_proj.output_size, - device=hidden_states_or_q_c.device, - dtype=hidden_states_or_q_c.dtype) + self.v_head_dim * self.num_heads, + device=q.device, + dtype=q.dtype) if has_prefill: output[:num_prefill_tokens] = self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 6e695b78e0e15..2984bc1dad64a 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -409,4 +409,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): attn_metadata.paged_kv_indices, attn_metadata.paged_kv_last_page_lens) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index 61e5c76d9fda3..6945c2c6e29cd 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -110,4 +110,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): decode_meta.seq_lens_tensor, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ffa5840b46041..ce86b9b2c4f04 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -454,9 +454,7 @@ class DeepseekV2MLAAttention(nn.Module): qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, ) self.prefix = prefix @@ -468,17 +466,22 @@ class DeepseekV2MLAAttention(nn.Module): hidden_states: torch.Tensor, ) -> torch.Tensor: if self.q_lora_rank is not None: - ckq = self.q_a_proj(hidden_states)[0] - hidden_states_or_q_c = self.q_a_layernorm(ckq) + q_c = self.q_a_proj(hidden_states)[0] + q_c = self.q_a_layernorm(q_c) + q = self.q_b_proj(q_c)[0] else: - hidden_states_or_q_c = hidden_states + q = self.q_proj(hidden_states)[0] kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) - return self.mla_attn(hidden_states_or_q_c, - kv_c_normed, - k_pe, - output_shape=hidden_states.shape) + + attn_out = self.mla_attn( + q, + kv_c_normed, + k_pe, + output_shape=(hidden_states.shape[0], + self.num_local_heads * self.v_head_dim)) + return self.o_proj(attn_out)[0] class DeepseekV2DecoderLayer(nn.Module): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index fd3be901f4c38..3e77555d7f942 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -200,7 +200,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, RowParallelLinear, + LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform @@ -597,12 +597,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, - # q_proj should be q_b_proj if q_lora_rank is not None, but from an - # attention backend perspective we rely on the layer to pass in the - # correct matrix - q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, - o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -625,9 +620,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if current_platform.is_cuda(): self.rotary_emb = rotary_emb.forward_cuda - self.q_proj = q_proj self.kv_b_proj = kv_b_proj - self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn @@ -684,27 +677,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return attn_out, lse return attn_out - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = self.q_proj(x)[0]\ - .view(-1, self.num_heads, self.qk_head_dim)\ - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) def process_weights_after_loading(self, act_dtype: torch.dtype): @@ -874,7 +853,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): suffix_lse=suffix_lse, ) - return self.o_proj(output.flatten(start_dim=-2))[0] + return output.flatten(start_dim=-2) @abstractmethod def _forward_decode( @@ -889,7 +868,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): def forward( self, layer: AttentionLayer, - hidden_states_or_q_c: torch.Tensor, # query in unified attn + q: torch.Tensor, k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, @@ -908,7 +887,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # Inputs and outputs may be padded for CUDA graphs output_padded = output output = output[:num_actual_toks, ...] - hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...] + q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] @@ -923,24 +902,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] + q = q.view(-1, self.num_heads, self.qk_head_dim) + decode_q = q[:num_decode_tokens] decode_k_pe = k_pe[:num_decode_tokens] - prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] + prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] if has_decode: assert attn_metadata.decode is not None - decode_ql_nope, decode_q_pe = \ - self._q_proj_and_k_up_proj(decode_hs_or_q_c) + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) if has_prefill: assert attn_metadata.prefill is not None - prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ - .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 143bfe35bb5e5..f18c9c8b6462c 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -146,4 +146,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): causal=True, ) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 8e7e4f10b81b8..2e6b619db6287 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -115,4 +115,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata.decode.seq_lens, attn_logits, num_kv_splits, self.scale, PAGE_SIZE) - return self._v_up_proj_and_o_proj(o) + return self._v_up_proj(o)