[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-11-19 16:30:04 -05:00 committed by GitHub
parent 613abb50d5
commit 88f5b19f0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 3 deletions

View File

@ -24,6 +24,7 @@ class MLAModules:
q_b_proj: torch.nn.Module | None
q_proj: torch.nn.Module | None
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module | None
is_sparse: bool
topk_indices_buffer: torch.Tensor | None
@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
_topk_indices = self.indexer(
hidden_states, q_c, positions, self.indexer_rope_emb
)
attn_out = self.mla_attn(
q,

View File

@ -837,8 +837,8 @@ class Indexer(nn.Module):
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
@ -987,6 +987,14 @@ class DeepseekV2MLAAttention(nn.Module):
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
self.indexer_rope_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=True,
)
self.indexer = Indexer(
vllm_config,
config,
@ -998,6 +1006,7 @@ class DeepseekV2MLAAttention(nn.Module):
f"{prefix}.indexer",
)
else:
self.indexer_rope_emb = None
self.indexer = None
mla_modules = MLAModules(
@ -1015,6 +1024,7 @@ class DeepseekV2MLAAttention(nn.Module):
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=self.indexer,
indexer_rotary_emb=self.indexer_rope_emb,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
)