From 803d5c35f3e8a6547ff7c6e6c322e54cbfec8444 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Mar 2025 18:20:42 +0800 Subject: [PATCH] [V1] Override `mm_counts` for dummy data creation (#15703) Signed-off-by: DarkLight1337 --- .../vision_language/test_models.py | 28 ++------------- .../model_executor/models/llava_next_video.py | 14 +++++--- vllm/model_executor/models/llava_onevision.py | 25 ++++++++----- vllm/model_executor/models/minicpmo.py | 26 ++++++++------ vllm/model_executor/models/minicpmv.py | 36 ++++++++++++------- vllm/model_executor/models/qwen2_vl.py | 26 +++++++++----- vllm/multimodal/profiling.py | 30 +++++++++++----- vllm/multimodal/registry.py | 6 ++-- vllm/v1/worker/gpu_model_runner.py | 16 +++------ 9 files changed, 114 insertions(+), 93 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 0d1d237e5693c..ecb637c62e439 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -385,7 +385,7 @@ VLM_TEST_SETTINGS = { ), "minicpmo_26": VLMTestInfo( models=["openbmb/MiniCPM-o-2_6"], - test_type=(VLMTestType.IMAGE), + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, @@ -394,21 +394,9 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, ), - "minicpmo_26_multi_image": VLMTestInfo( - models=["openbmb/MiniCPM-o-2_6"], - test_type=(VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "(./)\n", - max_model_len=4096, - max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 - hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, - patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, - marks=[large_gpu_mark(min_gb=32)], - ), "minicpmv_26": VLMTestInfo( models=["openbmb/MiniCPM-V-2_6"], - test_type=(VLMTestType.IMAGE), + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, @@ -417,18 +405,6 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, ), - "minicpmv_26_multi_image": VLMTestInfo( - models=["openbmb/MiniCPM-V-2_6"], - test_type=(VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "(./)\n", - max_model_len=4096, - max_num_seqs=2, - get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 - hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, - patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, - marks=[large_gpu_mark(min_gb=32)], - ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 8b1a8c9da6804..8a5edefb4a0b2 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -71,7 +71,8 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): max_video_tokens = self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features(seq_len), + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), ) return {"video": max_video_tokens} @@ -130,9 +131,12 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): return num_frames - def get_num_frames_with_most_features(self, seq_len: int) -> int: - mm_config = self.ctx.get_mm_config() - max_videos = mm_config.get_limit_per_prompt("video") + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len) @@ -155,7 +159,7 @@ class LlavaNextVideoDummyInputsBuilder( target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len) + self.info.get_num_frames_with_most_features(seq_len, mm_counts) mm_data = { "video": diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index fbc298b812498..c7e13bb352f42 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ) -> Mapping[str, int]: return { "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len), + "video": self.get_max_video_tokens(seq_len, mm_counts), } # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 @@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): return num_frames - def get_num_frames_with_most_features(self, seq_len: int) -> int: - mm_config = self.ctx.get_mm_config() - max_images = mm_config.get_limit_per_prompt("image") - max_videos = mm_config.get_limit_per_prompt("video") + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - @@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): return max(max_frames_per_video, 1) - def get_max_video_tokens(self, seq_len: int) -> int: + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features(seq_len), + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), ) @@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder( target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len) + self.info.get_num_frames_with_most_features(seq_len, + mm_counts) mm_data = { "image": diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index ea37de0b806ab..c74e086d3748e 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -43,7 +43,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.profiling import ProcessorInputs -from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder, +from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, + MiniCPMVDummyInputsBuilder, MiniCPMVMultiModalDataParser, MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo, _minicpmv_field_config) @@ -203,8 +204,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): return 30 def get_max_audio_tokens(self) -> int: - return self.get_max_audio_tokens_per_chunk( - ) * self.get_max_audio_chunks_with_most_features() + num_chunks = self.get_max_audio_chunks_with_most_features() + return self.get_max_audio_tokens_per_chunk() * num_chunks def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: sampling_rate = self.get_default_audio_sampling_rate() @@ -212,21 +213,24 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2 return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 - def get_num_frames_with_most_features(self, seq_len: int) -> int: - mm_config = self.ctx.get_mm_config() - max_images = mm_config.get_limit_per_prompt("image") - max_videos = mm_config.get_limit_per_prompt("video") - max_audios = mm_config.get_limit_per_prompt("audio") + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_audios = mm_counts.get("audio", 0) max_image_tokens = self.get_max_image_tokens() * max_images max_audio_tokens = self.get_max_audio_tokens() * max_audios max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens - max_audio_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) - num_frames = max(max_total_frames // max(max_videos, 1), 1) - - return num_frames + return max(max_frames_per_video, 1) class MiniCPMODummyInputsBuilder( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 76c7a59d656d5..2c0d37e883b90 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -69,6 +69,9 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) from .vision import scatter_patch_features, select_patch_features +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 + class MiniCPMVImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -369,7 +372,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ) -> Mapping[str, int]: mm_max_tokens = {"image": self.get_max_image_tokens()} if self.get_model_version() == (2, 6): - mm_max_tokens["video"] = self.get_max_video_tokens(seq_len) + mm_max_tokens["video"] = self.get_max_video_tokens( + seq_len, mm_counts) return mm_max_tokens @@ -432,9 +436,14 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): use_image_id=False, ) - def get_max_video_tokens(self, seq_len: int) -> int: - return self.get_max_video_frame_tokens( - ) * self.get_num_frames_with_most_features(seq_len) + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + num_frames = self.get_num_frames_with_most_features(seq_len, mm_counts) + num_video_tokens_total = self.get_max_video_frame_tokens() * num_frames + return num_video_tokens_total def get_video_max_slice_num(self) -> int: return 1 @@ -449,18 +458,21 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): num_frames = max_tokens // num_frame_tokens return num_frames - def get_num_frames_with_most_features(self, seq_len: int) -> int: - mm_config = self.ctx.get_mm_config() - max_images = mm_config.get_limit_per_prompt("image") - max_videos = mm_config.get_limit_per_prompt("video") + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self.get_max_video_frames(seq_len - max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) - num_frames = max(max_total_frames // max(max_videos, 1), 1) - - return num_frames + return max(max_frames_per_video, 1) _I = TypeVar("_I", @@ -483,7 +495,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): video_width, video_height = \ self.info.get_video_frame_size_with_most_features() num_video_frames = \ - self.info.get_num_frames_with_most_features(seq_len) + self.info.get_num_frames_with_most_features(seq_len, mm_counts) mm_data = { "image": diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7537671e1bb82..a7800d4153667 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -806,7 +806,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): max_pixels: Optional[int] = None, size: Optional[dict[str, int]] = None, **kwargs: object, - ): + ) -> Qwen2VLImageProcessor: return cached_image_processor_from_config( self.ctx.model_config, **self._get_image_processor_kwargs(min_pixels=min_pixels, @@ -825,7 +825,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ) -> Mapping[str, int]: return { "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len), + "video": self.get_max_video_tokens(seq_len, mm_counts), } def _get_vision_info( @@ -941,10 +941,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): return num_frames - def get_num_frames_with_most_features(self, seq_len: int) -> int: - mm_config = self.ctx.get_mm_config() - max_images = mm_config.get_limit_per_prompt("image") - max_videos = mm_config.get_limit_per_prompt("video") + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) max_image_tokens = self.get_max_image_tokens() * max_images max_total_frames = self._get_max_video_frames(seq_len - @@ -954,13 +957,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): return max(max_frames_per_video, 1) - def get_max_video_tokens(self, seq_len: int) -> int: + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: target_width, target_height = self.get_image_size_with_most_features() return self.get_num_video_tokens( image_width=target_width, image_height=target_height, - num_frames=self.get_num_frames_with_most_features(seq_len), + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), image_processor=None, ) @@ -982,7 +990,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): target_width, target_height = \ self.info.get_image_size_with_most_features() target_num_frames = \ - self.info.get_num_frames_with_most_features(seq_len) + self.info.get_num_frames_with_most_features(seq_len, mm_counts) mm_data = { "image": diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index e36f8e4434ec6..1df9a1f5eba1c 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, NamedTuple, TypeVar, cast +from typing import Generic, NamedTuple, Optional, TypeVar, cast import numpy as np import numpy.typing as npt @@ -160,17 +160,19 @@ class MultiModalProfiler(Generic[_I]): def get_and_validate_mm_inputs( self, seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, ) -> tuple[MultiModalInputs, Mapping[str, int]]: - mm_counts = self.get_mm_limits() + if mm_counts is None: + mm_counts = self.get_mm_limits() info = self.processing_info mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( seq_len, mm_counts) - if mm_counts.keys() != mm_max_tokens_per_item.keys(): + if mm_counts.keys() - mm_max_tokens_per_item.keys(): raise AssertionError( "The keys returned by `get_supported_mm_limits` " - f"({set(mm_counts.keys())}) should be the same as those " + f"({set(mm_counts.keys())}) should be a subset of those " "returned by `get_mm_max_tokens_per_item` " f"({set(mm_max_tokens_per_item.keys())})") @@ -193,8 +195,12 @@ class MultiModalProfiler(Generic[_I]): "tokens.") return mm_inputs, total_placeholders_by_modality - def get_encoder_dummy_data(self, seq_len: int) -> DummyEncoderData: - mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len) + def get_encoder_dummy_data( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyEncoderData: + mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len, mm_counts) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) # For encoder-decoder models, use encoder prompt token ids instead of @@ -207,9 +213,15 @@ class MultiModalProfiler(Generic[_I]): return DummyEncoderData(encoder_prompt_token_ids) - def get_decoder_dummy_data(self, seq_len: int) -> DummyDecoderData: - (mm_inputs, total_placeholders_by_modality - ) = self.get_and_validate_mm_inputs(seq_len) + def get_decoder_dummy_data( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> DummyDecoderData: + ( + mm_inputs, + total_placeholders_by_modality, + ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 8c16c3ba80750..4f41fa083f63b 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -458,6 +458,7 @@ class MultiModalRegistry: self, model_config: "ModelConfig", seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. @@ -466,7 +467,7 @@ class MultiModalRegistry: """ processor = self.create_processor(model_config, disable_cache=True) profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_decoder_dummy_data(seq_len) + dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids @@ -481,6 +482,7 @@ class MultiModalRegistry: self, model_config: "ModelConfig", seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. @@ -489,7 +491,7 @@ class MultiModalRegistry: """ processor = self.create_processor(model_config, disable_cache=True) profiler = MultiModalProfiler(processor) - dummy_data = profiler.get_encoder_dummy_data(seq_len) + dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4511a9aa85fd3..8071c98b269fd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1470,19 +1470,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): encoder_budget, max_num_mm_items, dummy_data_modality) # Create dummy batch of multimodal inputs. - dummy_request_data = self.mm_registry.get_decoder_dummy_data( + dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, - ) - dummy_mm_data = dummy_request_data.multi_modal_data - - # Dummy data definition may contain multiple multimodal items - # (e.g, multiple images) for a single request, therefore here we - # always replicate first item by max_num_mm_items times since in V1 - # they are scheduled to be processed separately. - dummy_mm_item = dummy_mm_data.get_item( - modality=dummy_data_modality, item_index=0) - dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) + mm_counts={ + dummy_data_modality: 1 + }, + ).multi_modal_data batched_dummy_mm_inputs = MultiModalKwargs.batch( [dummy_mm_kwargs] * max_num_mm_items)