diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 0c7496334fb7..03d830fe90f1 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -121,17 +121,21 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": 1} ``` -### Maximum number of placeholder feature tokens +## 3. Specify dummy inputs -Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item` -to return the maximum number of placeholder feature tokens per input item for each modality. +Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for +HF processing as well as memory profiling. -When calling the model, the output embeddings from the visual encoder are assigned to the input positions -containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal -to the size of the output embeddings. +### For memory profiling -:::::{tab-set} -::::{tab-item} Basic example: LLaVA +Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` +to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of +the model so that vLLM can reserve the correct amount of memory for it. + +Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA :sync: llava Looking at the code of HF's `LlavaForConditionalGeneration`: @@ -240,7 +244,7 @@ def get_num_image_tokens( ``` Notice that the number of image tokens doesn't depend on the image width and height. -So, we can calculate the maximum number of image tokens using any image size: +We can simply use a dummy `image_size`: ```python def get_image_size_with_most_features(self) -> ImageSize: @@ -248,33 +252,35 @@ def get_image_size_with_most_features(self) -> ImageSize: width = height = hf_config.image_size return ImageSize(width=width, height=height) -def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) -``` - -And thus, we can override the method as: - -```python -def get_mm_max_tokens_per_item( +def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], -) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} +) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + hf_config = self.get_hf_config() + target_width, target_height = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) ``` -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP. ::: -:::: - -::::{tab-item} Non-consecutive feature tokens: Fuyu +:::{tab-item} No input placeholders: Fuyu :sync: fuyu Looking at the code of HF's `FuyuForCausalLM`: @@ -394,188 +400,16 @@ num_patches_per_dim_w = image_width // patch_width num_patches = num_patches_per_dim_h * num_patches_per_dim_w ``` -We can calculate this in vLLM using this code: - -```python -def get_num_image_patches( - self, - *, - image_width: int, - image_height: int, -) -> int: - image_processor = self.get_image_processor() - target_width = image_processor.size["width"] - target_height = image_processor.size["height"] - patch_width = image_processor.patch_size["width"] - patch_height = image_processor.patch_size["height"] - - if not (image_width <= target_width and image_height <= target_height): - height_scale_factor = target_height / image_height - width_scale_factor = target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) - - image_height = int(image_height * optimal_scale_factor) - image_width = int(image_width * optimal_scale_factor) - - ncols = math.ceil(image_width / patch_width) - nrows = math.ceil(image_height / patch_height) - return ncols * nrows -``` - -These image patches correspond to placeholder tokens (`|SPEAKER|`). However, the processor also -inserts newline tokens (`|NEWLINE|`) as shown here: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L654-L670 -tensor_of_image_ids = torch.full( - [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device -) -patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) -assert num_patches == patches.shape[0] - -if variable_sized: - # Now terminate each line with |NEWLINE|. - tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width) - newline_ids = torch.full( - [tensor_of_image_ids.shape[0], 1], - image_newline_id, - dtype=torch.int32, - device=image_input.device, - ) - tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1) - tensor_of_image_ids = tensor_of_image_ids.reshape(-1) -``` - -So, the layout of tokens for an image is: - -``` -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -... -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -``` - -This makes the placeholder tokens non-consecutive in the prompt. -Since vLLM requires the feature tokens to be consecutive, **we also treat the newline tokens as feature tokens**. - -So overall, the total number of feature tokens is - -```python -def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, -) -> int: - image_processor = self.get_image_processor() - target_width = image_processor.size["width"] - target_height = image_processor.size["height"] - patch_width = image_processor.patch_size["width"] - patch_height = image_processor.patch_size["height"] - - if not (image_width <= target_width and image_height <= target_height): - height_scale_factor = target_height / image_height - width_scale_factor = target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) - - image_height = int(image_height * optimal_scale_factor) - image_width = int(image_width * optimal_scale_factor) - - ncols = math.ceil(image_width / patch_width) - nrows = math.ceil(image_height / patch_height) - return (ncols + 1) * nrows -``` - -To calculate the maximum number of image tokens, recall that input images are first resized -to fit within `image_processor.size`. The maximum possible dimensions of the image before -being converted into patches is therefore equal to `image_processor.size`. +These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized +to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`. ```python def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) - -def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) ``` -And thus, we can override the method as: - -```python -def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} -``` - -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) returns `ncols` and `nrows` directly instead of the total token count. -This is because `ncols` and `nrows` are used to specify the layout of the feature tokens (as shown in Step 4 of this guide). -::: - -:::: -::::: - -## 3. Specify dummy inputs - -Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for -HF processing as well as memory profiling. - -### For memory profiling - -Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` -to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of -the model so that vLLM can reserve the correct amount of memory for it. - -Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based -on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`. - -::::{tab-set} -:::{tab-item} Basic example: LLaVA -:sync: llava - -Making use of the `get_image_size_with_most_features` method implemented in Step 2: - -```python -def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.image_token - - hf_config = self.get_hf_config() - target_width, target_height = self.info.get_image_size_with_most_features() - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) -``` - -::: - -:::{tab-item} No input placeholders: Fuyu -:sync: fuyu - Fuyu does not expect image placeholders in the inputs to HF processor, so the dummy prompt text is empty regardless of the number of images. Otherwise, the logic of this method is very similar to LLaVA: diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 578dcd4a4445..2bfc2785feb6 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -76,11 +76,6 @@ def test_processor_override( if v == config.boi_token_index] # patch sizes and masks - patch_token_id = vocab[hf_processor.img_patch_token] - num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) - mm_counts = {"image": num_imgs} - assert num_patches / num_imgs <= \ - processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] num_patches_per_chunk = processor.info.get_patch_per_chunk( config.vision_config) assert prompt_token_ids.count(config.image_token_index) \ diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index af340feffcf9..23b8ef89268d 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -408,13 +408,6 @@ class AriaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() return max(hf_config.projector_patch_to_query_dict.values()) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 929c8f2a82a2..cdec31602503 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -117,31 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): def get_image_processor(self) -> GotOcr2ImageProcessor: return self.get_hf_processor().image_processor - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - - def get_max_image_tokens(self) -> int: - hf_processor = self.get_hf_processor() - image_processor = hf_processor.image_processor - - image_size = self.get_image_size_with_most_features() - num_patches = self.get_num_patches( - image_width=image_size.width, - image_height=image_size.height, - size=image_processor.size, - min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches, - ) - - img_patches_per_tile = (hf_processor.img_size // - hf_processor.patch_size)**2 - - return num_patches * img_patches_per_tile - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a1f20ea4e614..dde78ee52a3d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -406,13 +406,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() return hf_config.num_query_tokens diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index d46ae5327dcb..fb2f4b677c5a 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -64,13 +64,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: processor = self.get_hf_processor() return processor.image_seq_length diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index dc3aa9cbe86b..153054e5c028 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): ) -> int: return self.get_patch_grid_length()**2 + 1 - def get_max_image_tokens(self) -> int: - return self.get_patch_grid_length()**2 + 1 - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 03d5be2927bb..951185bc9bd0 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -168,20 +168,6 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): image_width=x[1], image_height=x[0])) return ImageSize(width=width, height=height) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - num_images = mm_counts.get("image", 0) - max_image_size = self.get_image_size_with_most_features() - max_image_tokens = self.get_num_image_tokens( - image_height=max_image_size.height, - image_width=max_image_size.width, - cropping=num_images <= 2) - - return {"image": max_image_tokens} - class DeepseekVL2DummyInputsBuilder( BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 62fd09398fac..56572bd59a35 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -764,17 +764,10 @@ class Florence2ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_max_image_tokens(self) -> int: + def get_num_image_tokens(self) -> int: processor_config = self.ctx.get_hf_image_processor_config() return processor_config["image_seq_length"] - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - class Florence2DummyInputsBuilder( BaseDummyInputsBuilder[Florence2ProcessingInfo]): @@ -871,7 +864,7 @@ class Florence2MultiModalProcessor( ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id - num_image_tokens = self.info.get_max_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() image_tokens = [pad_token_id] * num_image_tokens return [ diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index c0a0f572ff3c..5fc6bb846388 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -80,13 +80,6 @@ class FuyuProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_image_feature_grid_size( self, *, @@ -129,14 +122,6 @@ class FuyuProcessingInfo(BaseProcessingInfo): return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 93d0aa301f54..34d856f4b203 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -68,13 +68,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _resolve_image_kwargs( self, processor: Gemma3Processor, @@ -228,15 +221,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): # Result in the max possible feature size (h:w = max_num_crops:1) return ImageSize(height=50 * max_num_crops, width=50) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 6d7b760d0dd7..02954eecc42c 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_feature_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() vision_config = hf_config.vision_config diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index f975a19a364e..15e126b0f4ce 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): **kwargs, ) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - max_tokens_one_image = self.get_max_image_tokens(use_msac=None) - if mm_counts.get("image", 0) <= 1: - max_tokens_per_image = max_tokens_one_image - else: - max_tokens_per_image = self.get_max_image_tokens(use_msac=False) - - return {"image": max_tokens_per_image} - def get_num_image_tokens( self, *, @@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): use_msac=use_msac, ) - def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - use_msac=use_msac, - ) - class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index ec02d1c8862a..655db1c85634 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -97,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _resize_output_size(self, *, height: int, @@ -287,15 +280,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): height=image_processor.size["longest_edge"], ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7fd628fa6c38..08741b3a3c11 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -458,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -480,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): image_height=image_height, ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 9516550005d5..5804cb4419b6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -137,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _apply_feature_select_strategy( self, strategy: str, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 6fc4c187efa7..281c9c0e8ebe 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -61,22 +61,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"video": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - target_width, target_height = self.get_image_size_with_most_features() - - max_video_tokens = self.get_num_video_tokens( - image_width=target_width, - image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), - ) - - return {"video": max_video_tokens} - def get_image_size_with_most_features(self) -> ImageSize: vision_encoder_info = self.get_vision_encoder_info() width = height = vision_encoder_info.get_image_size() diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 5fbd27b9b0b3..f6256771d982 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len, mm_counts), - } - # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor def _get_num_unpadded_features( diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index a4fb0cb1741e..8bb41a108b5a 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -142,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {**super().get_supported_mm_limits(), "audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - **super().get_mm_max_tokens_per_item(seq_len, mm_counts), - "audio": - self.get_max_audio_tokens(), - } - def get_audio_placeholder( self, audio_lens: int, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 12b5364cbaf8..87c690219583 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -346,18 +346,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): return mm_limits - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - mm_max_tokens = {"image": self.get_max_image_tokens()} - if self.get_model_version() == (2, 6): - mm_max_tokens["video"] = self.get_max_video_tokens( - seq_len, mm_counts) - - return mm_max_tokens - def get_slice_image_placeholder( self, image_size: ImageSize, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 67c0e2ec233b..d2c600feb4b2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d332b17f910e..b61e42f31d88 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo): image_size = self.get_hf_config().vision_config.image_size return calc_token_per_chunk(image_size) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - vision_config = self.get_hf_config().vision_config - token_per_chunk = self.get_token_per_chunk_from_config() - mm_max_tokens = vision_config.max_num_tiles * token_per_chunk - return {"image": mm_max_tokens} - def get_num_tiles_per_image(self, image_height: int, image_width: int) -> int: vision_config = self.get_hf_config().vision_config diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 17171f823cb0..4f709751ae62 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): image_processor = self.get_hf_processor().image_processor return image_processor.max_patches - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - vision_config = self.get_hf_config().vision_config - patch_per_chunk = self.get_patch_per_chunk(vision_config) - num_patches = self.get_max_num_tiles() + 1 - - return {"image": patch_per_chunk * num_patches} - def get_image_size_with_most_features(self) -> ImageSize: vision_config = self.get_hf_config().vision_config image_size = vision_config.image_size @@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a7551e613dfc..d896431b166b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): return extra + joint - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 274163ac9c42..ae8eee4515e0 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs) -from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate, @@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( + def get_num_image_tokens( self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - - def get_num_image_tokens(self) -> int: + *, + image_width: int, + image_height: int, + ) -> int: vision_encoder_info = self.get_vision_encoder_info() - return vision_encoder_info.get_max_image_tokens() + + return vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) class PaliGemmaDummyInputsBuilder( @@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor( image_token_id = hf_config.image_token_index tokenizer = self.info.get_tokenizer() - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [image_token_id] * num_image_tokens bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) + def get_insertion(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + image_tokens = [image_token_id] * num_image_tokens + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=image_token_id, + ) + # Paligemma 1 and 2 have different tokenizer.add_bos_token # Insert *n + after for Paligemma 1 # Insert *n + for Paligemma 2 @@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor( modality="image", target=PromptIndexTargets.prefix( [bos_token_id] if tokenizer.add_bos_token else []), - insertion=PromptUpdateDetails.select_token_id( - image_tokens + [bos_token_id], - embed_token_id=image_token_id, - ), + insertion=get_insertion, ) ] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 344f348cd3d9..cce700f02f59 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - target_width, target_height = self.get_image_size_with_most_features() - - max_image_tokens = self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - - return {"image": max_image_tokens} - def get_num_image_tokens( self, *, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 328d52711b5e..fdd342ccf6b5 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_vision_config( self, processor: Optional[PixtralProcessorAdapter] = None, @@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): return ImageSize(width=max_image_size, height=max_image_size) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): @@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ) return ncols * nrows - def get_max_image_tokens(self) -> int: - image_size = self.get_image_size() - - return self.get_num_image_tokens( - image_width=image_size, - image_height=image_size, - ) - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a69c0fc54e4c..e3a93e95530c 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": 0} - class PrithviGeoSpatialMAEInputBuilder( BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 9f2593fc94f4..ba4646f5583f 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - hf_config = self.get_hf_config() - max_source_positions = hf_config.audio_config.max_source_positions - max_output_lengths = (max_source_positions - 2) // 2 + 1 - - return {"audio": max_output_lengths} - class Qwen2AudioDummyInputsBuilder( BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f93654d0fcb3..23f27e7ef9fb 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len, mm_counts), - } - def _get_vision_info( self, *, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 2e941f3b7a31..403d47a39d17 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() vision_config = hf_config.visual diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cecad9e8935e..75fcf540b0b1 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ) -> int: return self.get_patch_grid_length()**2 - def get_max_image_tokens(self) -> int: - return self.get_patch_grid_length()**2 - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index a8460a2e1043..09a212a9face 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): image_height=image_height, ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6e9d15261b79..3ff5a0516b65 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union @@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - feature_extractor = self.get_feature_extractor() - max_audio_tokens = math.ceil(feature_extractor.chunk_length * - _AUDIO_TOKENS_PER_SECOND) - - return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE} - class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] ): diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 347f51499b7b..05e3b3f3ccdf 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]): ) -> int: raise NotImplementedError - @abstractmethod - def get_max_image_tokens(self) -> int: - raise NotImplementedError - @abstractmethod def get_image_size(self) -> int: raise NotImplementedError diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 7751f96da6ae..341e22a4a8bb 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo): assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_max_audio_tokens(self) -> int: + def get_num_audio_tokens(self) -> int: return self.get_hf_config().max_source_positions - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"audio": self.get_max_audio_tokens()} - class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): @@ -630,7 +623,7 @@ class WhisperMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: - num_tokens = self.info.get_max_audio_tokens() + num_tokens = self.info.get_num_audio_tokens() return [ PromptReplacement( modality="audio", diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 64f657db94bb..fefeefd21375 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1034,21 +1034,6 @@ class BaseProcessingInfo: """ raise NotImplementedError - @abstractmethod - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - """ - Get the maximum possible number of tokens per data item - for each modality. - - The dictionary returned by this method should have the same - keys as that returned by :meth:`get_supported_mm_limits`. - """ - raise NotImplementedError - _I = TypeVar("_I", bound=BaseProcessingInfo) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index ec3625f2f426..7efe86448fdd 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ) -> ProcessorInputs: """ Build the input which, after processing, results in - :code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens. + the maximum possible number of placeholder tokens. """ raise NotImplementedError @@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]): def _get_dummy_mm_inputs( self, seq_len: int, - mm_counts: Mapping[str, int], + mm_counts: Optional[Mapping[str, int]] = None, ) -> MultiModalInputs: + if mm_counts is None: + mm_counts = self.get_mm_limits() + factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( seq_len, mm_counts) @@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]): hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def get_and_validate_mm_inputs( + def _get_mm_num_tokens( self, - seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, - ) -> tuple[MultiModalInputs, Mapping[str, int]]: - if mm_counts is None: - mm_counts = self.get_mm_limits() - - info = self.processing_info - mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( - seq_len, mm_counts) - - if mm_counts.keys() - mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits` " - f"({set(mm_counts.keys())}) should be a subset of those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + mm_inputs: MultiModalInputs, + ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] - total_placeholders_by_modality = { + return { modality: sum(item.get_num_embeds() for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - return mm_inputs, total_placeholders_by_modality def get_encoder_dummy_data( self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyEncoderData: - ( - mm_inputs, - total_placeholders_by_modality, - ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) # For encoder-decoder models, use encoder prompt token ids instead of @@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]): " is too short " "to hold the multi-modal embeddings in the worst case " f"({total_len} tokens in total, out of which " - f"{total_placeholders_by_modality} are reserved for " + f"{self._get_mm_num_tokens(mm_inputs)} are reserved for " "multi-modal embeddings). This may cause certain " "multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should " @@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]): seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyDecoderData: - ( - mm_inputs, - total_placeholders_by_modality, - ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) @@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]): "is too short " "to hold the multi-modal embeddings in the worst case " f"({total_len} tokens in total, out of which " - f"{total_placeholders_by_modality} are reserved for " + f"{self._get_mm_num_tokens(mm_inputs)} are reserved for " "multi-modal embeddings). This may cause certain " "multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should " @@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]): multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_placeholders=mm_inputs["mm_placeholders"], ) + + def get_mm_max_tokens( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> Mapping[str, int]: + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + + return self._get_mm_num_tokens(mm_inputs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4f41fa083f63..eafa28d612a6 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -258,10 +258,16 @@ class MultiModalRegistry: """ if self.has_processor(model_config): processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) - return processor.info.get_mm_max_tokens_per_item( - seq_len, mm_limits) + + return profiler.get_mm_max_tokens( + seq_len, + {modality: 1 + for modality in mm_limits}, + ) return { key: plugin.get_max_multimodal_tokens(model_config)