[Model][V1] Support Ernie MTP (#22169)

Signed-off-by: zhouchong <zhouchong03@baidu.com>
Co-authored-by: zhouchong <zhouchong03@baidu.com>
This commit is contained in:
xyxinyang 2025-08-20 20:41:55 +08:00 committed by GitHub
parent 50df09fe13
commit 7cd17e22d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 320 additions and 7 deletions

View File

@ -556,6 +556,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
is_available_online=False, is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16", speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16"), tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
"ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True,
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"),
"Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5",
speculative_model="zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.54", min_transformers_version="4.54",

View File

@ -1463,7 +1463,8 @@ class ModelConfig:
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
if (self.hf_text_config.model_type == "deepseek_mtp" if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp" or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"): or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config, total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0) "num_nextn_predict_layers", 0)
else: else:
@ -1911,7 +1912,8 @@ class DeviceConfig:
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"] "mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
@config @config
@ -2044,6 +2046,16 @@ class SpeculativeConfig:
"architectures": ["Glm4MoeMTPModel"] "architectures": ["Glm4MoeMTPModel"]
}) })
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config
return hf_config return hf_config
def __post_init__(self): def __post_init__(self):
@ -2062,8 +2074,8 @@ class SpeculativeConfig:
if self.target_model_config and \ if self.target_model_config and \
(self.target_model_config.hf_text_config.model_type \ (self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or == "deepseek_v3" or
self.target_model_config.hf_text_config.model_type \ self.target_model_config.hf_text_config.model_type in
== "mimo"): ("mimo","ernie4_5_moe")):
# use the draft model from the same model: # use the draft model from the same model:
self.model = self.target_model_config.model self.model = self.target_model_config.model
elif self.method in ("ngram", "[ngram]"): elif self.method in ("ngram", "[ngram]"):
@ -2161,6 +2173,15 @@ class SpeculativeConfig:
"one layer. Might need some code changes " \ "one layer. Might need some code changes " \
"to support multiple layers." "to support multiple layers."
) )
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else: else:
self.method = "draft_model" self.method = "draft_model"
raise NotImplementedError( raise NotImplementedError(
@ -2376,7 +2397,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens return self.num_speculative_tokens
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp") return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method

View File

@ -0,0 +1,287 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Ernie-MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .llama import LlamaDecoderLayer
from .utils import is_pp_missing_parameter, maybe_prefix
class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.mtp_hidden_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
prefix)
def forward(
self,
inputs_embeds: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = self.mtp_emb_norm(inputs_embeds)
previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states)
hidden_states = self.mtp_linear_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
class ErnieMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
ErnieMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)](
inputs_embeds,
positions,
previous_hidden_states,
spec_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
lm_head: ParallelLMHead,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> torch.Tensor:
self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]
logits = self.logits_processor(lm_head, hidden_states,
sampling_metadata)
return logits
class ErnieMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.sampler = get_sampler()
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
assert spec_step_idx == 0, "ernie_mtp only support predict one token"
hidden_states = self.model(input_ids, positions, hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.model.compute_logits(hidden_states, self.lm_head,
sampling_metadata, spec_step_idx)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and name.endswith(
"lm_head.weight"):
continue
if "rotary_emb.inv_freq" in name:
continue
if "mtp" in name:
name = self._rewrite_spec_layer_name(self.config, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if "mtp" not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if "mtp_" not in name and ("embed_tokens" not in name
and "lm_head" not in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _rewrite_spec_layer_name(self, config: PretrainedConfig,
name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
"""
spec_layer_weight_names = [
"embed_tokens", "mtp_emb_norm", "mtp_hidden_norm",
"mtp_linear_proj"
]
layer_idx = config.num_hidden_layers
for weight_name in spec_layer_weight_names:
if weight_name in name:
name = name.replace(
f"model.{weight_name}.0.",
f"model.layers.{layer_idx}.{weight_name}.")
return name
name = name.replace("model.mtp_block.0.",
f"model.layers.{layer_idx}.mtp_block.")
return name

View File

@ -266,6 +266,7 @@ _SPECULATIVE_DECODING_MODELS = {
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
"MedusaModel": ("medusa", "Medusa"), "MedusaModel": ("medusa", "Medusa"),
# Temporarily disabled. # Temporarily disabled.

View File

@ -194,7 +194,7 @@ class EagleProposer:
hidden_states=self.hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if self.method == "deepseek_mtp": if self.method in ("deepseek_mtp", "ernie_mtp"):
last_hidden_states = ret_hidden_states last_hidden_states = ret_hidden_states
else: else:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states

View File

@ -77,7 +77,8 @@ class Worker(LocalOrDistributedWorkerBase):
"eagle", "eagle",
"deepseek_mtp", "deepseek_mtp",
"glm4_moe_mtp", "glm4_moe_mtp",
"mimo_mtp")) \ "mimo_mtp",
"ernie_mtp")) \
else {"return_hidden_states": True} else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner