From a9e879b3167be9d04148ff6ee42313f71e0fd4c0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 25 Mar 2025 18:22:52 +0800 Subject: [PATCH] [Misc] Clean up MiniCPM-V/O code (#15337) Signed-off-by: DarkLight1337 --- examples/offline_inference/vision_language.py | 1 + .../vision_language/test_models.py | 65 +- .../multimodal/processing/test_common.py | 84 +-- vllm/model_executor/models/gemma3_mm.py | 2 - vllm/model_executor/models/minicpmo.py | 338 ++++----- vllm/model_executor/models/minicpmv.py | 675 ++++++++---------- vllm/multimodal/inputs.py | 7 + 7 files changed, 521 insertions(+), 651 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 1cc2562759d47..0adbe574370d3 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -361,6 +361,7 @@ def run_llava_next_video(questions: list[str], engine_args = EngineArgs( model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192, + max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 0235140187990..94b61b6ae7803 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -163,24 +163,24 @@ VLM_TEST_SETTINGS = { marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), #### Extended model tests - # "aria": VLMTestInfo( - # models=["rhymes-ai/Aria"], - # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - # prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 - # img_idx_to_prompt=lambda idx: "<|img|>\n", - # max_model_len=4096, - # max_num_seqs=2, - # auto_cls=AutoModelForImageTextToText, - # single_image_prompts=IMAGE_ASSETS.prompts({ - # "stop_sign": "Please describe the image shortly.", - # "cherry_blossom": "Please infer the season with reason.", # noqa: E501 - # }), - # multi_image_prompt="Describe the two images shortly.", # noqa: E501 - # stop_str=["<|im_end|>"], - # image_size_factors=[(0.10, 0.15)], - # max_tokens=64, - # marks=[large_gpu_mark(min_gb=64)], - # ), + "aria": VLMTestInfo( + models=["rhymes-ai/Aria"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|img|>\n", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "Please describe the image shortly.", + "cherry_blossom": "Please infer the season with reason.", # noqa: E501 + }), + multi_image_prompt="Describe the two images shortly.", # noqa: E501 + stop_str=["<|im_end|>"], + image_size_factors=[(0.10, 0.15)], + max_tokens=64, + marks=[large_gpu_mark(min_gb=64)], + ), "blip2": VLMTestInfo( models=["Salesforce/blip2-opt-2.7b"], test_type=VLMTestType.IMAGE, @@ -352,6 +352,7 @@ VLM_TEST_SETTINGS = { prompt_formatter=lambda vid_prompt: f"USER: {vid_prompt} ASSISTANT:", num_video_frames=16, max_model_len=4096, + max_num_seqs=2, auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, ), @@ -384,7 +385,7 @@ VLM_TEST_SETTINGS = { ), "minicpmo_26": VLMTestInfo( models=["openbmb/MiniCPM-o-2_6"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + test_type=(VLMTestType.IMAGE), prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, @@ -393,9 +394,21 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, ), + "minicpmo_26_multi_image": VLMTestInfo( + models=["openbmb/MiniCPM-o-2_6"], + test_type=(VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), "minicpmv_26": VLMTestInfo( models=["openbmb/MiniCPM-V-2_6"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + test_type=(VLMTestType.IMAGE), prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, @@ -404,6 +417,18 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), + "minicpmv_26_multi_image": VLMTestInfo( + models=["openbmb/MiniCPM-V-2_6"], + test_type=(VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], test_type=(VLMTestType.IMAGE), diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index f761190a8d097..078ed21537b8d 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import copy from functools import partial from typing import Optional, Union @@ -29,7 +28,7 @@ def _test_processing_correctness( hit_rate: float, num_batches: int, simplify_rate: float, - ignore_mm_keys: Optional[list[str]] = None, + ignore_mm_keys: Optional[set[str]] = None, ): model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -145,7 +144,7 @@ def _test_processing_correctness_hf( baseline_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor, batch_idx: int, - ignore_mm_keys: Optional[list[str]] = None, + ignore_mm_keys: Optional[set[str]] = None, ): if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"): # For some multimodal models, tokenizer will always add bos_token @@ -167,11 +166,12 @@ def _test_processing_correctness_hf( hf_processor_mm_kwargs={}, ) - assert _inputs_equal( + _assert_inputs_equal( baseline_result, cached_result, - ignore_mm_keys, - ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", + ) baseline_tokenized_result = baseline_processor.apply( token_prompt, @@ -179,11 +179,12 @@ def _test_processing_correctness_hf( hf_processor_mm_kwargs={}, ) - assert _inputs_equal( + _assert_inputs_equal( baseline_result, baseline_tokenized_result, - ignore_mm_keys, - ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", + ) cached_tokenized_result = cached_processor.apply( token_prompt, @@ -191,11 +192,12 @@ def _test_processing_correctness_hf( hf_processor_mm_kwargs={}, ) - assert _inputs_equal( + _assert_inputs_equal( cached_result, cached_tokenized_result, - ignore_mm_keys, - ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", + ) def _test_processing_correctness_mistral( @@ -206,7 +208,7 @@ def _test_processing_correctness_mistral( baseline_processor: BaseMultiModalProcessor, cached_processor: BaseMultiModalProcessor, batch_idx: int, - ignore_mm_keys: Optional[list[str]] = None, + ignore_mm_keys: Optional[set[str]] = None, ): images = mm_data.get("image", []) if not isinstance(images, list): @@ -233,11 +235,12 @@ def _test_processing_correctness_mistral( hf_processor_mm_kwargs={}, ) - assert _inputs_equal( + _assert_inputs_equal( baseline_tokenized_result, cached_tokenized_result, - ignore_mm_keys, - ), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})" + ignore_mm_keys=ignore_mm_keys, + msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})", + ) # yapf: disable @@ -261,6 +264,7 @@ def _test_processing_correctness_mistral( "TIGER-Lab/Mantis-8B-siglip-llama3", "mistralai/Pixtral-12B-2409", "mistral-community/pixtral-12b", + "openbmb/MiniCPM-Llama3-V-2_5", "openbmb/MiniCPM-o-2_6", "openbmb/MiniCPM-V-2_6", "allenai/Molmo-7B-D-0924", @@ -290,7 +294,7 @@ def test_processing_correctness( # In Ultravox, the audio_features can be different depending on padding # The slight difference should not be a problem though, since # attention_mask lets us ignore the difference. - ignore_mm_keys = ['audio_features'] + ignore_mm_keys = {"audio_features"} _test_processing_correctness( model_id, @@ -328,38 +332,26 @@ def test_processing_correctness_phi3v( ) -def _inputs_equal( +def _assert_inputs_equal( a: MultiModalInputs, b: MultiModalInputs, - ignore_mm_keys: Optional[list[str]] = None, + *, + ignore_mm_keys: Optional[set[str]] = None, + msg: str = "", ): - return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys( - b, ignore_mm_keys) + if ignore_mm_keys is None: + ignore_mm_keys = set() + if msg is None: + assert "mm_kwargs" in a and "mm_kwargs" in b + else: + assert "mm_kwargs" in a and "mm_kwargs" in b, msg -def _drop_mm_kwargs_keys( - result: MultiModalInputs, - ignore_mm_keys: Optional[list[str]] = None, -) -> MultiModalInputs: - """Drop specified keys from result['mm_kwargs']. + for key in ignore_mm_keys: + a["mm_kwargs"].pop(key, None) + b["mm_kwargs"].pop(key, None) - This is mainly to avoid doing exact match of audio_features in ultravox. - - Args: - result: Result to drop keys from - ignore_mm_keys: List of keys to ignore, e.g. ['audio_features'] - """ - if not ignore_mm_keys: - return result - - if 'mm_kwargs' in result: - result = copy.deepcopy(result) - mm_kwargs = result['mm_kwargs'] - for key in ignore_mm_keys: - mm_kwargs.pop(key, None) - for items in mm_kwargs._items_by_modality.values(): - for item in items: - for key in ignore_mm_keys: - item.pop(key, None) - - return result + if msg is None: + assert a == b + else: + assert a == b, msg diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 8db2bfb901bf3..d843232ca1b6b 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -295,8 +295,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - assert isinstance(images, list) - parsed_images = (self._get_data_parser().parse_mm_data({ "image": images diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index ac10c211fa81f..1312b1051732f 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -23,7 +23,7 @@ # limitations under the License. """Inference-only MiniCPM-O model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence -from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, +from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple, TypedDict, Union) import torch @@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs from vllm.sequence import IntermediateTensors +from vllm.utils import flatten_2d_lists from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, MiniCPMVMultiModalDataParser, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, _minicpmv_field_config) -from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix +from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, + maybe_prefix) CPU_DEVICE = torch.device("cpu") class MiniCPMOAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: torch.Tensor + audio_features: torch.Tensor """ Shape: `(batch_size * num_audios * num_slices, num_channels, length)` Slice here means chunk. Audio that is too long will be split into slices, which is the same as image. - Padding is used therefore `data` is `torch.Tensor`. + Padding is used therefore `audio_features` is `torch.Tensor`. """ audio_feature_lens: torch.Tensor @@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict): Shape: `(batch_size * num_audios * num_slices)` This should be feature length of each audio slice, - which equals to `data.shape[-1]` + which equals to `audio_features.shape[-1]` """ audio_bounds: torch.Tensor @@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] - data: List[torch.Tensor] + audio_embeds: torch.Tensor """ Shape: `(batch_size * num_images * num_slices, hidden_size)` @@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): - audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0)) - return dict( **_minicpmv_field_config(hf_inputs), - audio_features=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices), - audio_feature_lens=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices), - audio_num_slices=MultiModalFieldConfig.batched("audio"), - audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"), - audio_embeds=MultiModalFieldConfig.flat_from_sizes( - "audio", audio_num_slices), + audio_features=MultiModalFieldConfig.batched("audio"), + audio_feature_lens=MultiModalFieldConfig.batched("audio"), + audio_embeds=MultiModalFieldConfig.batched("audio"), ) @@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): audio_pattern = "()" - def get_supported_mm_modalities(self) -> List[str]: - return ["image", "video", "audio"] - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None, "audio": None} @@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor( mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: - mm_data = dict(mm_data) + if (audios := mm_data.get("audios")) is None: + return {} - audios = mm_data.pop("audios", []) - audio_embeds = mm_data.pop("audio_embeds", []) - if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0: - audio_outputs = { - "audio_lens": [], - "audio_features": [], - "audio_feature_lens": [], - "audio_num_segments": [] - } - for audio in audios: - single_audio_outputs = super().call_base_hf_processor( - prompt=self.info.audio_pattern, - mm_data={ - "audios": audio, - "chunk_input": True - }, - mm_kwargs=mm_kwargs) - audio_outputs["audio_lens"].append(len(audio)) - audio_outputs["audio_features"].append( - single_audio_outputs["audio_features"]) - audio_outputs["audio_num_segments"].append( - len(single_audio_outputs["audio_feature_lens"][0])) - audio_outputs["audio_feature_lens"] += \ - single_audio_outputs["audio_feature_lens"] - audio_outputs["audio_features"] = [ - audio_feature for single_audio_features in \ - audio_outputs["audio_features"] - for audio_feature in single_audio_features - ] - audio_outputs["audio_feature_lens"] = torch.cat( - audio_outputs["audio_feature_lens"]) - elif len(audio_embeds): - audio_outputs = { - "audio_lens": [ - self.info.get_audio_len_by_num_chunks( - sum(chunk_embeds.shape[0] - for chunk_embeds in single_audio_embeds)) - for single_audio_embeds in audio_embeds - ], - "audio_embeds": [ - chunk_embeds for single_audio_embeds in audio_embeds - for chunk_embeds in single_audio_embeds - ], - "audio_num_segments": [ - len(single_audio_embeds) - for single_audio_embeds in audio_embeds - ] - } - else: - audio_outputs = {} - return audio_outputs + parsed_audios = (self._get_data_parser().parse_mm_data({ + "audio": audios + }).get_items("audio", AudioProcessorItems)) + + audio_inputs = self._base_call_hf_processor( + prompts=[self.info.audio_pattern] * len(parsed_audios), + mm_data={"audios": [[audio] for audio in parsed_audios]}, + mm_kwargs={ + **mm_kwargs, "chunk_input": True + }, + out_keys={"audio_features", "audio_feature_lens"}, + ) + + # Avoid padding since we need the output for each audio to be + # independent of other audios for the cache to work correctly + unpadded_audio_features = [ + feat[:, :feature_len] for feat, feature_len in zip( + audio_inputs["audio_features"], + audio_inputs["audio_feature_lens"], + ) + ] + audio_inputs["audio_features"] = unpadded_audio_features + + return audio_inputs def get_placeholder_match_pattern(self) -> str: return r"\(<(image|video|audio)>./\)" - def get_placeholder_split_pattern(self) -> str: - return r"\(<(?:image|video|audio)>./\)" - def process_mm_inputs( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], - ) -> Mapping[str, Mapping[str, NestedTensors]]: + ) -> Mapping[str, NestedTensors]: return { - "image": self.process_images(mm_data, mm_kwargs), - "video": self.process_videos(mm_data, mm_kwargs), - "audio": self.process_audios(mm_data, mm_kwargs), + **super().process_mm_inputs(mm_data, mm_kwargs), + **self.process_audios(mm_data, mm_kwargs), } - def get_modality_num_counter(self, modality: str) -> str: - if modality == "audio": - return "audio_lens" - return super().get_modality_num_counter(modality) - - def get_num_slices_by_modality(self, inputs: Dict[str, object], - modality: str, index: int) -> int: - if modality == "audio": - return inputs["audio"]["audio_num_segments"][index] - return super().get_num_slices_by_modality(inputs, modality, index) - - def get_prompt_texts_by_modality(self, inputs: Dict[str, object], - modality: str, index: int) -> str: - if modality == "audio": - return self.get_audio_prompt_texts( - inputs["audio"]["audio_lens"][index]) - return super().get_prompt_texts_by_modality(inputs, modality, index) - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6): # Copied from HF repo of MiniCPM-o-2_6, # designed for batched inputs and outputs def get_audio_hidden_states(self, data: MiniCPMOAudioInputs, - chunk_length: int) -> torch.Tensor: + chunk_length: int) -> list[torch.Tensor]: wavforms = data.get( - "data", + "audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = [data.get("audio_feature_lens", [])] # list, [[x1, x2], [y1], [z1]] - # exist audio - if len(wavforms) > 0: - audio_feature_lens = torch.hstack(audio_feature_lens_raw) - batch_size, _, max_mel_seq_len = wavforms.shape - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 - - # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = (torch.arange( - 0, - max_seq_len, - dtype=audio_feature_lens.dtype, - device=audio_feature_lens.device).unsqueeze(0).expand( - batch_size, max_seq_len)) - lengths_expand = audio_feature_lens.unsqueeze(1).expand( - batch_size, max_seq_len) - # Create mask - padding_mask = seq_range >= lengths_expand # 1 for padded values - - audio_attention_mask_ = padding_mask.view( - batch_size, 1, 1, max_seq_len).expand(batch_size, 1, - max_seq_len, max_seq_len) - audio_attention_mask = audio_attention_mask_.to( - dtype=self.apm.conv1.weight.dtype, - device=self.apm.conv1.weight.device) - - if chunk_length > 0: - chunk_num_frame = int(chunk_length * 50) - chunk_mask = self.subsequent_chunk_mask( - size=max_seq_len, - chunk_size=chunk_num_frame, - num_left_chunks=-1, - device=audio_attention_mask_.device, - ) - audio_attention_mask_ = torch.logical_or( - audio_attention_mask_, torch.logical_not(chunk_mask)) - - audio_attention_mask[audio_attention_mask_] = float("-inf") - audio_states = self.apm( - wavforms, attention_mask=audio_attention_mask).hidden_states[ - self.audio_encoder_layer] - audio_embeds = self.audio_projection_layer(audio_states) - - audio_embeds = audio_embeds.transpose(1, 2) - audio_embeds = self.audio_avg_pooler(audio_embeds) - audio_embeds = audio_embeds.transpose(1, 2) - - _, feature_lens_after_pooling = \ - self._get_feat_extract_output_lengths(audio_feature_lens) - - num_audio_tokens = feature_lens_after_pooling - - final_audio_embeds = [] - idx = 0 - for i in range(len(audio_feature_lens_raw)): - target_audio_embeds = [] - for _ in range(len(audio_feature_lens_raw[i])): - target_audio_embeds.append( - audio_embeds[idx, :num_audio_tokens[idx], :]) - idx += 1 - final_audio_embeds.append(target_audio_embeds) - return final_audio_embeds - else: + if len(wavforms) == 0: return [] + audio_feature_lens = torch.hstack(audio_feature_lens_raw) + batch_size, _, max_mel_seq_len = wavforms.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = (torch.arange( + 0, + max_seq_len, + dtype=audio_feature_lens.dtype, + device=audio_feature_lens.device).unsqueeze(0).expand( + batch_size, max_seq_len)) + lengths_expand = audio_feature_lens.unsqueeze(1).expand( + batch_size, max_seq_len) + # Create mask + padding_mask = seq_range >= lengths_expand # 1 for padded values + + audio_attention_mask_ = padding_mask.view( + batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len, + max_seq_len) + audio_attention_mask = audio_attention_mask_.to( + dtype=self.apm.conv1.weight.dtype, + device=self.apm.conv1.weight.device) + + if chunk_length > 0: + chunk_num_frame = int(chunk_length * 50) + chunk_mask = self.subsequent_chunk_mask( + size=max_seq_len, + chunk_size=chunk_num_frame, + num_left_chunks=-1, + device=audio_attention_mask_.device, + ) + audio_attention_mask_ = torch.logical_or( + audio_attention_mask_, torch.logical_not(chunk_mask)) + + audio_attention_mask[audio_attention_mask_] = float("-inf") + audio_states = self.apm( + wavforms, attention_mask=audio_attention_mask).hidden_states[ + self.audio_encoder_layer] + audio_embeds = self.audio_projection_layer(audio_states) + + audio_embeds = audio_embeds.transpose(1, 2) + audio_embeds = self.audio_avg_pooler(audio_embeds) + audio_embeds = audio_embeds.transpose(1, 2) + + _, feature_lens_after_pooling = \ + self._get_feat_extract_output_lengths(audio_feature_lens) + + num_audio_tokens = feature_lens_after_pooling + + final_audio_embeds = [] + idx = 0 + for i in range(len(audio_feature_lens_raw)): + target_audio_embeds = [] + for _ in range(len(audio_feature_lens_raw[i])): + target_audio_embeds.append( + audio_embeds[idx, :num_audio_tokens[idx], :]) + idx += 1 + final_audio_embeds.append(target_audio_embeds) + return final_audio_embeds + def get_embedding_with_audios(self, vlm_embedding: torch.Tensor, - audio_inputs: Optional[MiniCPMOAudioInputs], + audio_inputs: MiniCPMOAudioInputs, chunk_length: int) -> torch.Tensor: device, dtype = vlm_embedding.device, vlm_embedding.dtype if audio_inputs["type"] == "audio_embeds": - audio_embeddings = audio_inputs["data"] audio_embeddings = [ - audio_embeddings[i].to(device=device, dtype=dtype) - for i in range(len(audio_embeddings)) + item.to(device=device, dtype=dtype) + for item in audio_inputs["audio_embeds"] ] else: audio_embeddings = self.get_audio_hidden_states( @@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6): def _parse_and_validate_audio_inputs( self, input_ids: torch.Tensor, - **kwargs: object) -> Tuple[MiniCPMOAudioInputs]: - audio_features = kwargs.pop("audio_features", []) - audio_feature_lens = kwargs.pop("audio_feature_lens", []) + **kwargs: object) -> Optional[MiniCPMOAudioInputs]: + audio_features = kwargs.pop("audio_features", None) audio_embeds = kwargs.pop("audio_embeds", None) - audio_start_id = kwargs.pop("audio_start_id", None) - audio_end_id = kwargs.pop("audio_end_id", None) + + if audio_features is None and audio_embeds is None: + return None + + audio_start_id = kwargs.pop("audio_start_id") + if not isinstance(audio_start_id, torch.Tensor): + raise ValueError("Incorrect type of audio_start_id. " + f"Got type: {type(audio_start_id)}") + + audio_end_id = kwargs.pop("audio_end_id") + if not isinstance(audio_end_id, torch.Tensor): + raise ValueError("Incorrect type of audio_end_id. " + f"Got type: {type(audio_end_id)}") + if audio_embeds is not None: - audio_embeds = [ - audio_embeds[i][j] for i in range(len(audio_embeds)) - for j in range(len(audio_embeds[i])) - ] + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_embeds. " + f"Got type: {type(audio_embeds)}") + return MiniCPMOAudioEmbeddingInputs( + type="audio_embeds", + audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds), + concat=True), audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_end_id), - data=audio_embeds, - type="audio_embeds") - if len(audio_features) > 0: - audio_features_all = [ - i.permute(1, 0) for audio_feature in audio_features - for i in audio_feature - ] - audio_features = torch.nn.utils.rnn.pad_sequence( - audio_features_all, batch_first=True, - padding_value=0.0).permute(0, 2, 1) - audio_feature_lens = torch.cat( - [item for item in audio_feature_lens]) + ) + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_features. " + f"Got type: {type(audio_features)}") + + audio_feature_lens = kwargs.pop("audio_feature_lens") + if not isinstance(audio_feature_lens, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_feature_lens. " + f"Got type: {type(audio_feature_lens)}") return MiniCPMOAudioFeatureInputs( + type="audio_features", + audio_features=flatten_bn(audio_features, concat=True), + audio_feature_lens=flatten_bn( + flatten_2d_lists(audio_feature_lens), concat=True), audio_bounds=self._get_audio_bounds(input_ids, audio_start_id, audio_end_id), - data=audio_features, - audio_feature_lens=audio_feature_lens, - type="audio_features") - return None + ) + + raise AssertionError("This line should be unreachable.") def _parse_and_validate_inputs(self, input_ids: torch.Tensor, **kwargs: object): @@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6): else: image_inputs, audio_inputs = \ self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings, _ = self.get_embedding_with_vision( + vlm_embeddings = self.get_embedding_with_vision( input_ids, image_inputs) if audio_inputs is not None: diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 48c8572c05f65..23c010c63d558 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -24,6 +24,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re +from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import cached_property, partial from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple, @@ -63,11 +64,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import flatten_2d_lists from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsV0Only) -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix CPU_DEVICE = torch.device("cpu") @@ -76,7 +78,7 @@ RawImageType = Union[Image.Image, torch.Tensor] class MiniCPMVImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: List[torch.Tensor] + pixel_values: list[torch.Tensor] """ Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` @@ -101,7 +103,7 @@ class MiniCPMVImagePixelInputs(TypedDict): class MiniCPMVImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor + image_embeds: torch.Tensor """ Shape: `(batch_size * num_images * num_slices, image_feature_size, hidden_size)` @@ -231,26 +233,15 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): - image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0)) - video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0)) - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), + pixel_values=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"), - tgt_sizes=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), - image_num_slices=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.flat_from_sizes( - "image", image_num_slices), - video_pixel_values=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), + tgt_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"), - video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), - video_embeds=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_slices), - video_num_slices=MultiModalFieldConfig.batched("video"), + video_tgt_sizes=MultiModalFieldConfig.batched("video"), + video_embeds=MultiModalFieldConfig.batched("video"), ) @@ -356,12 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): def get_model_version(self): return get_version_by_config(self.get_hf_config()) - def get_supported_mm_modalities(self) -> List[str]: - if self.get_model_version() == (2, 6): - return ["image", "video"] - else: - return ["image"] - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: if self.get_model_version() == (2, 6): return {"image": None, "video": None} @@ -526,187 +511,123 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: - prompt_texts = self.get_slice_image_placeholder(image_size, - image_idx=image_idx) - return prompt_texts + return self.get_slice_image_placeholder(image_size, + image_idx=image_idx) def get_video_prompt_texts(self, image_size: ImageSize, num_frames: int) -> str: - prompt_texts = "".join( - self.get_slice_image_placeholder( - image_size=image_size, - image_idx=0, - max_slice_nums=self.info.get_video_max_slice_num(), - use_image_id=False) for image_idx in range(num_frames)) - return prompt_texts + return self.get_slice_image_placeholder( + image_size=image_size, + image_idx=0, + max_slice_nums=self.info.get_video_max_slice_num(), + use_image_id=False, + ) * num_frames def get_special_tokens(self) -> Dict[str, torch.Tensor]: tokenizer = self.info.get_tokenizer() + special_tokens = { - "im_start_id": torch.tensor(tokenizer.im_start_id), - "im_end_id": torch.tensor(tokenizer.im_end_id) + "im_start_id": tokenizer.im_start_id, + "im_end_id": tokenizer.im_end_id, } if hasattr(tokenizer, "slice_start_id"): - special_tokens["slice_start_id"] = torch.tensor( - tokenizer.slice_start_id) - special_tokens["slice_end_id"] = torch.tensor( - tokenizer.slice_end_id) - return special_tokens + special_tokens["slice_start_id"] = tokenizer.slice_start_id + special_tokens["slice_end_id"] = tokenizer.slice_end_id - @staticmethod - def repack_processor_outputs(outputs: Any) -> BatchFeature: - valid_keys = ["pixel_values", "image_sizes", "tgt_sizes"] - outputs = {key: outputs[key][0] for key in valid_keys} - return outputs + return {k: torch.tensor(v) for k, v in special_tokens.items()} def process_images( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: - mm_data = dict(mm_data) + if (images := mm_data.get("images")) is None: + return {} - images = mm_data.pop("images", []) - image_embeds = mm_data.pop("image_embeds", []) - if isinstance(images, Image.Image): - images = [images] - if isinstance(images, (list, torch.Tensor)) and len(images) > 0: - image_outputs = super()._call_hf_processor( - prompt=self.info.image_pattern * len(images), - mm_data={"images": images}, - mm_kwargs=mm_kwargs) - image_outputs = self.repack_processor_outputs(image_outputs) - elif len(image_embeds) > 0: - image_sizes = mm_data.pop("image_sizes", None) - image_outputs = { - "image_embeds": torch.cat(image_embeds), - "image_sizes": image_sizes - } - else: - image_outputs = {} - return image_outputs + parsed_images = (self._get_data_parser().parse_mm_data({ + "image": images + }).get_items("image", ImageProcessorItems)) + + return self._base_call_hf_processor( + prompts=[self.info.image_pattern] * len(parsed_images), + mm_data={"images": [[image] for image in parsed_images]}, + mm_kwargs=mm_kwargs, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) def process_videos( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], ) -> Mapping[str, NestedTensors]: - mm_data = dict(mm_data) + if (videos := mm_data.get("videos")) is None: + return {} - videos = mm_data.pop("videos", []) - video_embeds = mm_data.pop("video_embeds", []) - if len(videos) > 0 and isinstance(videos[0], Image.Image): - videos = [videos] - if isinstance(videos, list) and len(videos) > 0: - video_outputs = { - "video_pixel_values": [], - "video_image_sizes": [], - "video_tgt_sizes": [], - "num_frames": [] - } - for video in videos: - parsed_video = [] - for frame in video: - if isinstance(frame, np.ndarray): - parsed_video.append(Image.fromarray(frame)) - else: - parsed_video.append(frame) - video = parsed_video - single_video_outputs = super()._call_hf_processor( - prompt=self.info.image_pattern * len(video), - mm_data={"images": video}, - mm_kwargs={ - **mm_kwargs, "max_slice_nums": - self.info.get_video_max_slice_num() - }) - video_outputs["num_frames"].append(len(video)) - for key in single_video_outputs: - if "video_" + key in video_outputs: - if key == "image_sizes": - video_outputs["video_" + key].append( - single_video_outputs[key][0][0]) - else: - video_outputs["video_" + - key] += single_video_outputs[key][0] - elif len(video_embeds): - image_sizes = mm_data.pop("image_sizes", None) - num_frames = mm_data.pop("num_frames", None) - video_outputs = { - "video_embeds": torch.cat(video_embeds), - "video_image_sizes": image_sizes, - "num_frames": num_frames - } - else: - video_outputs = {} - return video_outputs + parsed_videos = (self._get_data_parser().parse_mm_data({ + "video": videos + }).get_items("video", VideoProcessorItems)) + + max_slice_num = self.info.get_video_max_slice_num() + + video_inputs = self._base_call_hf_processor( + prompts=[ + self.info.image_pattern * len(video) for video in parsed_videos + ], + mm_data={"images": list(parsed_videos)}, + mm_kwargs={ + **mm_kwargs, "max_slice_nums": max_slice_num + }, + out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, + ) + + return {f"video_{k}": v for k, v in video_inputs.items()} def get_placeholder_match_pattern(self) -> str: return r"\(<(image|video)>./\)" - def get_placeholder_split_pattern(self) -> str: - return r"\(<(?:image|video)>./\)" - def process_mm_inputs( self, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], - ) -> Mapping[str, Mapping[str, NestedTensors]]: + ) -> Mapping[str, NestedTensors]: return { - "image": self.process_images(mm_data, mm_kwargs), - "video": self.process_videos(mm_data, mm_kwargs), + **self.process_images(mm_data, mm_kwargs), + **self.process_videos(mm_data, mm_kwargs), } - def get_input_modalities(self, mm_data) -> List[str]: - supported_mm_modalities = self.info.get_supported_mm_modalities() - input_modalities = [] - for modality in supported_mm_modalities: - if modality in mm_data and mm_data[modality] != {}: - input_modalities.append(modality) - return input_modalities - - def get_modality_num_counter(self, modality: str) -> str: - if modality == "image": - return "image_sizes" - elif modality == "video": - return "video_image_sizes" - - raise NotImplementedError(modality) - - def get_num_slices_by_modality(self, inputs: dict[str, Any], modality: str, - index: int) -> int: - if modality == "image": - return self.info.get_image_slice_nums( - inputs[modality]["image_sizes"][index], - self.info.get_max_slice_num()) - elif modality == "video": - return self.info.get_image_slice_nums( - inputs[modality]["video_image_sizes"][index], - self.info.get_video_max_slice_num() - ) * inputs[modality]["num_frames"][index] - else: - raise ValueError(f"Unexpected modality: {modality}") - - def get_prompt_texts_by_modality(self, inputs: dict[str, Any], - modality: str, index: int) -> str: - if modality == "image": - return self.get_image_prompt_texts( - inputs["image"]["image_sizes"][index], index) - elif modality == "video": - return self.get_video_prompt_texts( - inputs["video"]["video_image_sizes"][index], - inputs["video"]["num_frames"][index]) - else: - raise ValueError(f"Unexpected modality: {modality}") - - def call_base_hf_processor( + def _base_call_hf_processor( self, - prompt: str, - mm_data: Mapping[str, object], + prompts: list[str], + mm_data: Mapping[str, Sequence[object]], mm_kwargs: Mapping[str, object], - ) -> BatchFeature: - return super()._call_hf_processor(prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs) + *, + out_keys: set[str], + ) -> Mapping[str, NestedTensors]: + # This processor supports zipping prompt and mm_data together + if self.info.get_model_version() == (2, 6): + inputs = super()._call_hf_processor( + prompt=prompts, # type: ignore + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + else: + inputs = defaultdict[str, list[torch.Tensor]](list) + + for i, prompt in enumerate(prompts): + inputs_one = super()._call_hf_processor( + prompt=prompt, + mm_data={ + k: v[i] + for k, v in mm_data.items() + }, + mm_kwargs=mm_kwargs, + ) + + for k, v in inputs_one.items(): + assert len(v) == 1, (k, len(v)) + inputs[k].append(v[0]) + + return {k: inputs[k] for k in out_keys} def _call_hf_processor( self, @@ -717,35 +638,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): # Do not support combination inputs of images and videos for now # Try to handle interleaved multimodal data tokenizer = self.info.get_tokenizer() - inputs = self.process_mm_inputs(mm_data, mm_kwargs) - mm_input_modalities = self.get_input_modalities(inputs) - - num_mm_slices_lst = { - modality: list[int]() - for modality in mm_input_modalities - } - for modality in mm_input_modalities: - num_counter_key = self.get_modality_num_counter(modality) - for index in range(len(inputs[modality][num_counter_key])): - num_mm_slices_lst[modality].append( - self.get_num_slices_by_modality(inputs, modality, index)) - - num_mm_slices = { - modality: torch.tensor(v) - for modality, v in num_mm_slices_lst.items() - } + mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) return BatchFeature({ - "input_ids": np.array([tokenizer.encode(prompt)]), - **{ - key: value - for modality in inputs - for key, value in inputs[modality].items() - }, - **{ - f"{modality}_num_slices": num_mm_slices[modality] - for modality in mm_input_modalities - } + "input_ids": + torch.tensor([tokenizer.encode(prompt)]), + **mm_inputs, }) def _hf_processor_applies_updates( @@ -810,7 +708,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): hf_processor_mm_kwargs: Mapping[str, object], return_mm_hashes: bool = False, ) -> MultiModalInputs: - supported_mm_modalities = self.info.get_supported_mm_modalities() if isinstance(prompt, list): prompt = self.info.get_tokenizer().decode(prompt) matches = re.findall(self.get_placeholder_match_pattern(), prompt) @@ -818,7 +715,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): f"{modality}_orders": torch.tensor( [index for index, m in enumerate(matches) if m == modality]) - for modality in supported_mm_modalities + for modality in self.info.get_supported_mm_limits() } result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, return_mm_hashes) @@ -884,35 +781,35 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) - if image_inputs is None: # No image - vision_hidden_states = torch.tensor([], device=input_ids.device) + if image_inputs is None: + return vlm_embedding + + if image_inputs["type"] == "image_embeds": + vision_hidden_states = image_inputs["image_embeds"].to( + device=vlm_embedding.device, + dtype=vlm_embedding.dtype, + ) else: - if image_inputs["type"] == "image_embeds": - vision_hidden_states = (image_inputs["data"].type( - vlm_embedding.dtype).to(vlm_embedding.device)) - else: - vision_hidden_states = self.get_vision_hidden_states( - image_inputs) + vision_hidden_states = self.get_vision_hidden_states(image_inputs) - # See NOTE in _parse_and_validate_inputs - image_bounds = image_inputs["image_bounds"] - if len(image_bounds) > 0: - image_indices = torch.stack([ - torch.arange(start, end, dtype=torch.long) - for start, end in image_bounds.tolist() - ]).to(vlm_embedding.device) - vlm_embedding.scatter_( - 0, - image_indices.view(-1, 1).repeat(1, - vlm_embedding.shape[-1]), - vision_hidden_states.view(-1, - vision_hidden_states.shape[-1]), - ) + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack([ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ]).to(vlm_embedding.device) - return vlm_embedding, vision_hidden_states + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), + ) + + return vlm_embedding def _get_image_bounds( self, @@ -947,90 +844,115 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, input_ids: torch.Tensor, **kwargs: object, ) -> Optional[MiniCPMVImageInputs]: - mm_data = { + image_keys = {"pixel_values", "tgt_sizes"} + pixel_data = { "image": { - key: kwargs.pop(key, []) - for key in ["pixel_values", "tgt_sizes", "image_num_slices"] + key: kwargs.pop(key, None) + for key in image_keys }, "video": { - "pixel_values": kwargs.pop("video_pixel_values", []), - "tgt_sizes": kwargs.pop("video_tgt_sizes", []), - "video_num_slices": kwargs.pop("video_num_slices", []) + key: kwargs.pop("video_" + key, None) + for key in image_keys } } - im_start_id = kwargs.pop("im_start_id", None) - im_end_id = kwargs.pop("im_end_id", None) - slice_start_id = kwargs.pop("slice_start_id", None) - slice_end_id = kwargs.pop("slice_end_id", None) - mm_orders = { - f"{modality}": kwargs.pop(f"{modality}_orders", None) - for modality in ["image", "video", "audio"] + embed_data = { + "image": kwargs.pop("image_embeds", None), + "video": kwargs.pop("video_embeds", None), } - batch_size = max(len(mm_data["image"]["pixel_values"]), - len(mm_data["video"]["pixel_values"])) - image_embeds = kwargs.pop("image_embeds", None) - video_embeds = kwargs.pop("video_embeds", None) - if image_embeds is not None and video_embeds is not None: - raise ValueError( - "Incorrect inputs for vision embeddings. " - "Image embeds and video embeds can not exist simultaneously.") - if video_embeds is not None: - image_embeds = video_embeds - if image_embeds is not None: - if not isinstance(image_embeds, (torch.Tensor, list)): - raise ValueError(f"Incorrect type of image embeds. " - f"Got type: {type(image_embeds)}") - image_embeds = torch.concat( - [image_embeds[i] for i in range(len(image_embeds))]) + + all_pixel_data = [ + v for vs in pixel_data.values() for v in vs.values() + if v is not None + ] + all_embed_data = [v for v in embed_data.values() if v is not None] + if len(all_pixel_data) == 0 and len(all_embed_data) == 0: + return None + + im_start_id = kwargs.pop("im_start_id") + if not isinstance(im_start_id, torch.Tensor): + raise ValueError("Incorrect type of im_start_id. " + f"Got type: {type(im_start_id)}") + + im_end_id = kwargs.pop("im_end_id") + if not isinstance(im_end_id, torch.Tensor): + raise ValueError("Incorrect type of im_end_id. " + f"Got type: {type(im_end_id)}") + + slice_start_id = kwargs.pop("slice_start_id", None) + if slice_start_id is not None and not isinstance( + slice_start_id, torch.Tensor): + raise ValueError("Incorrect type of slice_start_id. " + f"Got type: {type(slice_start_id)}") + + slice_end_id = kwargs.pop("slice_end_id", None) + if slice_end_id is not None and not isinstance(slice_end_id, + torch.Tensor): + raise ValueError("Incorrect type of slice_end_id. " + f"Got type: {type(slice_end_id)}") + + if len(all_embed_data) > 0: + if len(all_embed_data) > 1: + raise ValueError("Incorrect inputs for vision embeddings. " + "Image embeds and video embeds can not " + "exist simultaneously.") + + vision_embeds, = all_embed_data + if not isinstance(vision_embeds, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of vision_embeds. " + f"Got type: {type(vision_embeds)}") return MiniCPMVImageEmbeddingInputs( + type="image_embeds", + image_embeds=flatten_bn(flatten_2d_lists(vision_embeds), + concat=True), image_bounds=self._get_image_bounds(input_ids, im_start_id, im_end_id, slice_start_id, slice_end_id), - data=image_embeds, - type="image_embeds", ) - for modality, modality_mm_data in mm_data.items(): - if not isinstance(modality_mm_data["pixel_values"], - (torch.Tensor, list)): - raise ValueError( - "Incorrect type of pixel values. " - f"Got type: {type(modality_mm_data['pixel_values'])}") - if not isinstance(modality_mm_data["tgt_sizes"], - (torch.Tensor, list)): - raise ValueError( - "Incorrect type of target sizes. " - f"Got type: {type(modality_mm_data['tgt_sizes'])}") + order_data = dict[str, Union[torch.Tensor, list[torch.Tensor]]]() + for modality in ("image", "video"): + modality_orders = kwargs.pop(f"{modality}_orders", None) + if modality_orders is not None: + if not isinstance(modality_orders, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {modality}_orders. " + f"Got type: {type(modality_orders)}") - if len(modality_mm_data["pixel_values"]) != len( - modality_mm_data["tgt_sizes"]): - raise ValueError( - "Inconsistent batch lengths, found: " - f"{len(modality_mm_data['pixel_values'])} vs. " - f"{len(modality_mm_data['tgt_sizes'])}") + order_data[modality] = modality_orders - pixel_values_flat: List[torch.Tensor] = [] - tgt_sizes_flat: List[torch.Tensor] = [] + batch_sizes = { + modality: len(modality_orders) + for modality, modality_orders in order_data.items() + } + unique_batch_sizes = set(batch_sizes.values()) + assert len(unique_batch_sizes) == 1, ( + f"Found inconsistent batch sizes: {batch_sizes}") + batch_size, = unique_batch_sizes + + pixel_values_flat = list[torch.Tensor]() + tgt_sizes_flat = list[torch.Tensor]() for b in range(batch_size): - mm_counts = {"image": 0, "video": 0} if self.version == (2, 6) \ - else {"image": 0} - mm_slice_counts = {"image": 0, "video": 0} \ - if self.version == (2, 6) else {"image": 0} - mm_orders_b = [(index, modality) for modality in mm_counts - for index in mm_orders[modality][b]] + mm_orders_b = [(idx_b.item(), modality) + for modality, modality_orders in order_data.items() + for idx_b in modality_orders[b]] + for _, modality in sorted(mm_orders_b, key=lambda x: x[0]): - pos = mm_counts[modality] - num_slices = mm_data[modality][f"{modality}_num_slices"][b][ - pos] - slice_start_idx = mm_slice_counts[modality] - slice_end_idx = slice_start_idx + num_slices - pixel_values_flat += mm_data[modality]["pixel_values"][b][ - slice_start_idx:slice_end_idx] - tgt_sizes_flat += mm_data[modality]["tgt_sizes"][b][ - slice_start_idx:slice_end_idx] - mm_counts[modality] += 1 - mm_slice_counts[modality] += num_slices + modality_pixel_data = pixel_data[modality] + + modality_pixel_values = modality_pixel_data["pixel_values"] + if not isinstance(modality_pixel_values, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of pixel_values for {modality=}. " + f"Got type: {type(modality_pixel_values)}") + + modality_tgt_sizes = modality_pixel_data["tgt_sizes"] + if not isinstance(modality_tgt_sizes, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of tgt_sizes for {modality=}. " + f"Got type: {type(modality_tgt_sizes)}") + + pixel_values_flat += flatten_2d_lists(modality_pixel_values[b]) + tgt_sizes_flat += flatten_2d_lists(modality_tgt_sizes[b]) # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty @@ -1042,16 +964,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, if len(pixel_values_flat) == 0: return None - if im_start_id is None: - return None - return MiniCPMVImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), image_bounds=self._get_image_bounds(input_ids, im_start_id, im_end_id, slice_start_id, slice_end_id), - data=pixel_values_flat, - tgt_sizes=torch.stack(tgt_sizes_flat), - type="pixel_values", ) def _parse_and_validate_inputs(self, input_ids: torch.Tensor, @@ -1070,7 +989,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, else: image_inputs = \ self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings, _ = self.get_embedding_with_vision( + vlm_embeddings = self.get_embedding_with_vision( input_ids, image_inputs) # always pass the input via `inputs_embeds` @@ -1136,16 +1055,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP, prefix: str = "") -> nn.Module: raise NotImplementedError - def get_vision_embedding( - self, - pixel_values: List[torch.Tensor], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - raise NotImplementedError - - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: raise NotImplementedError @@ -1216,35 +1127,27 @@ class MiniCPMV2_0(MiniCPMVBaseModel): return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) - def get_vision_embedding( - self, - pixel_values: List[torch.Tensor], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - res = [] - dtype = self.vpm.pos_embed.data.dtype + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + + P_h, P_w = self.vpm.patch_embed.patch_size + dtype: torch.dtype = self.vpm.pos_embed.data.dtype + num_prefix_tokens = getattr(self.vpm, "num_prefix_tokens", 0) + + res = list[torch.Tensor]() for pixel_value in pixel_values: H, W = pixel_value[0].shape[-2:] - tgt_size = ( - math.ceil(H / self.vpm.patch_embed.patch_size[0]), - math.ceil(W / self.vpm.patch_embed.patch_size[0]), - ) + tgt_size = (math.ceil(H / P_h), math.ceil(W / P_w)) vision_embedding = self.vpm.forward_features( pixel_value.unsqueeze(0).type(dtype)) - if (hasattr(self.vpm, "num_prefix_tokens") - and self.vpm.num_prefix_tokens > 0): - vision_embedding = vision_embedding[:, self.vpm. - num_prefix_tokens:] + + if num_prefix_tokens > 0: + vision_embedding = vision_embedding[:, num_prefix_tokens:] res.append(self.resampler(vision_embedding, tgt_size)) + return torch.vstack(res) - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: - pixel_values = data["data"] - - return self.get_vision_embedding(pixel_values) - class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -1299,45 +1202,41 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) - def get_vision_embedding( - self, - pixel_values: List[torch.Tensor], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - vision_embedding = self.vpm(pixel_values, - patch_attention_mask=patch_attn_mask) - vision_embedding = self.resampler(vision_embedding, tgt_sizes) - return vision_embedding - - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: - pixel_values = data["data"] + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] - device = self.vpm.embeddings.position_embedding.weight.device - dtype = self.vpm.embeddings.position_embedding.weight.dtype - all_pixel_values_lst = [ - i.flatten(end_dim=1).permute(1, 0) for i in pixel_values - ] + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype - max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() assert isinstance(max_patches, int) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values_lst, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - - patch_attn_mask = torch.zeros((B, 1, max_patches), + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True - return self.get_vision_embedding(all_pixel_values.type(dtype), - patch_attn_mask, tgt_sizes) + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=None, + ) + + return self.resampler(vision_embedding, tgt_sizes) class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): @@ -1394,47 +1293,37 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): return resampler.to(device=current_platform.device_type, dtype=torch.get_default_dtype()) - def get_vision_embedding( - self, - pixel_values: List[torch.Tensor], - patch_attn_mask: Optional[torch.Tensor] = None, - tgt_sizes: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - vision_embedding = self.vpm( - pixel_values, - patch_attention_mask=patch_attn_mask, - tgt_sizes=tgt_sizes, - ) - return vision_embedding - - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: - pixel_values = data["data"] + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] - device = self.vpm.embeddings.position_embedding.weight.device - dtype = self.vpm.embeddings.position_embedding.weight.dtype - all_pixel_values_lst = [ - i.flatten(end_dim=1).permute(1, 0) for i in pixel_values - ] + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype - max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() assert isinstance(max_patches, int) - all_pixel_values = torch.nn.utils.rnn.pad_sequence( - all_pixel_values_lst, batch_first=True, padding_value=0.0) - B, L, _ = all_pixel_values.shape - all_pixel_values = all_pixel_values.permute(0, 2, - 1).reshape(B, 3, -1, L) - - patch_attn_mask = torch.zeros((B, 1, max_patches), + patch_attn_mask = torch.zeros((B, max_patches), dtype=torch.bool, device=device) - for i in range(B): - patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + vision_embedding = self.vpm( - all_pixel_values.type(dtype), - patch_attention_mask=patch_attn_mask, + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), tgt_sizes=tgt_sizes, ) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 3c609fd967650..3a588bb4eaba1 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -665,6 +665,13 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): return cast(BatchedTensorInputs, json_mapped) + def __delitem__(self, key: str) -> None: + super().__delitem__(key) + + for items in self._items_by_modality.values(): + for item in items: + item.pop(key, None) + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False