diff --git a/tests/models/registry.py b/tests/models/registry.py index 739d96227971..6e6acfb8cd22 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -556,6 +556,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", + trust_remote_code=True, + speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"), "Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5", speculative_model="zai-org/GLM-4.5", min_transformers_version="4.54", diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 56a749789b6a..801fa97fe5da 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1463,7 +1463,8 @@ class ModelConfig: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" or self.hf_config.model_type == "mimo_mtp" - or self.hf_config.model_type == "glm4_moe_mtp"): + or self.hf_config.model_type == "glm4_moe_mtp" + or self.hf_config.model_type == "ernie_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1911,7 +1912,8 @@ class DeviceConfig: SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp"] + "mlp_speculator", "draft_model", "deepseek_mtp", + "ernie_mtp"] @config @@ -2044,6 +2046,16 @@ class SpeculativeConfig: "architectures": ["Glm4MoeMTPModel"] }) + if hf_config.model_type == "ernie4_5_moe": + hf_config.model_type = "ernie_mtp" + if hf_config.model_type == "ernie_mtp": + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "n_predict": n_predict, + "architectures": ["ErnieMTPModel"] + }) + return hf_config + return hf_config def __post_init__(self): @@ -2062,8 +2074,8 @@ class SpeculativeConfig: if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type \ - == "mimo"): + self.target_model_config.hf_text_config.model_type in + ("mimo","ernie4_5_moe")): # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -2161,6 +2173,15 @@ class SpeculativeConfig: "one layer. Might need some code changes " \ "to support multiple layers." ) + elif (self.draft_model_config.hf_config.model_type == + "ernie_mtp"): + self.method = "ernie_mtp" + if self.num_speculative_tokens > 1: + logger.warning( + "All Ernie MTP models only have " \ + "one layer. Might need some code changes " \ + "to support multiple layers." + ) else: self.method = "draft_model" raise NotImplementedError( @@ -2376,7 +2397,7 @@ class SpeculativeConfig: return self.num_speculative_tokens def use_eagle(self) -> bool: - return self.method in ("eagle", "eagle3", "deepseek_mtp") + return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") def __repr__(self) -> str: method = self.method diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py new file mode 100644 index 000000000000..90a1267b28f0 --- /dev/null +++ b/vllm/model_executor/models/ernie_mtp.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Ernie-MTP model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .llama import LlamaDecoderLayer +from .utils import is_pp_missing_parameter, maybe_prefix + + +class ErnieMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.mtp_emb_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_hidden_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, + prefix) + + def forward( + self, + inputs_embeds: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + + inputs_embeds = self.mtp_emb_norm(inputs_embeds) + previous_hidden_states = self.mtp_hidden_norm(previous_hidden_states) + + hidden_states = self.mtp_linear_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + + return hidden_states + + +class ErnieMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + ErnieMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + return self.layers[str(self.mtp_start_layer_idx + spec_step_idx)]( + inputs_embeds, + positions, + previous_hidden_states, + spec_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + lm_head: ParallelLMHead, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] + logits = self.logits_processor(lm_head, hidden_states, + sampling_metadata) + return logits + + +class ErnieMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config = vllm_config.model_config.hf_config + self.model = ErnieMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.sampler = get_sampler() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + assert spec_step_idx == 0, "ernie_mtp only support predict one token" + hidden_states = self.model(input_ids, positions, hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, self.lm_head, + sampling_metadata, spec_step_idx) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + continue + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + name = self._rewrite_spec_layer_name(self.config, name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if "mtp" not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if "mtp_" not in name and ("embed_tokens" not in name + and "lm_head" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, config: PretrainedConfig, + name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + """ + spec_layer_weight_names = [ + "embed_tokens", "mtp_emb_norm", "mtp_hidden_norm", + "mtp_linear_proj" + ] + layer_idx = config.num_hidden_layers + for weight_name in spec_layer_weight_names: + if weight_name in name: + name = name.replace( + f"model.{weight_name}.0.", + f"model.layers.{layer_idx}.{weight_name}.") + return name + name = name.replace("model.mtp_block.0.", + f"model.layers.{layer_idx}.mtp_block.") + return name diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a94231b0f846..78ef270598b8 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -266,6 +266,7 @@ _SPECULATIVE_DECODING_MODELS = { # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), # Temporarily disabled. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f995..8cd2ad12cfa3 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -194,7 +194,7 @@ class EagleProposer: hidden_states=self.hidden_states[:num_input_tokens], inputs_embeds=inputs_embeds, ) - if self.method == "deepseek_mtp": + if self.method in ("deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9dfea947568d..7a01e585ba6d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -77,7 +77,8 @@ class Worker(LocalOrDistributedWorkerBase): "eagle", "deepseek_mtp", "glm4_moe_mtp", - "mimo_mtp")) \ + "mimo_mtp", + "ernie_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner