diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 381641226855..265643a44104 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -622,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
| `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
+| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 69961d738518..2c2d094e048f 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -427,7 +427,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
trust_remote_code=True),
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
- extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
+ extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True,
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index e172758b2f2c..3aa16bb9abe4 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -38,6 +38,8 @@ from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.awq import AWQConfig
+from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
@@ -339,7 +341,9 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
mm_limits = {"image": None}
- if self.get_model_version() == (2, 6):
+ if self.get_model_version() == (2,
+ 6) or self.get_model_version() == (4,
+ 0):
mm_limits["video"] = None
return mm_limits
@@ -620,7 +624,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys: set[str],
) -> dict[str, NestedTensors]:
# This processor supports zipping prompt and mm_data together
- if self.info.get_model_version() == (2, 6):
+ if self.info.get_model_version() == (
+ 2, 6) or self.info.get_model_version() == (4, 0):
inputs = super()._call_hf_processor(
prompt=prompts, # type: ignore
mm_data=mm_data,
@@ -679,10 +684,18 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
- placeholder = {
- "image": self.info.image_pattern,
- "video": self.info.video_pattern,
- }
+ placeholders = [("image", self.info.image_pattern),
+ ("video", self.info.video_pattern)]
+
+ # hard code for inconsistency of encode-decode image_pattern
+ additional_placeholders = []
+ tokenizer = self.info.get_tokenizer()
+ for modality, pattern in placeholders:
+ sub_pattern = tokenizer.decode(
+ tokenizer.encode(pattern, add_special_tokens=False))
+ if sub_pattern != pattern:
+ additional_placeholders.append((modality, sub_pattern))
+ placeholders += additional_placeholders
def get_image_replacement(item_idx: int):
images = mm_items.get_items(
@@ -714,9 +727,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
return [
PromptReplacement(modality=modality,
- target=placeholder[modality],
+ target=pattern,
replacement=get_replacement[modality])
- for modality in ("image", "video")
+ for modality, pattern in placeholders
]
def _get_mm_fields_config(
@@ -1262,11 +1275,124 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return self.resampler(vision_embedding, tgt_sizes)
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self,
+ skip_prefixes=["apm.", "audio", "tts"])
+ return loader.load_weights(weights)
+
+
+class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ assert self.version == (4, 0)
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
+ return None
+ return quant_config
+
+ def init_llm(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> nn.Module:
+ return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
+
+ def init_vision_module(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ quant_config = self._maybe_ignore_quant_config(quant_config)
+ model = Idefics2VisionTransformer(config.vision_config,
+ quant_config=quant_config,
+ prefix=prefix)
+ if self.config.drop_vision_last_layer:
+ model.encoder.layers = model.encoder.layers[:-1]
+ return model
+
+ def init_resampler(
+ self,
+ embed_dim: int,
+ vision_dim: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> nn.Module:
+ quant_config = self._maybe_ignore_quant_config(quant_config)
+ with set_default_torch_dtype(torch.float16):
+ # The resampler in 4.0 remains consistent with the one in 2.5/2.6.
+ resampler = Resampler2_5(num_queries=self.config.query_num,
+ embed_dim=embed_dim,
+ num_heads=embed_dim // 128,
+ kv_dim=vision_dim,
+ quant_config=quant_config,
+ prefix=prefix)
+
+ return resampler.to(device=current_platform.device_type,
+ dtype=torch.get_default_dtype())
+
+ def get_vision_hidden_states(
+ self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
+ pixel_values = data["pixel_values"]
+ tgt_sizes = data["tgt_sizes"]
+
+ B = len(pixel_values)
+ P = pixel_values[0].shape[-2]
+ L = max(item.shape[-1] for item in pixel_values)
+ device = pixel_values[0].device
+ dtype = pixel_values[0].dtype
+
+ all_pixel_values = torch.zeros((B, 3, P, L),
+ dtype=dtype,
+ device=device)
+ for i, pixel_values_item in enumerate(pixel_values):
+ L_item = pixel_values_item.shape[-1]
+ all_pixel_values[i, ..., :L_item] = pixel_values_item
+
+ num_patches = tgt_sizes.prod(-1)
+ max_patches = num_patches.max().item()
+ assert isinstance(max_patches, int)
+
+ patch_attn_mask = torch.zeros((B, max_patches),
+ dtype=torch.bool,
+ device=device)
+ for i, num_patches_item in enumerate(num_patches):
+ patch_attn_mask[i, :num_patches_item] = True
+
+ vision_embedding = self.vpm(
+ all_pixel_values,
+ patch_attention_mask=patch_attn_mask.unsqueeze(1),
+ tgt_sizes=tgt_sizes,
+ )
+
+ return self.resampler(vision_embedding, tgt_sizes)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self,
+ skip_prefixes=["apm.", "audio", "tts"])
+ return loader.load_weights(weights)
+
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,
(2, 5): MiniCPMV2_5,
(2, 6): MiniCPMV2_6,
+ (4, 0): MiniCPMV4_0,
}
@@ -1294,8 +1420,10 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
# Dispatch class based on version
instance_cls = _SUPPORT_VERSION.get(version)
if instance_cls is None:
- raise ValueError(
- "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
+ supported_versions = ", ".join(
+ [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())])
+ raise ValueError(f"Currently, MiniCPMV only supports versions "
+ f"{supported_versions}. Got version: {version}")
# quant_config references base class members,
# so update values before init is called