mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 14:18:00 +08:00
[DeepSeek] Fix DeepSeek V3.2 Rope Embedding (#28968)
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
613abb50d5
commit
88f5b19f0b
@ -24,6 +24,7 @@ class MLAModules:
|
|||||||
q_b_proj: torch.nn.Module | None
|
q_b_proj: torch.nn.Module | None
|
||||||
q_proj: torch.nn.Module | None
|
q_proj: torch.nn.Module | None
|
||||||
indexer: torch.nn.Module | None
|
indexer: torch.nn.Module | None
|
||||||
|
indexer_rotary_emb: torch.nn.Module | None
|
||||||
is_sparse: bool
|
is_sparse: bool
|
||||||
topk_indices_buffer: torch.Tensor | None
|
topk_indices_buffer: torch.Tensor | None
|
||||||
|
|
||||||
@ -80,6 +81,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
self.rotary_emb = mla_modules.rotary_emb
|
self.rotary_emb = mla_modules.rotary_emb
|
||||||
self.o_proj = mla_modules.o_proj
|
self.o_proj = mla_modules.o_proj
|
||||||
self.indexer = mla_modules.indexer
|
self.indexer = mla_modules.indexer
|
||||||
|
self.indexer_rope_emb = mla_modules.indexer_rotary_emb
|
||||||
self.is_sparse = mla_modules.is_sparse
|
self.is_sparse = mla_modules.is_sparse
|
||||||
|
|
||||||
if self.indexer is not None:
|
if self.indexer is not None:
|
||||||
@ -153,7 +155,9 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.indexer and self.is_sparse:
|
if self.indexer and self.is_sparse:
|
||||||
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)
|
_topk_indices = self.indexer(
|
||||||
|
hidden_states, q_c, positions, self.indexer_rope_emb
|
||||||
|
)
|
||||||
|
|
||||||
attn_out = self.mla_attn(
|
attn_out = self.mla_attn(
|
||||||
q,
|
q,
|
||||||
|
|||||||
@ -837,8 +837,8 @@ class Indexer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
|
||||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
k = torch.cat([k_pe.squeeze((0, 2)), k_nope], dim=-1)
|
||||||
|
|
||||||
# we only quant q here since k quant is fused with cache insertion
|
# we only quant q here since k quant is fused with cache insertion
|
||||||
q = q.view(-1, self.head_dim)
|
q = q.view(-1, self.head_dim)
|
||||||
@ -987,6 +987,14 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
self.is_v32 = hasattr(config, "index_topk")
|
self.is_v32 = hasattr(config, "index_topk")
|
||||||
|
|
||||||
if self.is_v32:
|
if self.is_v32:
|
||||||
|
self.indexer_rope_emb = get_rope(
|
||||||
|
qk_rope_head_dim,
|
||||||
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=True,
|
||||||
|
)
|
||||||
self.indexer = Indexer(
|
self.indexer = Indexer(
|
||||||
vllm_config,
|
vllm_config,
|
||||||
config,
|
config,
|
||||||
@ -998,6 +1006,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
f"{prefix}.indexer",
|
f"{prefix}.indexer",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self.indexer_rope_emb = None
|
||||||
self.indexer = None
|
self.indexer = None
|
||||||
|
|
||||||
mla_modules = MLAModules(
|
mla_modules = MLAModules(
|
||||||
@ -1015,6 +1024,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||||
indexer=self.indexer,
|
indexer=self.indexer,
|
||||||
|
indexer_rotary_emb=self.indexer_rope_emb,
|
||||||
is_sparse=self.is_v32,
|
is_sparse=self.is_v32,
|
||||||
topk_indices_buffer=topk_indices_buffer,
|
topk_indices_buffer=topk_indices_buffer,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user