mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[model] Add support for openPangu_Ultra_MoE (#27521)
Signed-off-by: yuantao <2422264527@qq.com> Signed-off-by: yt0428 <51468697+yt0428@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
5fd8f02ea9
commit
05cae69f0f
@ -404,6 +404,8 @@ th {
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
|
||||
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
|
||||
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
|
||||
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@ -363,6 +363,11 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
|
||||
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
|
||||
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
|
||||
"OpenPanguMTPModel": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"OPTForCausalLM": _HfExamplesInfo(
|
||||
"facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}
|
||||
),
|
||||
@ -370,6 +375,14 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"OrionStarAI/Orion-14B-Chat", trust_remote_code=True
|
||||
),
|
||||
"OuroForCausalLM": _HfExamplesInfo("ByteDance/Ouro-1.4B", trust_remote_code=True),
|
||||
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
|
||||
),
|
||||
"PanguUltraMoEForCausalLM": _HfExamplesInfo(
|
||||
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False,
|
||||
),
|
||||
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
|
||||
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
|
||||
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
|
||||
|
||||
@ -1231,6 +1231,8 @@ class ModelConfig:
|
||||
"kimi_k2",
|
||||
"kimi_linear",
|
||||
"longcat_flash",
|
||||
"pangu_ultra_moe",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == "eagle":
|
||||
@ -1379,6 +1381,7 @@ class ModelConfig:
|
||||
or self.hf_config.model_type == "glm4_moe_mtp"
|
||||
or self.hf_config.model_type == "ernie_mtp"
|
||||
or self.hf_config.model_type == "qwen3_next_mtp"
|
||||
or self.hf_config.model_type == "pangu_ultra_moe_mtp"
|
||||
):
|
||||
total_num_hidden_layers = getattr(
|
||||
self.hf_text_config, "num_nextn_predict_layers", 0
|
||||
|
||||
@ -41,6 +41,7 @@ SpeculativeMethod = Literal[
|
||||
"qwen3_next_mtp",
|
||||
"mimo_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"mtp",
|
||||
"suffix",
|
||||
]
|
||||
@ -51,6 +52,7 @@ MTP_MODEL_TYPES = (
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
)
|
||||
|
||||
|
||||
@ -179,6 +181,13 @@ class SpeculativeConfig:
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
|
||||
)
|
||||
if hf_config.model_type in ("pangu_ultra_moe"):
|
||||
hf_config.model_type = "pangu_ultra_moe_mtp"
|
||||
if hf_config.model_type == "pangu_ultra_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
|
||||
1078
vllm/model_executor/models/openpangu.py
Normal file
1078
vllm/model_executor/models/openpangu.py
Normal file
File diff suppressed because it is too large
Load Diff
265
vllm/model_executor/models/openpangu_mtp.py
Normal file
265
vllm/model_executor/models/openpangu_mtp.py
Normal file
@ -0,0 +1,265 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/deepseek_mtp.py
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMultiTokenPredictor,
|
||||
DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead,
|
||||
)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .openpangu import OpenPanguDecoderLayer
|
||||
|
||||
|
||||
class OpenPanguMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.config = config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
|
||||
self.shared_head = SharedHead(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "shared_head"),
|
||||
)
|
||||
self.mtp_block = OpenPanguDecoderLayer(config, prefix, vllm_config)
|
||||
|
||||
|
||||
class OpenPanguMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict(
|
||||
{
|
||||
str(idx): OpenPanguMultiTokenPredictorLayer(
|
||||
vllm_config, f"{prefix}.layers.{idx}"
|
||||
)
|
||||
for idx in range(
|
||||
self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers,
|
||||
)
|
||||
}
|
||||
)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class OpenPanguMTP(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = OpenPanguMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
hidden_states,
|
||||
inputs_embeds,
|
||||
spec_step_idx,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor | None:
|
||||
return self.model.compute_logits(hidden_states, spec_step_idx)
|
||||
|
||||
def get_spec_layer(self, name):
|
||||
if (
|
||||
"layers" in name
|
||||
and hasattr(self.config, "num_nextn_predict_layers")
|
||||
and self.config.num_nextn_predict_layers > 0
|
||||
):
|
||||
layer_idx = int(name.split("layers.")[-1].split(".")[0])
|
||||
mtp_idx = layer_idx - self.config.num_hidden_layers
|
||||
if mtp_idx >= 0 and mtp_idx < self.config.num_nextn_predict_layers:
|
||||
return layer_idx
|
||||
return None
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
spec_layer = self.get_spec_layer(name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
if (
|
||||
param_name == "fused_qkv_a_proj"
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
else:
|
||||
name = name_mapped
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if (
|
||||
spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name
|
||||
):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens",
|
||||
"enorm",
|
||||
"hnorm",
|
||||
"eh_proj",
|
||||
"shared_head",
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(
|
||||
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
|
||||
)
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
||||
@ -149,6 +149,8 @@ _TEXT_GENERATION_MODELS = {
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"OuroForCausalLM": ("ouro", "OuroForCausalLM"),
|
||||
"PanguEmbeddedForCausalLM": ("openpangu", "PanguEmbeddedForCausalLM"),
|
||||
"PanguUltraMoEForCausalLM": ("openpangu", "PanguUltraMoEForCausalLM"),
|
||||
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
@ -406,6 +408,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"LongCatFlashMTPModel": ("longcat_flash_mtp", "LongCatFlashMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"OpenPanguMTPModel": ("openpangu_mtp", "OpenPanguMTP"),
|
||||
"Qwen3NextMTP": ("qwen3_next_mtp", "Qwen3NextMTP"),
|
||||
# Temporarily disabled.
|
||||
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||
|
||||
@ -316,7 +316,12 @@ class EagleProposer:
|
||||
positions = target_positions[:, last_token_indices]
|
||||
else:
|
||||
positions = target_positions[last_token_indices]
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
|
||||
if self.method in (
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
):
|
||||
hidden_states = self.hidden_states[last_token_indices]
|
||||
else:
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user