mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 15:25:48 +08:00
[Model] Add Ernie4.5 VL Model Support (#22514)
Signed-off-by: wangyafeng <wangyafeng@baidu.com>
This commit is contained in:
parent
c905684cfe
commit
644d57d531
@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
|
||||
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
|
||||
@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Ernie4.5-VL
|
||||
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
|
||||
elif modality == "video":
|
||||
placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
f"<|begin_of_sentence|>User: {question}{placeholder}\n"
|
||||
"Assistant: <think></think>"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Florence2
|
||||
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1602,6 +1633,7 @@ model_example_map = {
|
||||
"chameleon": run_chameleon,
|
||||
"command_a_vision": run_command_a_vision,
|
||||
"deepseek_vl_v2": run_deepseek_vl2,
|
||||
"ernie45_vl": run_ernie45_vl,
|
||||
"florence2": run_florence2,
|
||||
"fuyu": run_fuyu,
|
||||
"gemma3": run_gemma3,
|
||||
|
||||
@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
||||
decord==0.6.0
|
||||
|
||||
@ -156,6 +156,8 @@ datasets==3.0.2
|
||||
# mteb
|
||||
decorator==5.1.1
|
||||
# via librosa
|
||||
decord==0.6.0
|
||||
# via -r requirements/test.in
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
@ -493,6 +495,7 @@ numpy==1.26.4
|
||||
# contourpy
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
|
||||
@ -272,6 +272,7 @@ def _test_processing_correctness_one(
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
|
||||
@ -396,6 +396,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
transformers_version_reason="HF model is not compatible.", # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .mrope import MRotaryEmbedding
|
||||
|
||||
|
||||
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
section_h = self.mrope_section[0] # 22
|
||||
section_w = self.mrope_section[1] # 22
|
||||
section_t = self.mrope_section[2] # 20
|
||||
assert section_h == section_w
|
||||
# Split according to [h w h w h w h w... t t t...]
|
||||
section_cos_t = cos[..., -section_t:]
|
||||
section_cos_h = cos[..., :section_h + section_w:2]
|
||||
section_cos_w = cos[..., 1:section_h + section_w:2]
|
||||
|
||||
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
|
||||
1], section_cos_w[2]
|
||||
cos_hw = torch.stack([cos_h, cos_w],
|
||||
dim=-1).reshape(cos_h.shape[:-1] +
|
||||
(cos_h.shape[-1] * 2, ))
|
||||
cos = torch.cat([cos_hw, cos_t], dim=-1)
|
||||
|
||||
section_sin_t = sin[..., -section_t:]
|
||||
section_sin_h = sin[..., :section_h + section_w:2]
|
||||
section_sin_w = sin[..., 1:section_h + section_w:2]
|
||||
|
||||
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
|
||||
1], section_sin_w[2]
|
||||
sin_hw = torch.stack([sin_h, sin_w],
|
||||
dim=-1).reshape(sin_h.shape[:-1] +
|
||||
(sin_h.shape[-1] * 2, ))
|
||||
sin = torch.cat([sin_hw, sin_t], dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
@ -393,6 +393,15 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
|
||||
return cls._ernie_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
context_len=context_len,
|
||||
seq_len=seq_len,
|
||||
)
|
||||
else:
|
||||
return cls._vl_get_input_positions_tensor(
|
||||
input_tokens=input_tokens,
|
||||
@ -513,6 +522,120 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _ernie_get_input_positions_tensor(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value for Ernie VL."""
|
||||
|
||||
image_token_id = hf_config.im_patch_id
|
||||
video_start_token_id = hf_config.video_start_token_id
|
||||
video_end_token_id = hf_config.video_end_token_id
|
||||
spatial_conv_size = hf_config.spatial_conv_size
|
||||
temporal_conv_size = hf_config.temporal_conv_size
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
if not (image_grid_thw is None and video_grid_thw is None):
|
||||
if isinstance(image_grid_thw, torch.Tensor):
|
||||
image_grid_thw = image_grid_thw.tolist()
|
||||
|
||||
input_token_type: list[str] = []
|
||||
video_check_flg = False
|
||||
for token in input_tokens:
|
||||
if token == video_start_token_id:
|
||||
video_check_flg = True
|
||||
elif token == video_end_token_id:
|
||||
video_check_flg = False
|
||||
|
||||
if (token == image_token_id) and (video_check_flg is False):
|
||||
input_token_type.append("image")
|
||||
elif (token == image_token_id) and (video_check_flg is True):
|
||||
input_token_type.append("video")
|
||||
else:
|
||||
input_token_type.append("text")
|
||||
|
||||
input_type_group: list[tuple[str, int, int]] = []
|
||||
for key, group_iter in itertools.groupby(
|
||||
enumerate(input_token_type), lambda x: x[1]):
|
||||
group_list = list(group_iter)
|
||||
start_index = group_list[0][0]
|
||||
end_index = group_list[-1][0] + 1
|
||||
input_type_group.append((key, start_index, end_index))
|
||||
|
||||
video_frame_num = 1
|
||||
mm_data_idx = 0
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||
llm_pos_ids_list) > 0 else 0
|
||||
if modality_type == "image":
|
||||
t, h, w = (
|
||||
image_grid_thw[mm_data_idx][0],
|
||||
image_grid_thw[mm_data_idx][1],
|
||||
image_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||
t, h // spatial_conv_size, w // spatial_conv_size
|
||||
|
||||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||
llm_grid_t, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||
llm_grid_t, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
mm_data_idx += 1
|
||||
|
||||
elif modality_type == "video":
|
||||
t, h, w = (
|
||||
video_grid_thw[mm_data_idx][0],
|
||||
video_grid_thw[mm_data_idx][1],
|
||||
video_grid_thw[mm_data_idx][2],
|
||||
)
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (t //
|
||||
temporal_conv_size,
|
||||
h //
|
||||
spatial_conv_size,
|
||||
w //
|
||||
spatial_conv_size)
|
||||
|
||||
for t_idx in range(llm_grid_t):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(
|
||||
-1, llm_grid_h * llm_grid_w).flatten()
|
||||
h_index = torch.arange(llm_grid_h).view(
|
||||
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
||||
w_index = torch.arange(llm_grid_w).view(
|
||||
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
mm_data_idx += 1
|
||||
video_frame_num += 1
|
||||
|
||||
else:
|
||||
text_len = end_idx - start_idx
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) +
|
||||
st_idx)
|
||||
video_frame_num = 1
|
||||
|
||||
else:
|
||||
text_len = len(input_tokens)
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1))
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
@classmethod
|
||||
def _vl_get_input_positions_tensor(
|
||||
cls,
|
||||
|
||||
1504
vllm/model_executor/models/ernie45_vl.py
Normal file
1504
vllm/model_executor/models/ernie45_vl.py
Normal file
File diff suppressed because it is too large
Load Diff
723
vllm/model_executor/models/ernie45_vl_moe.py
Normal file
723
vllm/model_executor/models/ernie45_vl_moe.py
Normal file
@ -0,0 +1,723 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The Baidu team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Erine VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
# from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||
Ernie4_5_VLRotaryEmbedding)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .ernie45_moe import Ernie4_5_MoeMLP
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: Optional[int] = None,
|
||||
rope_theta: float = 500000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
freq_allocation: int = 20,
|
||||
max_position_embeddings: int = 131072,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
qkv_bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
t_rope = freq_allocation
|
||||
h_rope = (self.head_dim // 2 - freq_allocation) // 2
|
||||
w_rope = (self.head_dim // 2 - freq_allocation) // 2
|
||||
|
||||
self.rotary_emb = Ernie4_5_VLRotaryEmbedding(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
is_neox_style=False,
|
||||
dtype=torch.get_default_dtype(),
|
||||
mrope_section=[h_rope, w_rope, t_rope])
|
||||
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
# Attention
|
||||
attn_output = self.attn(q, k, v)
|
||||
# Output projection
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeMoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.layer_idx = layer_idx
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
|
||||
> 0)
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
moe_num_experts = config.moe_num_experts
|
||||
max_moe_num_experts = max(moe_num_experts)
|
||||
|
||||
if self.tp_size > max_moe_num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {moe_num_experts}.")
|
||||
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
text_moe_layer_start_index = moe_layer_start_index[0]
|
||||
vision_moe_layer_start_index = moe_layer_start_index[1]
|
||||
moe_layer_end_index = config.moe_layer_end_index
|
||||
moe_layer_end_index = getattr(
|
||||
config, "moe_layer_end_index",
|
||||
[config.num_hidden_layers - 1, config.num_hidden_layers - 1])
|
||||
text_moe_layer_end_index = moe_layer_end_index[0]
|
||||
vision_moe_layer_end_index = moe_layer_end_index[1]
|
||||
|
||||
assert config.moe_num_experts[0] == config.moe_num_experts[1]
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(2, config.moe_num_experts[0]))
|
||||
|
||||
assert text_moe_layer_start_index <= text_moe_layer_end_index
|
||||
|
||||
if layer_idx >= text_moe_layer_start_index and \
|
||||
layer_idx <= text_moe_layer_end_index:
|
||||
self.text_experts_gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.moe_num_experts[0],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.text_experts_gate")
|
||||
|
||||
self.text_experts = FusedMoE(
|
||||
num_experts=config.moe_num_experts[0],
|
||||
top_k=config.moe_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size[0],
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
e_score_correction_bias=self.e_score_correction_bias[0],
|
||||
prefix=f"{prefix}.text_experts")
|
||||
else:
|
||||
self.text_experts = Ernie4_5_VLMoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
use_bias=getattr(config, 'use_bias', False),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
assert vision_moe_layer_start_index <= vision_moe_layer_end_index
|
||||
if layer_idx >= vision_moe_layer_start_index and \
|
||||
layer_idx <= vision_moe_layer_end_index:
|
||||
self.vision_experts_gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.moe_num_experts[1],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_experts_gate")
|
||||
|
||||
self.vision_experts = FusedMoE(
|
||||
num_experts=config.moe_num_experts[1],
|
||||
top_k=config.moe_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size[1],
|
||||
reduce_results=False,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
e_score_correction_bias=self.e_score_correction_bias[1],
|
||||
prefix=f"{prefix}.vision_experts")
|
||||
else:
|
||||
self.vision_experts = Ernie4_5_VLMoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
use_bias=getattr(config, 'use_bias', False),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
if self.has_shared_experts:
|
||||
intermediate_size = (config.moe_intermediate_size[0] *
|
||||
config.moe_num_shared_experts)
|
||||
self.shared_experts = Ernie4_5_VLMoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
reduce_results=self.text_experts.
|
||||
must_reduce_shared_expert_outputs())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
visual_token_mask: torch.Tensor,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.has_shared_experts:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if visual_token_mask is not None and visual_token_mask.any():
|
||||
# assert visual_token_mask.shape[0] != hidden_states.shape[0]
|
||||
visual_token_mask = visual_token_mask.repeat(
|
||||
1, self.hidden_size).bool()
|
||||
text_token_mask = ~visual_token_mask
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
|
||||
text_hidden_states = hidden_states[text_token_mask].reshape(
|
||||
-1, self.hidden_size)
|
||||
vision_hidden_states = hidden_states[visual_token_mask].reshape(
|
||||
-1, self.hidden_size)
|
||||
|
||||
text_router_logits, _ = self.text_experts_gate(text_hidden_states)
|
||||
final_hidden_states[text_token_mask] = self.text_experts(
|
||||
hidden_states=text_hidden_states,
|
||||
router_logits=text_router_logits).flatten()
|
||||
|
||||
vision_router_logits, _ = self.vision_experts_gate(
|
||||
vision_hidden_states)
|
||||
final_hidden_states[visual_token_mask] = self.vision_experts(
|
||||
hidden_states=vision_hidden_states,
|
||||
router_logits=vision_router_logits).flatten()
|
||||
else:
|
||||
# text modal input processing directly
|
||||
text_router_logits, _ = self.text_experts_gate(hidden_states)
|
||||
|
||||
final_hidden_states = self.text_experts(
|
||||
hidden_states=hidden_states, router_logits=text_router_logits)
|
||||
|
||||
if self.has_shared_experts and \
|
||||
shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.text_experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class Ernie4_5_VLMoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 500000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
freq_allocation = getattr(config, "freq_allocation", 20)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
|
||||
self.self_attn = Ernie4_5_VLMoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
freq_allocation=freq_allocation,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'use_bias', False),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# MoE
|
||||
moe_layer_start_index = config.moe_layer_start_index
|
||||
min_moe_layer_start_index = min(moe_layer_start_index)
|
||||
moe_layer_end_index = getattr(
|
||||
config, "moe_layer_end_index",
|
||||
[config.num_hidden_layers - 1, config.num_hidden_layers - 1])
|
||||
max_moe_layer_end_index = max(moe_layer_end_index)
|
||||
assert min_moe_layer_start_index <= max_moe_layer_end_index
|
||||
moe_num_experts = config.moe_num_experts
|
||||
max_moe_num_experts = max(moe_num_experts)
|
||||
moe_layer_interval = getattr(config, "moe_layer_interval", 1)
|
||||
use_moe = getattr(config, "use_moe", max_moe_num_experts > 0)
|
||||
|
||||
if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0)
|
||||
and layer_idx >= min_moe_layer_start_index
|
||||
and layer_idx <= max_moe_layer_end_index):
|
||||
self.mlp = Ernie4_5_VLMoeMoE(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Ernie4_5_VLMoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
use_bias=getattr(config, 'use_bias', False),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
visual_token_mask: Optional[torch.Tensor],
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
if isinstance(self.mlp, Ernie4_5_VLMoeMoE):
|
||||
hidden_states = self.mlp(hidden_states, visual_token_mask,
|
||||
**kwargs)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
# Since Ernie VL distinguishes between text experts and vision experts,
|
||||
# enabling torch.compile will cause errors.
|
||||
# @support_torch_compile(
|
||||
# dynamic_arg_dims={
|
||||
# "input_ids": 0,
|
||||
# "positions": -1,
|
||||
# "intermediate_tensors": 0,
|
||||
# "inputs_embeds": 0,
|
||||
# "visual_token_mask": 0,
|
||||
# })
|
||||
class Ernie4_5_VLMoeModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
self.im_patch_id = config.im_patch_id
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Ernie4_5_VLMoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
visual_token_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual,
|
||||
visual_token_mask, **kwargs)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# only used as text backbone for ernie4.5-vl
|
||||
class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=max(self.config.moe_num_experts))
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if self.config.tie_word_embeddings and name.endswith(
|
||||
"lm_head.weight"):
|
||||
loaded_params.add("lm_head.weight")
|
||||
continue
|
||||
# MTP will be supported soon.
|
||||
if "mtp" in name or \
|
||||
"vision_model" in name or \
|
||||
"resampler_model" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Distinguish between vision experts and text experts
|
||||
if "mlp.experts" in name:
|
||||
moe_offset = int(name.split(".")[-3])
|
||||
vision_expert_start_idx = self.config.moe_num_experts[0]
|
||||
is_text_expert = \
|
||||
moe_offset <= vision_expert_start_idx - 1
|
||||
if is_text_expert:
|
||||
name = name.replace(".experts.", ".text_experts.")
|
||||
else:
|
||||
name = name.replace(
|
||||
f".experts.{moe_offset}",
|
||||
f".vision_experts.{moe_offset-vision_expert_start_idx}"
|
||||
)
|
||||
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Distinguish between vision experts and text experts
|
||||
moe_offset = int(name.split(".")[-3])
|
||||
is_text_expert = \
|
||||
moe_offset <= self.config.moe_num_experts[0] - 1
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_text_expert:
|
||||
name = name.replace(".experts.", ".text_experts.")
|
||||
else:
|
||||
name = name.replace(".experts.", ".vision_experts.")
|
||||
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Distinguish between vision expert gate
|
||||
# and text expert gate
|
||||
if name.endswith("mlp.gate.weight"):
|
||||
name = name.replace("gate.weight",
|
||||
"text_experts_gate.weight")
|
||||
loaded_weight = loaded_weight.T
|
||||
elif name.endswith("mlp.gate.weight_1"):
|
||||
name = name.replace("gate.weight_1",
|
||||
"vision_experts_gate.weight")
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
if "e_score_correction_bias" in name:
|
||||
name = name.replace(".moe_statics.", ".")
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = {
|
||||
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
|
||||
"Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501
|
||||
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501
|
||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user