diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 8ff18a17d36c3..793831fd06ded 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -836,14 +836,14 @@ See [this page](#generative-models) for more information on how to use generativ * `openbmb/MiniCPM-o-2_6`, etc. * ✅︎ * ✅︎ - * + * ✅︎ - * `MiniCPMV` * MiniCPM-V * T + IE+ + VE+ * `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. * ✅︎ * ✅︎ - * + * ✅︎ - * `MllamaForConditionalGeneration` * Llama 3.2 * T + I+ diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 1312b1051732f..ea37de0b806ab 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -23,8 +23,8 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) import torch from torch import nn @@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, MultiModalDataParser) from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs -from vllm.sequence import IntermediateTensors -from vllm.utils import flatten_2d_lists from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, MiniCPMVMultiModalDataParser, @@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, _minicpmv_field_config) from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix) +from .vision import scatter_patch_features CPU_DEVICE = torch.device("cpu") class MiniCPMOAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - audio_features: torch.Tensor + audio_features: Union[torch.Tensor, list[torch.Tensor]] """ Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, @@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict): Padding is used therefore `audio_features` is `torch.Tensor`. """ - audio_feature_lens: torch.Tensor + audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size * num_audios * num_slices)` + Shape: `(batch_size * num_audios, num_slices)` This should be feature length of each audio slice, which equals to `audio_features.shape[-1]` """ - audio_bounds: torch.Tensor + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size * num_audios * num_slices, 2)` + A boolean mask indicating which audio embeddings correspond + to patch tokens. - This should be in `(start, stop)` format. + Shape: `(batch_size * num_audios, num_embeds)` """ class MiniCPMOAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] - audio_embeds: torch.Tensor + audio_embeds: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size * num_images * num_slices, hidden_size)` + Shape: `(batch_size * num_audios, num_slices, hidden_size)` `hidden_size` must match the hidden size of language model backbone. instead of a batched tensor. Length of each slice may vary, so pass it as a list. """ - audio_bounds: torch.Tensor - """ - Shape: `(batch_size * num_audios * num_slices, 2)` - This should be in `(start, stop)` format. + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which audio embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_audios, num_embeds)` """ @@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): + audio_features = hf_inputs.get("audio_features", torch.empty(0)) + num_audios = len(audio_features) + return dict( **_minicpmv_field_config(hf_inputs), audio_features=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"), + audio_embed_is_patch=MultiModalFieldConfig.batched("audio"), + audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), ) @@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): audio_pattern = "()" def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None, "audio": None} + return {**super().get_supported_mm_limits(), "audio": None} def get_mm_max_tokens_per_item( self, @@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): mm_counts: Mapping[str, int], ) -> Mapping[str, int]: return { - "image": self.get_max_image_tokens(), - "audio": self.get_max_audio_tokens(), - "video": self.get_max_video_tokens(seq_len), + **super().get_mm_max_tokens_per_item(seq_len, mm_counts), + "audio": + self.get_max_audio_tokens(), } + def get_audio_placeholder( + self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1, + ) -> str: + hf_processor = self.get_hf_processor() + + return hf_processor.get_audio_placeholder( + audio_lens, + chunk_input=chunk_input, + chunk_length=chunk_length, + ) + def get_default_audio_pool_step(self) -> int: return 2 @@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): max_videos = mm_config.get_limit_per_prompt("video") max_audios = mm_config.get_limit_per_prompt("audio") - # count tokens - # which are not in get_max_image_tokens - max_image_tokens = self.get_max_image_tokens( - ) * max_images + 4 * max_images - max_audio_tokens = self.get_max_audio_tokens( - ) * max_audios + 2 * max_audios + max_image_tokens = self.get_max_image_tokens() * max_images + max_audio_tokens = self.get_max_audio_tokens() * max_audios max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens - max_audio_tokens) @@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder( processor_inputs = super().get_dummy_processor_inputs( seq_len, mm_counts) - mm_data = { - "image": - processor_inputs.mm_data["image"], - "video": - processor_inputs.mm_data["video"], + + audio_prompt_texts = self.info.audio_pattern * num_audios + audio_mm_data = { "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) } - audio_prompt_texts = self.info.audio_pattern * num_audios - - return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \ - audio_prompt_texts, - mm_data=mm_data) + return ProcessorInputs( + prompt_text=processor_inputs.prompt_text + audio_prompt_texts, + mm_data={ + **processor_inputs.mm_data, + **audio_mm_data, + }, + ) class MiniCPMOMultiModalProcessor( @@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor( return MiniCPMOMultiModalDataParser( target_sr=self.info.get_default_audio_sampling_rate()) - def get_audio_prompt_texts(self, - audio_lens: int, - chunk_input: bool = True, - chunk_length: int = 1) -> str: - return self.info.get_hf_processor().get_audio_placeholder( - audio_lens, chunk_input, chunk_length) - - def get_special_tokens(self) -> Dict[str, torch.Tensor]: - tokenizer = self.info.get_tokenizer() - special_tokens = super().get_special_tokens() - if hasattr(tokenizer, "audio_start_id"): - special_tokens["audio_start_id"] = torch.tensor( - tokenizer.audio_start_id) - special_tokens["audio_end_id"] = torch.tensor( - tokenizer.audio_end_id) - return special_tokens + def get_audio_prompt_texts( + self, + audio_lens: int, + chunk_input: bool = True, + chunk_length: int = 1, + ) -> str: + return self.info.get_audio_placeholder( + audio_lens, + chunk_input=chunk_input, + chunk_length=chunk_length, + ) def process_audios( self, @@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor( parsed_audios = (self._get_data_parser().parse_mm_data({ "audio": audios - }).get_items("audio", AudioProcessorItems)) + }).get_items("audio", + (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))) - audio_inputs = self._base_call_hf_processor( - prompts=[self.info.audio_pattern] * len(parsed_audios), - mm_data={"audios": [[audio] for audio in parsed_audios]}, - mm_kwargs={ - **mm_kwargs, "chunk_input": True - }, - out_keys={"audio_features", "audio_feature_lens"}, - ) + if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): + audio_inputs = {} - # Avoid padding since we need the output for each audio to be - # independent of other audios for the cache to work correctly - unpadded_audio_features = [ - feat[:, :feature_len] for feat, feature_len in zip( - audio_inputs["audio_features"], - audio_inputs["audio_feature_lens"], + audio_lens = [ + self.info.get_audio_len_by_num_chunks( + sum(map(len, + parsed_audios.get(i)["audio_embeds"]))) + for i in range(len(parsed_audios)) + ] + else: + audio_inputs = self._base_call_hf_processor( + prompts=[self.info.audio_pattern] * len(parsed_audios), + mm_data={"audios": [[audio] for audio in parsed_audios]}, + mm_kwargs={ + **mm_kwargs, + "chunk_input": True, + }, + out_keys={"audio_features", "audio_feature_lens"}, ) + + # Avoid padding since we need the output for each audio to be + # independent of other audios for the cache to work correctly + unpadded_audio_features = [ + feat[:, :feature_len] for feat, feature_len in zip( + audio_inputs["audio_features"], + audio_inputs["audio_feature_lens"], + ) + ] + audio_inputs["audio_features"] = unpadded_audio_features + + audio_lens = [ + parsed_audios.get_audio_length(i) + for i in range(len(parsed_audios)) + ] + + audio_repl_features = [ + self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens ] - audio_inputs["audio_features"] = unpadded_audio_features + + tokenizer = self.info.get_tokenizer() + audio_repls_feature_tokens = [ + tokenizer.encode(audio_repl, add_special_tokens=False) + for audio_repl in audio_repl_features + ] + + embed_is_patch = [ + self.get_embed_is_patch(audio_repl_tokens) + for audio_repl_tokens in audio_repls_feature_tokens + ] + audio_inputs["audio_embed_is_patch"] = embed_is_patch + + unk_token_id = tokenizer.get_vocab()[""] + audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) return audio_inputs - def get_placeholder_match_pattern(self) -> str: - return r"\(<(image|video|audio)>./\)" - def process_mm_inputs( self, mm_data: Mapping[str, object], @@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor( if isinstance(audios, MiniCPMOAudioEmbeddingItems): single_audio_embeds = audios.get(item_idx)["audio_embeds"] audio_len = self.info.get_audio_len_by_num_chunks( - sum(chunk_embeds.shape[0] - for chunk_embeds in single_audio_embeds)) + sum(map(len, single_audio_embeds))) else: audio_len = audios.get_audio_length(item_idx) @@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6): self.apm = self.init_audio_module(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")) + self.audio_token_id = None + def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily audio_config = self.config.audio_config @@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6): return input_lengths_after_cnn, input_lengths_after_pooling - # Copied from HF repo of MiniCPM-o-2_6, - # designed for batched inputs and outputs - def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, - chunk_length: int) -> list[torch.Tensor]: - wavforms = data.get( - "audio_features", - []) # (bs, 80, frames) or [], multi audios need filled in advance - audio_feature_lens_raw = [data.get("audio_feature_lens", - [])] # list, [[x1, x2], [y1], [z1]] + def get_audio_hidden_states( + self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]: + chunk_length = self.config.audio_chunk_length - if len(wavforms) == 0: - return [] + # (bs, 80, frames) or [], multi audios need filled in advance + wavforms_raw = data["audio_features"] + if isinstance(wavforms_raw, list): + B = len(wavforms_raw) + C = wavforms_raw[0].shape[-2] + L = max(item.shape[-1] for item in wavforms_raw) + device = wavforms_raw[0].device + dtype = wavforms_raw[0].dtype + + wavforms = torch.zeros((B, C, L), dtype=dtype, device=device) + for i, wavforms_item in enumerate(wavforms_raw): + L_item = wavforms_item.shape[-1] + wavforms[i, ..., :L_item] = wavforms_item + else: + wavforms = wavforms_raw + + # list, [[x1, x2], [y1], [z1]] + audio_feature_lens_raw = data["audio_feature_lens"] + if isinstance(audio_feature_lens_raw, torch.Tensor): + audio_feature_lens_raw = audio_feature_lens_raw.unbind(0) audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape @@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6): num_audio_tokens = feature_lens_after_pooling - final_audio_embeds = [] + final_audio_embeds = list[torch.Tensor]() idx = 0 for i in range(len(audio_feature_lens_raw)): - target_audio_embeds = [] + target_audio_embeds_lst = list[torch.Tensor]() for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds.append( + target_audio_embeds_lst.append( audio_embeds[idx, :num_audio_tokens[idx], :]) idx += 1 - final_audio_embeds.append(target_audio_embeds) + + final_audio_embeds.append(torch.cat(target_audio_embeds_lst)) + return final_audio_embeds - def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, - audio_inputs: MiniCPMOAudioInputs, - chunk_length: int) -> torch.Tensor: - device, dtype = vlm_embedding.device, vlm_embedding.dtype - if audio_inputs["type"] == "audio_embeds": - audio_embeddings = [ - item.to(device=device, dtype=dtype) - for item in audio_inputs["audio_embeds"] - ] - else: - audio_embeddings = self.get_audio_hidden_states( - audio_inputs, chunk_length)[0] - if audio_embeddings is None or len(audio_embeddings) == 0: - return vlm_embedding - audio_bounds = audio_inputs["audio_bounds"] - if self.config.chunk_input: - audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device, - dtype=dtype) - audio_start_pos = 0 - for bound in audio_bounds: - audio_len = bound[1] - bound[0] - vlm_embedding[bound[0]:bound[1]] = audio_embs[ - audio_start_pos:audio_start_pos + audio_len, :] - audio_start_pos += audio_len - else: - for embs, bound in zip(audio_embeddings, audio_bounds): - audio_indices = torch.arange(bound[0], - bound[1], - dtype=torch.long).to(device) - - if embs.shape[0] != len(audio_indices): - raise ValueError( - "Shape mismatch: Trying to assign embeddings " - f"of shape {embs.shape} " - f"to input indices of length {len(audio_indices)}") - vlm_embedding[audio_indices] = embs.to(dtype) - return vlm_embedding - - def _get_audio_bounds(self, input_ids: torch.Tensor, - audio_start_id: torch.Tensor, - audio_end_id: torch.Tensor) -> torch.Tensor: - audio_start_tokens, = torch.where(input_ids == audio_start_id[0]) - audio_start_tokens += 1 - audio_end_tokens, = torch.where(input_ids == audio_end_id[0]) - valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens)) - return torch.hstack([ - audio_start_tokens[:valid_audio_nums].unsqueeze(-1), - audio_end_tokens[:valid_audio_nums].unsqueeze(-1) - ]) - - def _parse_and_validate_audio_inputs( - self, input_ids: torch.Tensor, - **kwargs: object) -> Optional[MiniCPMOAudioInputs]: + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]: audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) if audio_features is None and audio_embeds is None: return None - audio_start_id = kwargs.pop("audio_start_id") - if not isinstance(audio_start_id, torch.Tensor): - raise ValueError("Incorrect type of audio_start_id. " - f"Got type: {type(audio_start_id)}") + audio_token_id = kwargs.pop("audio_token_id") + if audio_token_id is not None: + assert isinstance(audio_token_id, torch.Tensor) + self.mm_token_ids.add(audio_token_id.flatten().unique().item()) - audio_end_id = kwargs.pop("audio_end_id") - if not isinstance(audio_end_id, torch.Tensor): - raise ValueError("Incorrect type of audio_end_id. " - f"Got type: {type(audio_end_id)}") + audio_embed_is_patch = kwargs.pop("audio_embed_is_patch") + if not isinstance(audio_embed_is_patch, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_embed_is_patch. " + f"Got type: {type(audio_embed_is_patch)}") + + audio_embed_is_patch = flatten_bn(audio_embed_is_patch) if audio_embeds is not None: if not isinstance(audio_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of audio_embeds. " f"Got type: {type(audio_embeds)}") + audio_embeds_flat = flatten_bn(audio_embeds) + return MiniCPMOAudioEmbeddingInputs( type="audio_embeds", - audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds), - concat=True), - audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, - audio_end_id), + audio_embeds=audio_embeds_flat, + embed_is_patch=audio_embed_is_patch, ) - if audio_features is not None: - if not isinstance(audio_features, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_features. " - f"Got type: {type(audio_features)}") + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_features. " + f"Got type: {type(audio_features)}") - audio_feature_lens = kwargs.pop("audio_feature_lens") - if not isinstance(audio_feature_lens, (torch.Tensor, list)): - raise ValueError("Incorrect type of audio_feature_lens. " - f"Got type: {type(audio_feature_lens)}") + audio_feature_lens = kwargs.pop("audio_feature_lens") + if not isinstance(audio_feature_lens, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_feature_lens. " + f"Got type: {type(audio_feature_lens)}") - return MiniCPMOAudioFeatureInputs( - type="audio_features", - audio_features=flatten_bn(audio_features, concat=True), - audio_feature_lens=flatten_bn( - flatten_2d_lists(audio_feature_lens), concat=True), - audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, - audio_end_id), - ) + audio_features_flat = flatten_bn(audio_features) + audio_feature_lens_flat = flatten_bn(audio_feature_lens) - raise AssertionError("This line should be unreachable.") - - def _parse_and_validate_inputs(self, input_ids: torch.Tensor, - **kwargs: object): - image_inputs = self._parse_and_validate_image_inputs( - input_ids, **kwargs) - if not any("audio" in key for key in kwargs): - return image_inputs, None - audio_inputs = self._parse_and_validate_audio_inputs( - input_ids, **kwargs) - return image_inputs, audio_inputs - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: Any, - ) -> torch.Tensor: - if intermediate_tensors is not None: - vlm_embeddings = None - else: - image_inputs, audio_inputs = \ - self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings = self.get_embedding_with_vision( - input_ids, image_inputs) - - if audio_inputs is not None: - vlm_embeddings = self.get_embedding_with_audios( - vlm_embeddings, audio_inputs, - self.config.audio_chunk_length) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None - - output = self.llm.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=vlm_embeddings, + return MiniCPMOAudioFeatureInputs( + type="audio_features", + audio_features=audio_features_flat, + audio_feature_lens=audio_feature_lens_flat, + embed_is_patch=audio_embed_is_patch, ) - return output + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = super()._parse_and_validate_multimodal_inputs(**kwargs) + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("audio_features", + "audio_embeds") and "audios" not in modalities: + modalities["audios"] = self._parse_and_validate_audio_input( + **kwargs) + + return modalities + + def _process_audio_input( + self, + audio_input: MiniCPMOAudioInputs, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + if audio_input["type"] == "audio_embeds": + return audio_input["audio_embeds"] + + return self.get_audio_hidden_states(audio_input) + + def _process_multimodal_inputs(self, modalities: dict): + multimodal_embeddings = super()._process_multimodal_inputs(modalities) + + for modality in modalities: + if modality == "audios": + audio_input = modalities["audios"] + audio_features = self._process_audio_input(audio_input) + multimodal_embeddings += tuple( + scatter_patch_features( + audio_features, + audio_input["embed_is_patch"], + )) + + return multimodal_embeddings diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 23c010c63d558..76c7a59d656d5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -23,17 +23,15 @@ # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math -import re from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial -from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, + Union) import numpy as np import torch import torch.types -from PIL import Image from torch import nn from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar @@ -50,9 +48,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, NestedTensors, - PlaceholderRange) +from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -67,13 +63,11 @@ from vllm.sequence import IntermediateTensors from vllm.utils import flatten_2d_lists from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, - SupportsV0Only) -from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix - -CPU_DEVICE = torch.device("cpu") - -RawImageType = Union[Image.Image, torch.Tensor] +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, + merge_multimodal_embeddings) +from .vision import scatter_patch_features, select_patch_features class MiniCPMVImagePixelInputs(TypedDict): @@ -86,13 +80,6 @@ class MiniCPMVImagePixelInputs(TypedDict): instead of a batched tensor. """ - image_bounds: torch.Tensor - """ - Shape: `(batch_size * num_images * num_slices, 2)` - - This should be in `(start, stop)` format. - """ - tgt_sizes: torch.Tensor """ Shape: `(batch_size * num_images * num_slices, 2)` @@ -100,23 +87,34 @@ class MiniCPMVImagePixelInputs(TypedDict): This should be in `(height, width)` format. """ + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] + """ + A boolean mask indicating which image embeddings correspond + to patch tokens. + + Shape: `(batch_size * num_images, num_embeds)` + """ + + num_slices: torch.Tensor + """Shape: `(batch_size * num_images)`""" + class MiniCPMVImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - image_embeds: torch.Tensor + image_embeds: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size * num_images * num_slices, - image_feature_size, hidden_size)` + Shape: `(batch_size * num_images, num_slices, hidden_size)` `hidden_size` must match the hidden size of language model backbone. instead of a batched tensor. """ - image_bounds: torch.Tensor + embed_is_patch: Union[torch.Tensor, list[torch.Tensor]] """ - Shape: `(batch_size * num_images * num_slices, 2)` + A boolean mask indicating which image embeddings correspond + to patch tokens. - This should be in `(start, stop)` format. + Shape: `(batch_size * num_images, num_embeds)` """ @@ -233,15 +231,25 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): + pixel_values = hf_inputs.get("pixel_values", torch.empty(0)) + num_images = len(pixel_values) + + video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0)) + num_videos = len(video_pixel_values) + return dict( pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"), + embed_is_patch=MultiModalFieldConfig.batched("image"), video_pixel_values=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"), + video_embed_is_patch=MultiModalFieldConfig.batched("video"), + image_token_id=MultiModalFieldConfig.shared("image", num_images), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), ) @@ -348,10 +356,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): return get_version_by_config(self.get_hf_config()) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + mm_limits = {"image": None} if self.get_model_version() == (2, 6): - return {"image": None, "video": None} - else: - return {"image": None} + mm_limits["video"] = None + + return mm_limits def get_mm_max_tokens_per_item( self, @@ -361,70 +370,79 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): mm_max_tokens = {"image": self.get_max_image_tokens()} if self.get_model_version() == (2, 6): mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) + return mm_max_tokens + def get_slice_image_placeholder( + self, + image_size: ImageSize, + # For MiniCPM V/O 2.6 + image_idx: int = 0, + max_slice_nums: Optional[int] = None, + use_image_id: bool = True, + ) -> str: + image_processor = self.get_image_processor() + version = self.get_model_version() + + if version == (2, 0) or version == (2, 5): + return image_processor.get_slice_image_placeholder(image_size) + + return image_processor.get_slice_image_placeholder( + image_size, + image_idx=image_idx, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, + ) + + def get_num_image_tokens( + self, + image_size: ImageSize, + max_slice_nums: Optional[int] = None, + use_image_id: bool = True, + ) -> int: + tokenizer = self.get_tokenizer() + image_placeholders = self.get_slice_image_placeholder( + image_size, + max_slice_nums=max_slice_nums, + use_image_id=use_image_id, + ) + image_token_ids = tokenizer.encode(image_placeholders, + add_special_tokens=False) + + return len(image_token_ids) + + def get_max_image_tokens(self) -> int: + image_size = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_size) + + def get_image_max_slice_num(self) -> int: + return getattr(self.get_hf_config(), "max_slice_num", 9) + + def get_image_size_with_most_features(self) -> ImageSize: + image_size = getattr(self.get_hf_config(), "image_size", 448) + max_slice_num = self.get_image_max_slice_num() + return ImageSize(width=image_size, height=image_size * max_slice_num) + def get_max_video_frame_tokens(self) -> int: frame_size = self.get_video_frame_size_with_most_features() - return self.get_num_image_tokens(frame_size, - self.get_video_max_slice_num()) + + return self.get_num_image_tokens( + frame_size, + max_slice_nums=self.get_video_max_slice_num(), + use_image_id=False, + ) def get_max_video_tokens(self, seq_len: int) -> int: return self.get_max_video_frame_tokens( ) * self.get_num_frames_with_most_features(seq_len) - def get_slice_query_num(self) -> int: - hf_config = self.get_hf_config() - query_num = getattr(hf_config, "query_num", 64) - return query_num - - def get_max_slice_num(self) -> int: - hf_config = self.get_hf_config() - max_slice_num = getattr(hf_config, "max_slice_num", 9) - return max_slice_num - - def get_sliced_grid(self, image_size: ImageSize, - max_slice_num: int) -> Tuple[int, int]: - if self.get_model_version() == (2, 6): - slice_grid = self.get_image_processor().get_sliced_grid( - image_size, max_slice_num) - else: - slice_grid = self.get_image_processor().get_sliced_grid(image_size) - return slice_grid - - def get_num_image_tokens(self, image_size: ImageSize, - max_slice_num: int) -> int: - slice_grid = self.get_sliced_grid(image_size, max_slice_num) - num_tokens = self.get_slice_query_num( - ) + 2 # ( * query_num) - if slice_grid is not None: - if self.get_model_version() == (2, 6): - num_additional_tokens = 0 - else: - # ( * query_num) - num_additional_tokens = 2 - num_tokens += ((self.get_slice_query_num() + 2) \ - * slice_grid[0] * slice_grid[1]) \ - + slice_grid[1] - 1 + num_additional_tokens - return num_tokens - - def get_image_slice_nums(self, image_size: torch.Tensor, - max_slice_nums: int) -> int: - grid = self.get_sliced_grid(image_size, max_slice_nums) - return 1 if grid is None else grid[0] * grid[1] + 1 - - def get_max_image_tokens(self) -> int: - image_size = self.get_image_size_with_most_features() - return self.get_num_image_tokens(image_size, self.get_max_slice_num()) - - def get_image_size_with_most_features(self) -> ImageSize: - # Result in the max possible feature size (h:w = 9:1) - return self.get_default_image_sizes(self.get_max_slice_num()) - def get_video_max_slice_num(self) -> int: return 1 def get_video_frame_size_with_most_features(self) -> ImageSize: - return self.get_default_image_sizes(self.get_video_max_slice_num()) + image_size = getattr(self.get_hf_config(), "image_size", 448) + max_slice_num = self.get_video_max_slice_num() + return ImageSize(width=image_size, height=image_size * max_slice_num) def get_max_video_frames(self, max_tokens: int) -> int: num_frame_tokens = self.get_max_video_frame_tokens() @@ -436,10 +454,7 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): max_images = mm_config.get_limit_per_prompt("image") max_videos = mm_config.get_limit_per_prompt("video") - # count tokens - # which are not in get_max_image_tokens - max_image_tokens = self.get_max_image_tokens( - ) * max_images + 4 * max_images + max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens) @@ -447,10 +462,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): return num_frames - def get_default_image_sizes(self, num_slices: int) -> ImageSize: - image_size = getattr(self.get_hf_config(), "image_size", 448) - return ImageSize(width=image_size, height=image_size * num_slices) - _I = TypeVar("_I", bound=MiniCPMVProcessingInfo, @@ -499,42 +510,30 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_data_parser(self) -> MultiModalDataParser: return MiniCPMVMultiModalDataParser() - def get_slice_image_placeholder(self, image_size: ImageSize, - **kwargs) -> str: - image_processor = self.info.get_image_processor() - version = self.info.get_model_version() - if version == (2, 0) or version == (2, 5): - return image_processor.get_slice_image_placeholder(image_size) - return image_processor.get_slice_image_placeholder( - image_size, **kwargs) - def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: - return self.get_slice_image_placeholder(image_size, - image_idx=image_idx) + return self.info.get_slice_image_placeholder( + image_size, + image_idx=image_idx, + ) def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: - return self.get_slice_image_placeholder( + return self.info.get_slice_image_placeholder( image_size=image_size, image_idx=0, max_slice_nums=self.info.get_video_max_slice_num(), use_image_id=False, ) * num_frames - def get_special_tokens(self) -> Dict[str, torch.Tensor]: + def get_embed_is_patch( + self, + input_ids: list[int], + ) -> torch.Tensor: tokenizer = self.info.get_tokenizer() - - special_tokens = { - "im_start_id": tokenizer.im_start_id, - "im_end_id": tokenizer.im_end_id, - } - if hasattr(tokenizer, "slice_start_id"): - special_tokens["slice_start_id"] = tokenizer.slice_start_id - special_tokens["slice_end_id"] = tokenizer.slice_end_id - - return {k: torch.tensor(v) for k, v in special_tokens.items()} + unk_token_id = tokenizer.get_vocab()[""] + return torch.tensor(input_ids) == unk_token_id def process_images( self, @@ -546,14 +545,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): parsed_images = (self._get_data_parser().parse_mm_data({ "image": images - }).get_items("image", ImageProcessorItems)) + }).get_items("image", + (MiniCPMVImageEmbeddingItems, ImageProcessorItems))) - return self._base_call_hf_processor( - prompts=[self.info.image_pattern] * len(parsed_images), - mm_data={"images": [[image] for image in parsed_images]}, - mm_kwargs=mm_kwargs, - out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, - ) + if isinstance(parsed_images, MiniCPMVImageEmbeddingItems): + image_inputs = {} + else: + image_inputs = self._base_call_hf_processor( + prompts=[self.info.image_pattern] * len(parsed_images), + mm_data={"images": [[image] for image in parsed_images]}, + mm_kwargs=mm_kwargs, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) + + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + image_repl_features = [ + self.get_image_prompt_texts(size, idx) + for idx, size in enumerate(image_sizes) + ] + + tokenizer = self.info.get_tokenizer() + image_repls_feature_tokens = [ + tokenizer.encode(image_repl, add_special_tokens=False) + for image_repl in image_repl_features + ] + + embed_is_patch = [ + self.get_embed_is_patch(image_repl_tokens) + for image_repl_tokens in image_repls_feature_tokens + ] + image_inputs["embed_is_patch"] = embed_is_patch + + unk_token_id = tokenizer.get_vocab()[""] + image_inputs["image_token_id"] = torch.tensor(unk_token_id) + + return image_inputs def process_videos( self, @@ -565,25 +593,55 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): parsed_videos = (self._get_data_parser().parse_mm_data({ "video": videos - }).get_items("video", VideoProcessorItems)) + }).get_items("video", + (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))) - max_slice_num = self.info.get_video_max_slice_num() + if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems): + video_inputs = {} + else: + video_inputs = self._base_call_hf_processor( + prompts=[ + self.info.image_pattern * len(video) + for video in parsed_videos + ], + mm_data={"images": list(parsed_videos)}, + mm_kwargs={ + **mm_kwargs, + "max_slice_nums": + self.info.get_video_max_slice_num(), + }, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) - video_inputs = self._base_call_hf_processor( - prompts=[ - self.info.image_pattern * len(video) for video in parsed_videos - ], - mm_data={"images": list(parsed_videos)}, - mm_kwargs={ - **mm_kwargs, "max_slice_nums": max_slice_num - }, - out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, - ) + frame_sizes = [ + parsed_videos.get_frame_size(i) for i in range(len(parsed_videos)) + ] + num_frames = [ + parsed_videos.get_num_frames(i) for i in range(len(parsed_videos)) + ] + video_repl_features = [ + self.get_video_prompt_texts(size, nframes) + for size, nframes in zip(frame_sizes, num_frames) + ] - return {f"video_{k}": v for k, v in video_inputs.items()} + tokenizer = self.info.get_tokenizer() + video_repls_feature_tokens = [ + tokenizer.encode(video_repl, add_special_tokens=False) + for video_repl in video_repl_features + ] - def get_placeholder_match_pattern(self) -> str: - return r"\(<(image|video)>./\)" + embed_is_patch = [ + self.get_embed_is_patch(video_repl_tokens) + for video_repl_tokens in video_repls_feature_tokens + ] + video_inputs["embed_is_patch"] = embed_is_patch + + video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} + + unk_token_id = tokenizer.get_vocab()[""] + video_inputs["video_token_id"] = torch.tensor(unk_token_id) + + return video_inputs def process_mm_inputs( self, @@ -602,7 +660,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_kwargs: Mapping[str, object], *, out_keys: set[str], - ) -> Mapping[str, NestedTensors]: + ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together if self.info.get_model_version() == (2, 6): inputs = super()._call_hf_processor( @@ -635,14 +693,13 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> BatchFeature: - # Do not support combination inputs of images and videos for now - # Try to handle interleaved multimodal data tokenizer = self.info.get_tokenizer() + + input_ids = torch.tensor([tokenizer.encode(prompt)]) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) return BatchFeature({ - "input_ids": - torch.tensor([tokenizer.encode(prompt)]), + "input_ids": input_ids, **mm_inputs, }) @@ -701,39 +758,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ) -> Mapping[str, MultiModalFieldConfig]: return _minicpmv_field_config(hf_inputs) - def apply( - self, - prompt: Union[str, List[int]], - mm_data: MultiModalDataDict, - hf_processor_mm_kwargs: Mapping[str, object], - return_mm_hashes: bool = False, - ) -> MultiModalInputs: - if isinstance(prompt, list): - prompt = self.info.get_tokenizer().decode(prompt) - matches = re.findall(self.get_placeholder_match_pattern(), prompt) - mm_orders = { - f"{modality}_orders": - torch.tensor( - [index for index, m in enumerate(matches) if m == modality]) - for modality in self.info.get_supported_mm_limits() - } - result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - return_mm_hashes) - # Exclude x from placeholders - if "image" in result["mm_placeholders"] and \ - self.info.get_model_version() == (2, 6): - result["mm_placeholders"]["image"] = [ - PlaceholderRange(offset=p["offset"] + 3 + idx // 10, - length=p["length"] - 3 - idx // 10) - for idx, p in enumerate(result["mm_placeholders"]["image"]) - ] - result["mm_kwargs"].update(**mm_orders) - result["mm_kwargs"].update(**self.get_special_tokens()) - return result - -class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, - SupportsV0Only): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. @@ -767,6 +793,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, prefix=maybe_prefix( prefix, "resampler")) + self.mm_token_ids = set[int]() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) @@ -777,233 +804,191 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, return get_sampler() - def get_embedding_with_vision( + def _parse_and_validate_vision_input( self, - input_ids: torch.Tensor, - image_inputs: Optional[MiniCPMVImageInputs], - ) -> torch.Tensor: - vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) - - if image_inputs is None: - return vlm_embedding - - if image_inputs["type"] == "image_embeds": - vision_hidden_states = image_inputs["image_embeds"].to( - device=vlm_embedding.device, - dtype=vlm_embedding.dtype, - ) - else: - vision_hidden_states = self.get_vision_hidden_states(image_inputs) - - # See NOTE in _parse_and_validate_inputs - image_bounds = image_inputs["image_bounds"] - if len(image_bounds) > 0: - image_indices = torch.stack([ - torch.arange(start, end, dtype=torch.long) - for start, end in image_bounds.tolist() - ]).to(vlm_embedding.device) - - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), - ) - - return vlm_embedding - - def _get_image_bounds( - self, - input_ids: torch.Tensor, - im_start_id: torch.Tensor, - im_end_id: torch.Tensor, - slice_start_id: Optional[torch.Tensor] = None, - slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor: - # All the images in the batch should share the same special image - # bound token ids. - start_cond = input_ids == im_start_id[0] - end_cond = input_ids == im_end_id[0] - if slice_start_id is not None: - start_cond |= (input_ids == slice_start_id[0]) - end_cond |= (input_ids == slice_end_id[0]) - - image_start_tokens, = torch.where(start_cond) - image_start_tokens += 1 - image_end_tokens, = torch.where(end_cond) - valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) - - if valid_image_nums == 0: - return torch.zeros((0, 2), device=input_ids.device) - - return torch.hstack([ - image_start_tokens[:valid_image_nums].unsqueeze(-1), - image_end_tokens[:valid_image_nums].unsqueeze(-1), - ]) - - def _parse_and_validate_image_inputs( - self, - input_ids: torch.Tensor, + modality: str, **kwargs: object, ) -> Optional[MiniCPMVImageInputs]: - image_keys = {"pixel_values", "tgt_sizes"} - pixel_data = { - "image": { - key: kwargs.pop(key, None) - for key in image_keys - }, - "video": { - key: kwargs.pop("video_" + key, None) - for key in image_keys - } - } - embed_data = { - "image": kwargs.pop("image_embeds", None), - "video": kwargs.pop("video_embeds", None), - } + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) - all_pixel_data = [ - v for vs in pixel_data.values() for v in vs.values() - if v is not None - ] - all_embed_data = [v for v in embed_data.values() if v is not None] - if len(all_pixel_data) == 0 and len(all_embed_data) == 0: + if pixel_values is None and image_embeds is None: return None - im_start_id = kwargs.pop("im_start_id") - if not isinstance(im_start_id, torch.Tensor): - raise ValueError("Incorrect type of im_start_id. " - f"Got type: {type(im_start_id)}") + image_token_id = kwargs.pop("image_token_id") + if image_token_id is not None: + assert isinstance(image_token_id, torch.Tensor) + self.mm_token_ids.add(image_token_id.flatten().unique().item()) - im_end_id = kwargs.pop("im_end_id") - if not isinstance(im_end_id, torch.Tensor): - raise ValueError("Incorrect type of im_end_id. " - f"Got type: {type(im_end_id)}") + embed_is_patch = kwargs.pop("embed_is_patch") + if not isinstance(embed_is_patch, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of embed_is_patch for {modality=}. " + f"Got type: {type(embed_is_patch)}") - slice_start_id = kwargs.pop("slice_start_id", None) - if slice_start_id is not None and not isinstance( - slice_start_id, torch.Tensor): - raise ValueError("Incorrect type of slice_start_id. " - f"Got type: {type(slice_start_id)}") + embed_is_patch = flatten_bn(embed_is_patch) - slice_end_id = kwargs.pop("slice_end_id", None) - if slice_end_id is not None and not isinstance(slice_end_id, - torch.Tensor): - raise ValueError("Incorrect type of slice_end_id. " - f"Got type: {type(slice_end_id)}") + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of image_embeds for {modality=}. " + f"Got type: {type(image_embeds)}") - if len(all_embed_data) > 0: - if len(all_embed_data) > 1: - raise ValueError("Incorrect inputs for vision embeddings. " - "Image embeds and video embeds can not " - "exist simultaneously.") - - vision_embeds, = all_embed_data - if not isinstance(vision_embeds, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of vision_embeds. " - f"Got type: {type(vision_embeds)}") + image_embeds_flat = flatten_bn(image_embeds) return MiniCPMVImageEmbeddingInputs( type="image_embeds", - image_embeds=flatten_bn(flatten_2d_lists(vision_embeds), - concat=True), - image_bounds=self._get_image_bounds(input_ids, im_start_id, - im_end_id, slice_start_id, - slice_end_id), + image_embeds=image_embeds_flat, + embed_is_patch=embed_is_patch, ) - order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]() - for modality in ("image", "video"): - modality_orders = kwargs.pop(f"{modality}_orders", None) - if modality_orders is not None: - if not isinstance(modality_orders, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of {modality}_orders. " - f"Got type: {type(modality_orders)}") + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of pixel_values for {modality=}. " + f"Got type: {type(pixel_values)}") - order_data[modality] = modality_orders + tgt_sizes = kwargs.pop("tgt_sizes") + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of tgt_sizes for {modality=}. " + f"Got type: {type(tgt_sizes)}") - batch_sizes = { - modality: len(modality_orders) - for modality, modality_orders in order_data.items() - } - unique_batch_sizes = set(batch_sizes.values()) - assert len(unique_batch_sizes) == 1, ( - f"Found inconsistent batch sizes: {batch_sizes}") - batch_size, = unique_batch_sizes + num_slices = [[len(p) for p in ps] for ps in pixel_values] + num_slices_flat = flatten_bn(torch.tensor(num_slices)) - pixel_values_flat = list[torch.Tensor]() - tgt_sizes_flat = list[torch.Tensor]() - for b in range(batch_size): - mm_orders_b = [(idx_b.item(), modality) - for modality, modality_orders in order_data.items() - for idx_b in modality_orders[b]] + pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values)) + tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True) - for _, modality in sorted(mm_orders_b, key=lambda x: x[0]): - modality_pixel_data = pixel_data[modality] - - modality_pixel_values = modality_pixel_data["pixel_values"] - if not isinstance(modality_pixel_values, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of pixel_values for {modality=}. " - f"Got type: {type(modality_pixel_values)}") - - modality_tgt_sizes = modality_pixel_data["tgt_sizes"] - if not isinstance(modality_tgt_sizes, (torch.Tensor, list)): - raise ValueError( - f"Incorrect type of tgt_sizes for {modality=}. " - f"Got type: {type(modality_tgt_sizes)}") - - pixel_values_flat += flatten_2d_lists(modality_pixel_values[b]) - tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b]) - - # NOTE: Input IDs does not contain image tokens during memory profiling, - # so we allow it to be empty if len(pixel_values_flat) != len(tgt_sizes_flat): raise ValueError("Inconsistent flattened lengths, found: " f"{len(pixel_values_flat)} vs. " f"{len(tgt_sizes_flat)}") - if len(pixel_values_flat) == 0: - return None - return MiniCPMVImagePixelInputs( type="pixel_values", pixel_values=pixel_values_flat, - tgt_sizes=torch.stack(tgt_sizes_flat), - image_bounds=self._get_image_bounds(input_ids, im_start_id, - im_end_id, slice_start_id, - slice_end_id), + tgt_sizes=tgt_sizes_flat, + embed_is_patch=embed_is_patch, + num_slices=num_slices_flat, ) - def _parse_and_validate_inputs(self, input_ids: torch.Tensor, - **kwargs: object): - return self._parse_and_validate_image_inputs(input_ids, **kwargs) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_vision_input( + "images", **kwargs) + if input_key in ("video_pixel_values", + "video_embeds") and "videos" not in modalities: + + def _image_key(video_key: str): + if video_key == "video_token_id": + return "image_token_id" + + return video_key.removeprefix("video_") + + modalities["videos"] = self._parse_and_validate_vision_input( + "videos", **{ + _image_key(k): v + for k, v in kwargs.items() + }) + + return modalities + + def _process_vision_input( + self, + image_input: MiniCPMVImageInputs, + ) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"] + + image_features_flat = self.get_vision_hidden_states(image_input) + + # Reconstruct the batch dimension + return image_features_flat.split(image_input["num_slices"].tolist()) + + def _process_multimodal_inputs(self, modalities: dict): + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + image_features = self._process_vision_input(image_input) + multimodal_embeddings += tuple( + scatter_patch_features( + image_features, + image_input["embed_is_patch"], + )) + if modality == "videos": + video_input = modalities["videos"] + video_features = self._process_vision_input(video_input) + multimodal_embeddings += tuple( + scatter_patch_features( + video_features, + video_input["embed_is_patch"], + )) + + return multimodal_embeddings + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + return self._process_multimodal_inputs(modalities) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + assert len(self.mm_token_ids) > 0 + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + select_patch_features(multimodal_embeddings), + list(self.mm_token_ids), + ) + return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: Any, ) -> torch.Tensor: if intermediate_tensors is not None: - vlm_embeddings = None - else: - image_inputs = \ - self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings = self.get_embedding_with_vision( - input_ids, image_inputs) + inputs_embeds = None - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) - output = self.llm.model( + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + hidden_states = self.llm.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, - inputs_embeds=vlm_embeddings, + inputs_embeds=inputs_embeds, ) - return output + return hidden_states def compute_logits( self, @@ -1105,9 +1090,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel): return model - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_tokens(input_ids) - def init_resampler(self, embed_dim: int, vision_dim: int, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 9224687d8a5d3..b2f795155f17b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -92,8 +92,8 @@ class MolmoImageInputs(TypedDict): Shape: `(batch_size * num_images, num_embeds)` """ - num_crops: Union[torch.Tensor, list[torch.Tensor]] - """Shape: `(batch_size, num_images)`""" + num_crops: torch.Tensor + """Shape: `(batch_size * num_images)`""" @dataclass @@ -1492,6 +1492,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, self.img_patch_id = img_patch_id.flatten().unique().item() embed_is_patch = flatten_bn(embed_is_patch) + num_crops = flatten_bn(num_crops, concat=True) return MolmoImageInputs( images=images, @@ -1510,31 +1511,24 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, feat_is_patch = image_input["feat_is_patch"] num_crops = image_input["num_crops"] - if isinstance(images, list): - # Call the vision backbone on the whole batch at once - images_flat = flatten_bn(images, concat=True) - image_masks_flat = (None if image_masks is None else flatten_bn( - image_masks, concat=True)) + # Call the vision backbone on the whole batch at once + images_flat = flatten_bn(images, concat=True) + image_masks_flat = (None if image_masks is None else flatten_bn( + image_masks, concat=True)) + feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True) - image_features_flat = self.vision_backbone( - images=images_flat.unsqueeze(0), - image_masks=(None if image_masks_flat is None else - image_masks_flat.unsqueeze(0)), - ).squeeze(0) - - # Reconstruct the batch dimension - num_crops_per_image = [nc.sum().item() for nc in num_crops] - image_features = image_features_flat.split(num_crops_per_image) - else: - image_features = self.vision_backbone( - images=images, - image_masks=image_masks, - ) + image_features_flat = self.vision_backbone( + images=images_flat.unsqueeze(0), + image_masks=(None if image_masks_flat is None else + image_masks_flat.unsqueeze(0)), + ).squeeze(0) # Only the features corresponding to patch tokens are relevant return [ - feats[f_is_patch] - for feats, f_is_patch in zip(image_features, feat_is_patch) + feats[f_is_patch] for feats, f_is_patch in zip( + image_features_flat.split(num_crops.tolist()), + feat_is_patch_flat.split(num_crops.tolist()), + ) ] def get_multimodal_embeddings(