mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:34:28 +08:00
[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
613abb50d5
commit
88f5b19f0b
@ -24,6 +24,7 @@ class MLAModules:
|
||||
q_b_proj: torch.nn.Module | None
|
||||
q_proj: torch.nn.Module | None
|
||||
indexer: torch.nn.Module | None
|
||||
indexer_rotary_emb: torch.nn.Module | None
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: torch.Tensor | None
|
||||
|
||||
@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
|
||||
_topk_indices = self.indexer(
|
||||
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||
)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
|
||||
@ -837,8 +837,8 @@ class Indexer(nn.Module):
|
||||
)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
@ -987,6 +987,14 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
if self.is_v32:
|
||||
self.indexer_rope_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.indexer = Indexer(
|
||||
vllm_config,
|
||||
config,
|
||||
@ -998,6 +1006,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
f"{prefix}.indexer",
|
||||
)
|
||||
else:
|
||||
self.indexer_rope_emb = None
|
||||
self.indexer = None
|
||||
|
||||
mla_modules = MLAModules(
|
||||
@ -1015,6 +1024,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||
indexer=self.indexer,
|
||||
indexer_rotary_emb=self.indexer_rope_emb,
|
||||
is_sparse=self.is_v32,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user