mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
[Model] support MiniMax-VL-01 model (#16328)
Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
parent
96e06e3cb7
commit
cde384cd92
@ -446,6 +446,19 @@ VLM_TEST_SETTINGS = {
|
||||
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
|
||||
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
|
||||
),
|
||||
"minimax_vl_01": VLMTestInfo(
|
||||
models=["MiniMaxAI/MiniMax-VL-01"],
|
||||
prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501
|
||||
img_idx_to_prompt=lambda _: "<image>",
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
dtype="bfloat16",
|
||||
hf_output_post_proc=model_utils.minimax_vl_01_hf_output,
|
||||
patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
marks=[large_gpu_mark(min_gb=80)],
|
||||
),
|
||||
"molmo": VLMTestInfo(
|
||||
models=["allenai/Molmo-7B-D-0924"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
|
||||
@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
|
||||
return output_ids, output_str, out_logprobs
|
||||
|
||||
|
||||
def minimax_vl_01_hf_output(hf_output: RunnerOutput,
|
||||
model: str) -> RunnerOutput:
|
||||
output_ids, output_str, out_logprobs = hf_output
|
||||
if output_str.endswith("<end_of_sentence>"):
|
||||
output_str = output_str.split("<end_of_sentence>")[0]
|
||||
return output_ids, output_str, out_logprobs
|
||||
|
||||
|
||||
####### Functions for converting image assets to embeddings
|
||||
def get_llava_embeddings(image_assets: _ImageAssets):
|
||||
return [asset.image_embeds for asset in image_assets]
|
||||
@ -627,6 +635,17 @@ def minicpmv_26_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
return hf_model
|
||||
|
||||
|
||||
def minimax_vl_01_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, image_sizes=None, **kwargs):
|
||||
return orig_generate(*args, decode_text=False, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
"""Patches and returns an instance of the HfRunner to use for Molmo."""
|
||||
hf_processor = hf_model.processor
|
||||
|
||||
99
tests/models/multimodal/processing/test_minimax_vl_01.py
Normal file
99
tests/models/multimodal/processing/test_minimax_vl_01.py
Normal file
@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
from ....conftest import _ImageAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(
|
||||
image_assets: _ImageAssets,
|
||||
model_id: str,
|
||||
num_imgs: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
prompt = "<image>" * num_imgs
|
||||
image = Image.new("RGB", size=(364, 364))
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
|
||||
def _validate_image_prompt_replacements_one(
|
||||
processor: BaseMultiModalProcessor,
|
||||
num_imgs: int,
|
||||
failed_size_excs: list[tuple[ImageSize, Exception]],
|
||||
image_size: ImageSize,
|
||||
) -> None:
|
||||
prompt = "<image>" * num_imgs
|
||||
image = Image.new("RGB", size=image_size)
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
try:
|
||||
processed_inputs = processor.apply(prompt, mm_data, {})
|
||||
|
||||
image_placeholders = processed_inputs["mm_placeholders"]["image"]
|
||||
assert len(image_placeholders) == num_imgs
|
||||
|
||||
except Exception as exc:
|
||||
failed_size_excs.append((image_size, exc))
|
||||
|
||||
|
||||
def _test_image_prompt_replacements(
|
||||
processor,
|
||||
*,
|
||||
num_imgs: int,
|
||||
image_sizes: list[ImageSize],
|
||||
) -> None:
|
||||
|
||||
failed_size_excs = list[tuple[ImageSize, Exception]]()
|
||||
|
||||
for size in image_sizes:
|
||||
_validate_image_prompt_replacements_one(processor, num_imgs,
|
||||
failed_size_excs, size)
|
||||
|
||||
if failed_size_excs:
|
||||
msg = "Found failing image sizes:" \
|
||||
+ "\n========\n".join(f"[{size}]\n{exc}"
|
||||
for size, exc in failed_size_excs)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_prompt_replacements_regression(model_id, num_imgs):
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||
(488, 183), (2560, 1669)]
|
||||
image_sizes = [
|
||||
size for w, h in image_ratios
|
||||
for size in [ImageSize(w, h), ImageSize(h, w)]
|
||||
]
|
||||
|
||||
_test_image_prompt_replacements(
|
||||
processor,
|
||||
num_imgs=num_imgs,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
@ -337,6 +337,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
|
||||
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -110,7 +110,17 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
variance = tensor_model_parallel_all_reduce(
|
||||
variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
|
||||
weight = self.weight
|
||||
if x.size(-1) != self.weight.size(0):
|
||||
if self.weight.size(0) < x.size(-1):
|
||||
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
|
||||
full_weight = self.weight.repeat(repeat_count)
|
||||
weight = full_weight[:x.size(-1)]
|
||||
else:
|
||||
weight = self.weight[:x.size(-1)]
|
||||
|
||||
x = x.to(orig_dtype) * weight
|
||||
return x
|
||||
|
||||
def forward(
|
||||
@ -421,6 +431,10 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
attn_metadata):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
_start = attn_metadata.query_start_loc[_prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[_prefill_idx]
|
||||
@ -443,6 +457,10 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
hidden.append(
|
||||
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata))
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
@ -663,6 +681,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
self.shared_moe = False
|
||||
|
||||
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
|
||||
if isinstance(shared_intermediate, list):
|
||||
shared_intermediate = shared_intermediate[
|
||||
layer_id] if layer_id < len(shared_intermediate) else 0
|
||||
if shared_intermediate > 0:
|
||||
self.shared_moe = True
|
||||
self.shared_mlp = MiniMaxText01MLP(
|
||||
@ -875,6 +896,8 @@ class MiniMaxText01Model(nn.Module):
|
||||
|
||||
slots_to_clear = []
|
||||
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_id >= len(seq_id_map):
|
||||
break
|
||||
seq_id = seq_id_map[_prefill_id]
|
||||
if attn_metadata.context_lens_tensor[
|
||||
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
|
||||
@ -886,13 +909,18 @@ class MiniMaxText01Model(nn.Module):
|
||||
dtype=torch.long)
|
||||
minimax_cache_tensors[:, slots_tensor, ...] = 0
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors=None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
@ -901,6 +929,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
@ -922,15 +951,11 @@ class MiniMaxText01Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
kv_cache_index = 0
|
||||
minimax_cache_index = 0
|
||||
attn_metadata.rotary_emb = self.rotary_emb
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
_caches = None
|
||||
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
||||
_caches = kv_caches[kv_cache_index]
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
|
||||
current_state_layer = minimax_cache_index
|
||||
_caches = minimax_cache_params.at_layer_idx(
|
||||
@ -1009,15 +1034,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, self.kv_cache,
|
||||
intermediate_tensors, inputs_embeds,
|
||||
**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1043,8 +1073,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
def which_layer(name: str) -> int:
|
||||
if "layers" in name:
|
||||
@ -1108,6 +1139,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -1117,6 +1149,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_shared_mlp_weight(name: str) -> bool:
|
||||
@ -1154,6 +1187,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
else:
|
||||
raise AssertionError(
|
||||
"MLP weight not in [gate_up_proj, down_proj]")
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_mha_weight(name: str) -> bool:
|
||||
@ -1170,6 +1204,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
MiniMaxText01LinearAttention.weight_direct_load)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
|
||||
@ -1194,6 +1229,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@ -1204,6 +1240,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_layer_norm_weight(name: str) -> bool:
|
||||
@ -1219,6 +1256,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
|
||||
@ -1230,6 +1268,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@ -1258,4 +1297,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
continue
|
||||
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return
|
||||
return loaded_params
|
||||
|
||||
615
vllm/model_executor/models/minimax_vl_01.py
Normal file
615
vllm/model_executor/models/minimax_vl_01.py
Normal file
@ -0,0 +1,615 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, CLIPVisionConfig, PretrainedConfig
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.jsontree import json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# For dummy input only
|
||||
@dataclass
|
||||
class MaxImageTokenMeta:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
|
||||
|
||||
class MiniMaxVL01ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, num_channels, height, width)`
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
|
||||
class MiniMaxVL01ImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int):
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple,
|
||||
# otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError("image_size invalid type " +
|
||||
f"{type(image_size)} with value {image_size}")
|
||||
image_size = image_size.tolist()
|
||||
|
||||
best_resolution = select_best_resolution(image_size, grid_pinpoints)
|
||||
height, width = best_resolution
|
||||
num_patches = 0
|
||||
# consider change to ceil(height/patch_size)*ceil(width/patch_size) + 1
|
||||
for i in range(0, height, patch_size):
|
||||
for j in range(0, width, patch_size):
|
||||
num_patches += 1
|
||||
# add the base patch
|
||||
num_patches += 1
|
||||
return num_patches
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise TypeError("grid_pinpoints should be a list of tuples or lists")
|
||||
|
||||
# ! VERY IMPORTANT if image_size is tensor,
|
||||
# must convert to into tuple,
|
||||
# otherwise it will cause wrong calculate
|
||||
if not isinstance(image_size, (list, tuple)):
|
||||
if not isinstance(image_size, (torch.Tensor, np.ndarray)):
|
||||
raise TypeError(
|
||||
"image_size invalid type " +
|
||||
f"{type(image_size)} not valid, " +
|
||||
"should be either list, tuple, np.ndarray or tensor")
|
||||
image_size = image_size.tolist()
|
||||
|
||||
height, width = select_best_resolution(image_size, grid_pinpoints)
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def unpad_image(tensor, original_size):
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(original_height * current_width) // original_width
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
||||
else:
|
||||
new_width = int(original_width * current_height) // original_height
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
class MiniMaxVL01MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vision_hidden_size: int,
|
||||
text_hidden_size: int,
|
||||
projector_hidden_act: str,
|
||||
multimodal_projector_bias: bool,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = ColumnParallelLinear(vision_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_1")
|
||||
self.act = get_act_fn(projector_hidden_act)
|
||||
self.linear_2 = RowParallelLinear(text_hidden_size,
|
||||
text_hidden_size,
|
||||
bias=multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_2")
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniMaxVL01LikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
image_token_index: Final[int]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, list[int]]]
|
||||
|
||||
|
||||
class MiniMaxVL01LikeProcessor(Protocol):
|
||||
image_token: Final[str]
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
|
||||
class MiniMaxVL01DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
return image_token * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=MaxImageTokenMeta.width,
|
||||
height=MaxImageTokenMeta.height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
|
||||
class MiniMaxVL01ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(MiniMaxVL01Config)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
encoder_num_image_tokens: int,
|
||||
) -> int:
|
||||
if strategy == "default":
|
||||
return encoder_num_image_tokens - 1
|
||||
if strategy == "full":
|
||||
return encoder_num_image_tokens
|
||||
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
class BaseMiniMaxVL01MultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class MiniMaxVL01MultiModalProcessor(
|
||||
BaseMiniMaxVL01MultiModalProcessor[MiniMaxVL01ProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
image_sizes = processed_outputs["image_sizes"]
|
||||
min_len = min(len(pixel_values), len(image_sizes))
|
||||
pixel_values = pixel_values[:min_len]
|
||||
image_sizes = image_sizes[:min_len]
|
||||
assert len(pixel_values) == len(image_sizes)
|
||||
|
||||
processed_outputs["pixel_values"] = [
|
||||
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return {
|
||||
"pixel_values": MultiModalFieldConfig.batched("image"),
|
||||
"image_embeds": MultiModalFieldConfig.batched("image"),
|
||||
}
|
||||
|
||||
|
||||
def _get_num_hidden_layers(hf_config: MiniMaxVL01LikeConfig) -> int:
|
||||
"""Determine the number of hidden layers to initialize up to in the
|
||||
visual encoder.
|
||||
|
||||
Args:
|
||||
hf_config: Model config with vision feature layer(s).
|
||||
"""
|
||||
feature_layers = hf_config.vision_feature_layer
|
||||
num_hidden_layers = hf_config.vision_config.num_hidden_layers
|
||||
# If we have one feature layer, initialize up to that layer
|
||||
if isinstance(feature_layers, int):
|
||||
return _get_layer_index(feature_layers, num_hidden_layers)
|
||||
# If we have multiple feature layers, initialize up to the deepest one
|
||||
elif isinstance(feature_layers, (list, tuple)):
|
||||
return max(
|
||||
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
|
||||
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
|
||||
" is not supported")
|
||||
|
||||
|
||||
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
"""Given a signed vision feature layer, get the number of hidden layers
|
||||
needed to leverage it.
|
||||
|
||||
Args:
|
||||
feature_layer_index: Index of a required layer in the visual encoder.
|
||||
num_hidden_layers: The total number of hidden layers in the visual
|
||||
encoder.
|
||||
"""
|
||||
if feature_layer_index < 0:
|
||||
return num_hidden_layers + feature_layer_index + 1
|
||||
return feature_layer_index
|
||||
|
||||
|
||||
def init_vision_tower_for_MiniMaxVL01(
|
||||
hf_config: MiniMaxVL01LikeConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the deepest required feature layer
|
||||
num_hidden_layers = _get_num_hidden_layers(hf_config)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
MiniMaxVL01MultiModalProcessor,
|
||||
info=MiniMaxVL01ProcessingInfo,
|
||||
dummy_inputs=MiniMaxVL01DummyInputsBuilder)
|
||||
class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = init_vision_tower_for_MiniMaxVL01(
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.vision_feature_layer = config.vision_feature_layer
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.pad_token_id = -1
|
||||
if self.config.pad_token_id is not None:
|
||||
self.pad_token_id = self.config.pad_token_id
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
if strategy == "default":
|
||||
return image_features[:, 1:]
|
||||
elif strategy == "full":
|
||||
return image_features
|
||||
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel],
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values)
|
||||
|
||||
def select_features(leaf: torch.Tensor):
|
||||
return self._select_image_features(
|
||||
leaf,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
return cast(
|
||||
Union[torch.Tensor, tuple[torch.Tensor, ...]],
|
||||
json_map_leaves(select_features, image_features),
|
||||
)
|
||||
|
||||
def _process_image_pixels(
|
||||
self,
|
||||
inputs: Union[MiniMaxVL01ImagePixelInputs],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["pixel_values"]
|
||||
|
||||
return self._image_pixels_to_features(self.vision_tower, pixel_values)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: MiniMaxVL01ImagePixelInputs,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
if isinstance(image_features, torch.Tensor):
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
feature_sizes = [
|
||||
image_feature.shape[0] for image_feature in image_features
|
||||
]
|
||||
|
||||
image_embeds = self.multi_modal_projector(torch.cat(image_features))
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[MiniMaxVL01ImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return MiniMaxVL01ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return MiniMaxVL01ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
@ -189,6 +189,7 @@ _MULTIMODAL_MODELS = {
|
||||
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501
|
||||
"MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501
|
||||
"MiniMaxVL01ForConditionalGeneration": ("minimax_vl_01", "MiniMaxVL01ForConditionalGeneration"), # noqa: E501
|
||||
"MiniCPMO": ("minicpmo", "MiniCPMO"),
|
||||
"MiniCPMV": ("minicpmv", "MiniCPMV"),
|
||||
"Mistral3ForConditionalGeneration": ("mistral3", "Mistral3ForConditionalGeneration"), # noqa: E501
|
||||
|
||||
@ -34,11 +34,13 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
H2OVLChatConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
KimiVLConfig, MedusaConfig,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
NVLM_D_Config, RWConfig,
|
||||
SkyworkR1VChatConfig, SolarConfig,
|
||||
Telechat2Config, UltravoxConfig)
|
||||
MiniMaxText01Config,
|
||||
MiniMaxVL01Config, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, SkyworkR1VChatConfig,
|
||||
SolarConfig, Telechat2Config,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
@ -73,6 +75,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
"exaone": ExaoneConfig,
|
||||
"h2ovl_chat": H2OVLChatConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"minimax_text_01": MiniMaxText01Config,
|
||||
"minimax_vl_01": MiniMaxVL01Config,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"solar": SolarConfig,
|
||||
|
||||
@ -15,6 +15,8 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.minimax_text_01 import MiniMaxText01Config
|
||||
from vllm.transformers_utils.configs.minimax_vl_01 import MiniMaxVL01Config
|
||||
from vllm.transformers_utils.configs.mllama import MllamaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
|
||||
@ -39,6 +41,8 @@ __all__ = [
|
||||
"MedusaConfig",
|
||||
"EAGLEConfig",
|
||||
"ExaoneConfig",
|
||||
"MiniMaxText01Config",
|
||||
"MiniMaxVL01Config",
|
||||
"MllamaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
"MoonViTConfig",
|
||||
|
||||
69
vllm/transformers_utils/configs/minimax_text_01.py
Normal file
69
vllm/transformers_utils/configs/minimax_text_01.py
Normal file
@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
""" MiniMaxText01 model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MiniMaxText01Config(PretrainedConfig):
|
||||
model_type = "MiniMaxText01"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=None,
|
||||
eos_token_id=None,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1e6,
|
||||
sliding_window=None,
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=2,
|
||||
num_local_experts=8,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
router_jitter_noise=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_local_experts = num_local_experts
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.router_jitter_noise = router_jitter_noise
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
70
vllm/transformers_utils/configs/minimax_vl_01.py
Normal file
70
vllm/transformers_utils/configs/minimax_vl_01.py
Normal file
@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""MiniMaxVL01 model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from .minimax_text_01 import MiniMaxText01Config
|
||||
|
||||
|
||||
class MiniMaxVL01Config(PretrainedConfig):
|
||||
model_type = "minimax_vl_01"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
image_token_index=32000,
|
||||
projector_hidden_act="gelu",
|
||||
vision_feature_select_strategy="default",
|
||||
vision_feature_layer=-2,
|
||||
image_grid_pinpoints=None,
|
||||
tie_word_embeddings=False,
|
||||
image_seq_length=576,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.image_token_index = image_token_index
|
||||
self.projector_hidden_act = projector_hidden_act
|
||||
self.image_seq_length = image_seq_length
|
||||
|
||||
if vision_feature_select_strategy not in ["default", "full"]:
|
||||
raise ValueError("vision_feature_select_strategy should " +
|
||||
"be one of 'default', 'full'." +
|
||||
f"Got: {vision_feature_select_strategy}")
|
||||
|
||||
self.vision_feature_select_strategy = vision_feature_select_strategy
|
||||
self.vision_feature_layer = vision_feature_layer
|
||||
image_grid_pinpoints = (
|
||||
image_grid_pinpoints if image_grid_pinpoints is not None else
|
||||
[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]])
|
||||
self.image_grid_pinpoints = image_grid_pinpoints
|
||||
|
||||
if isinstance(vision_config, dict):
|
||||
if "model_type" not in vision_config:
|
||||
vision_config["model_type"] = "clip_vision_model"
|
||||
vision_config = CONFIG_MAPPING[vision_config["model_type"]](
|
||||
**vision_config)
|
||||
elif vision_config is None:
|
||||
vision_config = CONFIG_MAPPING["clip_vision_model"](
|
||||
intermediate_size=4096,
|
||||
hidden_size=1024,
|
||||
patch_size=14,
|
||||
image_size=336,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=16,
|
||||
vocab_size=32000,
|
||||
projection_dim=768,
|
||||
)
|
||||
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is not None:
|
||||
text_config = MiniMaxText01Config(**text_config)
|
||||
else:
|
||||
text_config = MiniMaxText01Config()
|
||||
|
||||
self.text_config = text_config
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
Loading…
x
Reference in New Issue
Block a user