diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index c4c44b83ae6bf..6ebfa47a9dc3f 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -24,6 +24,7 @@ class MLAModules: q_b_proj: torch.nn.Module | None q_proj: torch.nn.Module | None indexer: torch.nn.Module | None + indexer_rotary_emb: torch.nn.Module | None is_sparse: bool topk_indices_buffer: torch.Tensor | None @@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj self.indexer = mla_modules.indexer + self.indexer_rope_emb = mla_modules.indexer_rotary_emb self.is_sparse = mla_modules.is_sparse if self.indexer is not None: @@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp): ) if self.indexer and self.is_sparse: - _topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb) + _topk_indices = self.indexer( + hidden_states, q_c, positions, self.indexer_rope_emb + ) attn_out = self.mla_attn( q, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 6675b2133f386..c0ff621d84085 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -837,8 +837,8 @@ class Indexer(nn.Module): ) q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) - q = torch.cat([q_pe, q_nope], dim=-1) - k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1) + k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion q = q.view(-1, self.head_dim) @@ -987,6 +987,14 @@ class DeepseekV2MLAAttention(nn.Module): self.is_v32 = hasattr(config, "index_topk") if self.is_v32: + self.indexer_rope_emb = get_rope( + qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) self.indexer = Indexer( vllm_config, config, @@ -998,6 +1006,7 @@ class DeepseekV2MLAAttention(nn.Module): f"{prefix}.indexer", ) else: + self.indexer_rope_emb = None self.indexer = None mla_modules = MLAModules( @@ -1015,6 +1024,7 @@ class DeepseekV2MLAAttention(nn.Module): q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, indexer=self.indexer, + indexer_rotary_emb=self.indexer_rope_emb, is_sparse=self.is_v32, topk_indices_buffer=topk_indices_buffer, )