From 41b67f4263e6ee06cfb5e74073970e2cee854d5e Mon Sep 17 00:00:00 2001 From: tc-mb <157115220+tc-mb@users.noreply.github.com> Date: Thu, 7 Aug 2025 09:35:46 +0800 Subject: [PATCH] [model] Support MiniCPM-V 4.0 (#22166) Co-authored-by: imning3 --- docs/models/supported_models.md | 2 +- tests/models/registry.py | 2 +- vllm/model_executor/models/minicpmv.py | 148 +++++++++++++++++++++++-- 3 files changed, 140 insertions(+), 12 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 381641226855..265643a44104 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -622,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index 69961d738518..2c2d094e048f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -427,7 +427,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 trust_remote_code=True, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index e172758b2f2c..3aa16bb9abe4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -38,6 +38,8 @@ from typing_extensions import TypeVar from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -339,7 +341,9 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, 6): + if self.get_model_version() == (2, + 6) or self.get_model_version() == (4, + 0): mm_limits["video"] = None return mm_limits @@ -620,7 +624,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == (2, 6): + if self.info.get_model_version() == ( + 2, 6) or self.info.get_model_version() == (4, 0): inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -679,10 +684,18 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: - placeholder = { - "image": self.info.image_pattern, - "video": self.info.video_pattern, - } + placeholders = [("image", self.info.image_pattern), + ("video", self.info.video_pattern)] + + # hard code for inconsistency of encode-decode image_pattern + additional_placeholders = [] + tokenizer = self.info.get_tokenizer() + for modality, pattern in placeholders: + sub_pattern = tokenizer.decode( + tokenizer.encode(pattern, add_special_tokens=False)) + if sub_pattern != pattern: + additional_placeholders.append((modality, sub_pattern)) + placeholders += additional_placeholders def get_image_replacement(item_idx: int): images = mm_items.get_items( @@ -714,9 +727,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): return [ PromptReplacement(modality=modality, - target=placeholder[modality], + target=pattern, replacement=get_replacement[modality]) - for modality in ("image", "video") + for modality, pattern in placeholders ] def _get_mm_fields_config( @@ -1262,11 +1275,124 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): return self.resampler(vision_embedding, tgt_sizes) + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + +class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 0) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler2_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, + (4, 0): MiniCPMV4_0, } @@ -1294,8 +1420,10 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): # Dispatch class based on version instance_cls = _SUPPORT_VERSION.get(version) if instance_cls is None: - raise ValueError( - "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") + supported_versions = ", ".join( + [f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())]) + raise ValueError(f"Currently, MiniCPMV only supports versions " + f"{supported_versions}. Got version: {version}") # quant_config references base class members, # so update values before init is called