[Model] Add Ernie4.5 VL Model Support (#22514)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
CSWYF3634076 2025-08-27 12:02:55 +08:00 committed by GitHub
parent c905684cfe
commit 644d57d531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 2463 additions and 0 deletions

View File

@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |

View File

@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
) )
# Ernie4.5-VL
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
if modality == "image":
placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
elif modality == "video":
placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
prompts = [
(
f"<|begin_of_sentence|>User: {question}{placeholder}\n"
"Assistant: <think></think>"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Florence2 # Florence2
def run_florence2(questions: list[str], modality: str) -> ModelRequestData: def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
@ -1602,6 +1633,7 @@ model_example_map = {
"chameleon": run_chameleon, "chameleon": run_chameleon,
"command_a_vision": run_command_a_vision, "command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2, "deepseek_vl_v2": run_deepseek_vl2,
"ernie45_vl": run_ernie45_vl,
"florence2": run_florence2, "florence2": run_florence2,
"fuyu": run_fuyu, "fuyu": run_fuyu,
"gemma3": run_gemma3, "gemma3": run_gemma3,

View File

@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10 fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10 pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0

View File

@ -156,6 +156,8 @@ datasets==3.0.2
# mteb # mteb
decorator==5.1.1 decorator==5.1.1
# via librosa # via librosa
decord==0.6.0
# via -r requirements/test.in
dill==0.3.8 dill==0.3.8
# via # via
# datasets # datasets
@ -493,6 +495,7 @@ numpy==1.26.4
# contourpy # contourpy
# cupy-cuda12x # cupy-cuda12x
# datasets # datasets
# decord
# einx # einx
# encodec # encodec
# evaluate # evaluate

View File

@ -272,6 +272,7 @@ def _test_processing_correctness_one(
"CohereLabs/command-a-vision-07-2025", "CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny", "deepseek-ai/deepseek-vl2-tiny",
"naver-clova-ix/donut-base-finetuned-docvqa", "naver-clova-ix/donut-base-finetuned-docvqa",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"microsoft/Florence-2-base", "microsoft/Florence-2-base",
"adept/fuyu-8b", "adept/fuyu-8b",
"google/gemma-3-4b-it", "google/gemma-3-4b-it",

View File

@ -396,6 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
transformers_version_reason="HF model is not compatible.", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
trust_remote_code=True),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501

View File

@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
section_h = self.mrope_section[0] # 22
section_w = self.mrope_section[1] # 22
section_t = self.mrope_section[2] # 20
assert section_h == section_w
# Split according to [h w h w h w h w... t t t...]
section_cos_t = cos[..., -section_t:]
section_cos_h = cos[..., :section_h + section_w:2]
section_cos_w = cos[..., 1:section_h + section_w:2]
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
1], section_cos_w[2]
cos_hw = torch.stack([cos_h, cos_w],
dim=-1).reshape(cos_h.shape[:-1] +
(cos_h.shape[-1] * 2, ))
cos = torch.cat([cos_hw, cos_t], dim=-1)
section_sin_t = sin[..., -section_t:]
section_sin_h = sin[..., :section_h + section_w:2]
section_sin_w = sin[..., 1:section_h + section_w:2]
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
1], section_sin_w[2]
sin_hw = torch.stack([sin_h, sin_w],
dim=-1).reshape(sin_h.shape[:-1] +
(sin_h.shape[-1] * 2, ))
sin = torch.cat([sin_hw, sin_t], dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@ -393,6 +393,15 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len=context_len, context_len=context_len,
seq_len=seq_len, seq_len=seq_len,
) )
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
return cls._ernie_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
context_len=context_len,
seq_len=seq_len,
)
else: else:
return cls._vl_get_input_positions_tensor( return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens, input_tokens=input_tokens,
@ -513,6 +522,120 @@ class MRotaryEmbedding(RotaryEmbedding):
len(input_tokens)).item() len(input_tokens)).item()
return llm_positions, mrope_position_delta return llm_positions, mrope_position_delta
@classmethod
def _ernie_get_input_positions_tensor(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
context_len: int = 0,
seq_len: Optional[int] = None,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for Ernie VL."""
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = []
if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_conv_size, w // spatial_conv_size
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_grid_thw[mm_data_idx][0],
video_grid_thw[mm_data_idx][1],
video_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (t //
temporal_conv_size,
h //
spatial_conv_size,
w //
spatial_conv_size)
for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(
-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) +
st_idx)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 -
len(input_tokens)).item()
return llm_positions, mrope_position_delta
@classmethod @classmethod
def _vl_get_input_positions_tensor( def _vl_get_input_positions_tensor(
cls, cls,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,723 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Baidu team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Erine VL model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
# from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
Ernie4_5_VLRotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .ernie45_moe import Ernie4_5_MoeMLP
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP):
pass
class Ernie4_5_VLMoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: Optional[int] = None,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
freq_allocation: int = 20,
max_position_embeddings: int = 131072,
rms_norm_eps: float = 1e-05,
qkv_bias: bool = False,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0
self.layer_idx = layer_idx
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
t_rope = freq_allocation
h_rope = (self.head_dim // 2 - freq_allocation) // 2
w_rope = (self.head_dim // 2 - freq_allocation) // 2
self.rotary_emb = Ernie4_5_VLRotaryEmbedding(
head_size=self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
is_neox_style=False,
dtype=torch.get_default_dtype(),
mrope_section=[h_rope, w_rope, t_rope])
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# Attention
attn_output = self.attn(q, k, v)
# Output projection
output, _ = self.o_proj(attn_output)
return output
class Ernie4_5_VLMoeMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size()
self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
> 0)
self.hidden_size = config.hidden_size
moe_num_experts = config.moe_num_experts
max_moe_num_experts = max(moe_num_experts)
if self.tp_size > max_moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {moe_num_experts}.")
moe_layer_start_index = config.moe_layer_start_index
text_moe_layer_start_index = moe_layer_start_index[0]
vision_moe_layer_start_index = moe_layer_start_index[1]
moe_layer_end_index = config.moe_layer_end_index
moe_layer_end_index = getattr(
config, "moe_layer_end_index",
[config.num_hidden_layers - 1, config.num_hidden_layers - 1])
text_moe_layer_end_index = moe_layer_end_index[0]
vision_moe_layer_end_index = moe_layer_end_index[1]
assert config.moe_num_experts[0] == config.moe_num_experts[1]
self.e_score_correction_bias = nn.Parameter(
torch.empty(2, config.moe_num_experts[0]))
assert text_moe_layer_start_index <= text_moe_layer_end_index
if layer_idx >= text_moe_layer_start_index and \
layer_idx <= text_moe_layer_end_index:
self.text_experts_gate = ReplicatedLinear(
config.hidden_size,
config.moe_num_experts[0],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.text_experts_gate")
self.text_experts = FusedMoE(
num_experts=config.moe_num_experts[0],
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size[0],
reduce_results=False,
renormalize=True,
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[0],
prefix=f"{prefix}.text_experts")
else:
self.text_experts = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
use_bias=getattr(config, 'use_bias', False),
quant_config=quant_config,
prefix=f"{prefix}.mlp")
assert vision_moe_layer_start_index <= vision_moe_layer_end_index
if layer_idx >= vision_moe_layer_start_index and \
layer_idx <= vision_moe_layer_end_index:
self.vision_experts_gate = ReplicatedLinear(
config.hidden_size,
config.moe_num_experts[1],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.vision_experts_gate")
self.vision_experts = FusedMoE(
num_experts=config.moe_num_experts[1],
top_k=config.moe_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size[1],
reduce_results=False,
renormalize=True,
quant_config=quant_config,
e_score_correction_bias=self.e_score_correction_bias[1],
prefix=f"{prefix}.vision_experts")
else:
self.vision_experts = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
use_bias=getattr(config, 'use_bias', False),
quant_config=quant_config,
prefix=f"{prefix}.mlp")
if self.has_shared_experts:
intermediate_size = (config.moe_intermediate_size[0] *
config.moe_num_shared_experts)
self.shared_experts = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.shared_experts",
reduce_results=self.text_experts.
must_reduce_shared_expert_outputs())
def forward(
self,
hidden_states: torch.Tensor,
visual_token_mask: torch.Tensor,
**kwargs: object,
) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states)
if visual_token_mask is not None and visual_token_mask.any():
# assert visual_token_mask.shape[0] != hidden_states.shape[0]
visual_token_mask = visual_token_mask.repeat(
1, self.hidden_size).bool()
text_token_mask = ~visual_token_mask
final_hidden_states = torch.zeros_like(hidden_states)
text_hidden_states = hidden_states[text_token_mask].reshape(
-1, self.hidden_size)
vision_hidden_states = hidden_states[visual_token_mask].reshape(
-1, self.hidden_size)
text_router_logits, _ = self.text_experts_gate(text_hidden_states)
final_hidden_states[text_token_mask] = self.text_experts(
hidden_states=text_hidden_states,
router_logits=text_router_logits).flatten()
vision_router_logits, _ = self.vision_experts_gate(
vision_hidden_states)
final_hidden_states[visual_token_mask] = self.vision_experts(
hidden_states=vision_hidden_states,
router_logits=vision_router_logits).flatten()
else:
# text modal input processing directly
text_router_logits, _ = self.text_experts_gate(hidden_states)
final_hidden_states = self.text_experts(
hidden_states=hidden_states, router_logits=text_router_logits)
if self.has_shared_experts and \
shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = (
self.text_experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(orig_shape)
class Ernie4_5_VLMoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 500000)
rope_scaling = getattr(config, "rope_scaling", None)
freq_allocation = getattr(config, "freq_allocation", 20)
max_position_embeddings = getattr(config, "max_position_embeddings",
131072)
self.self_attn = Ernie4_5_VLMoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=getattr(config, 'head_dim', None),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
freq_allocation=freq_allocation,
max_position_embeddings=max_position_embeddings,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'use_bias', False),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx
# MoE
moe_layer_start_index = config.moe_layer_start_index
min_moe_layer_start_index = min(moe_layer_start_index)
moe_layer_end_index = getattr(
config, "moe_layer_end_index",
[config.num_hidden_layers - 1, config.num_hidden_layers - 1])
max_moe_layer_end_index = max(moe_layer_end_index)
assert min_moe_layer_start_index <= max_moe_layer_end_index
moe_num_experts = config.moe_num_experts
max_moe_num_experts = max(moe_num_experts)
moe_layer_interval = getattr(config, "moe_layer_interval", 1)
use_moe = getattr(config, "use_moe", max_moe_num_experts > 0)
if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0)
and layer_idx >= min_moe_layer_start_index
and layer_idx <= max_moe_layer_end_index):
self.mlp = Ernie4_5_VLMoeMoE(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Ernie4_5_VLMoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
use_bias=getattr(config, 'use_bias', False),
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor],
**kwargs: object,
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if isinstance(self.mlp, Ernie4_5_VLMoeMoE):
hidden_states = self.mlp(hidden_states, visual_token_mask,
**kwargs)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
# Since Ernie VL distinguishes between text experts and vision experts,
# enabling torch.compile will cause errors.
# @support_torch_compile(
# dynamic_arg_dims={
# "input_ids": 0,
# "positions": -1,
# "intermediate_tensors": 0,
# "inputs_embeds": 0,
# "visual_token_mask": 0,
# })
class Ernie4_5_VLMoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
self.im_patch_id = config.im_patch_id
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Ernie4_5_VLMoeDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
visual_token_mask: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual,
visual_token_mask, **kwargs)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# only used as text backbone for ernie4.5-vl
class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = PPMissingLayer()
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=max(self.config.moe_num_experts))
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.config.tie_word_embeddings and name.endswith(
"lm_head.weight"):
loaded_params.add("lm_head.weight")
continue
# MTP will be supported soon.
if "mtp" in name or \
"vision_model" in name or \
"resampler_model" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Distinguish between vision experts and text experts
if "mlp.experts" in name:
moe_offset = int(name.split(".")[-3])
vision_expert_start_idx = self.config.moe_num_experts[0]
is_text_expert = \
moe_offset <= vision_expert_start_idx - 1
if is_text_expert:
name = name.replace(".experts.", ".text_experts.")
else:
name = name.replace(
f".experts.{moe_offset}",
f".vision_experts.{moe_offset-vision_expert_start_idx}"
)
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Distinguish between vision experts and text experts
moe_offset = int(name.split(".")[-3])
is_text_expert = \
moe_offset <= self.config.moe_num_experts[0] - 1
name = name.replace(weight_name, param_name)
if is_text_expert:
name = name.replace(".experts.", ".text_experts.")
else:
name = name.replace(".experts.", ".vision_experts.")
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Distinguish between vision expert gate
# and text expert gate
if name.endswith("mlp.gate.weight"):
name = name.replace("gate.weight",
"text_experts_gate.weight")
loaded_weight = loaded_weight.T
elif name.endswith("mlp.gate.weight_1"):
name = name.replace("gate.weight_1",
"vision_experts_gate.weight")
loaded_weight = loaded_weight.T
if "e_score_correction_bias" in name:
name = name.replace(".moe_statics.", ".")
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
"Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501