mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 04:04:32 +08:00
[V1][Speculative Decoding] Fix DeepSeek MTP (#20022)
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
This commit is contained in:
parent
bf5181583f
commit
8359f4c8d8
@ -52,11 +52,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
@ -74,8 +69,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_index: int = 0,
|
spec_step_index: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
assert inputs_embeds is not None
|
assert inputs_embeds is not None
|
||||||
# masking inputs at position 0, as not needed by MTP
|
# masking inputs at position 0, as not needed by MTP
|
||||||
inputs_embeds[positions == 0] = 0
|
inputs_embeds[positions == 0] = 0
|
||||||
@ -112,7 +105,10 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
for idx in range(self.mtp_start_layer_idx,
|
for idx in range(self.mtp_start_layer_idx,
|
||||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||||
})
|
})
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
)
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -123,6 +119,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
spec_step_idx: int = 0,
|
spec_step_idx: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||||
input_ids,
|
input_ids,
|
||||||
@ -242,6 +240,12 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||||
|
# shares embedding layer. We only load the first weights.
|
||||||
|
if (spec_layer != self.model.mtp_start_layer_idx
|
||||||
|
and ".layers" not in name):
|
||||||
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
@ -253,17 +257,25 @@ class DeepSeekMTP(nn.Module, SupportsPP):
|
|||||||
"""
|
"""
|
||||||
Rewrite the weight name to match the format of the original model.
|
Rewrite the weight name to match the format of the original model.
|
||||||
Add .mtp_block for modules in transformer layer block for spec layer
|
Add .mtp_block for modules in transformer layer block for spec layer
|
||||||
|
and rename shared layer weights to be top level.
|
||||||
"""
|
"""
|
||||||
spec_layer_weight_names = [
|
spec_layer_weight_names = [
|
||||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||||
]
|
]
|
||||||
|
shared_weight_names = ["embed_tokens"]
|
||||||
spec_layer_weight = False
|
spec_layer_weight = False
|
||||||
|
shared_weight = False
|
||||||
for weight_name in spec_layer_weight_names:
|
for weight_name in spec_layer_weight_names:
|
||||||
if weight_name in name:
|
if weight_name in name:
|
||||||
spec_layer_weight = True
|
spec_layer_weight = True
|
||||||
|
if weight_name in shared_weight_names:
|
||||||
|
shared_weight = True
|
||||||
break
|
break
|
||||||
if not spec_layer_weight:
|
if not spec_layer_weight:
|
||||||
# treat rest weights as weights for transformer layer block
|
# treat rest weights as weights for transformer layer block
|
||||||
name = name.replace(f"model.layers.{spec_layer}.",
|
name = name.replace(f"model.layers.{spec_layer}.",
|
||||||
f"model.layers.{spec_layer}.mtp_block.")
|
f"model.layers.{spec_layer}.mtp_block.")
|
||||||
|
elif shared_weight:
|
||||||
|
# treat shared weights as top level weights
|
||||||
|
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||||
return name
|
return name
|
||||||
|
|||||||
@ -148,7 +148,7 @@ class EagleProposer:
|
|||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# FIXME: need to consider multiple kv_cache_groups
|
# FIXME: need to consider multiple kv_cache_groups
|
||||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
attn_metadata = self.runner.attn_metadata_builders[0].build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user