From 79f05e4436fd97383bfd6319a1e80886bceb0fd3 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Thu, 21 Aug 2025 07:23:28 -0700 Subject: [PATCH] [Multimodal] Always enable hashing mm data (#23308) Signed-off-by: Roger Wang Signed-off-by: DarkLight1337 Co-authored-by: DarkLight1337 --- vllm/config/__init__.py | 9 -- vllm/inputs/preprocess.py | 53 ++------- vllm/model_executor/models/deepseek_vl2.py | 4 - vllm/model_executor/models/h2ovl.py | 4 - vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/mllama.py | 3 +- vllm/model_executor/models/paligemma.py | 3 +- vllm/model_executor/models/pixtral.py | 3 - .../models/prithvi_geospatial_mae.py | 112 +++++++++++------- vllm/model_executor/models/transformers.py | 1 - vllm/model_executor/models/voxtral.py | 3 - vllm/multimodal/hasher.py | 2 +- vllm/multimodal/inputs.py | 2 +- vllm/multimodal/processing.py | 20 +--- vllm/v1/engine/processor.py | 20 +--- 15 files changed, 94 insertions(+), 148 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2973cb92d195b..fbc4dd3989f57 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1685,15 +1685,6 @@ class ModelConfig: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None - @property - def processor_return_mm_hashes(self) -> bool: - """Whether the multi-modal processor should output hashes.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - @property def enable_mm_processor_cache(self) -> bool: """Whether the multi-modal processor cache should be enabled.""" diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index de5dc0876651a..3f521012e82a2 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -254,7 +254,6 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Apply the model's multi-modal processor to a multi-modal prompt, @@ -271,8 +270,7 @@ class InputPreprocessor: return mm_processor.apply(prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + tokenization_kwargs=tokenization_kwargs) async def _process_multimodal_async( self, @@ -281,7 +279,6 @@ class InputPreprocessor: mm_processor_kwargs: Optional[Mapping[str, object]], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Async version of @@ -297,8 +294,7 @@ class InputPreprocessor: return mm_processor.apply(prompt, mm_data, hf_processor_mm_kwargs=mm_processor_kwargs, - tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes) + tokenization_kwargs=tokenization_kwargs) def _process_embeds( self, @@ -335,7 +331,6 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = parsed_content["prompt_token_ids"] token_type_ids = parsed_content.get("token_type_ids") @@ -348,7 +343,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: inputs = token_inputs( @@ -366,7 +360,6 @@ class InputPreprocessor: parsed_content: TokensPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_token_ids = parsed_content["prompt_token_ids"] token_type_ids = parsed_content.get("token_type_ids") @@ -379,7 +372,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: inputs = token_inputs( @@ -397,7 +389,6 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -409,7 +400,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: prompt_token_ids = self._tokenize_prompt( @@ -432,7 +422,6 @@ class InputPreprocessor: parsed_content: TextPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> Union[TokenInputs, MultiModalInputs]: prompt_text = parsed_content["prompt"] @@ -444,7 +433,6 @@ class InputPreprocessor: parsed_content.get("mm_processor_kwargs"), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) else: prompt_token_ids = await self._tokenize_prompt_async( @@ -467,7 +455,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> SingletonInputs: """ Extract the singleton inputs from a prompt. @@ -476,7 +463,6 @@ class InputPreprocessor: * prompt: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts - * return_mm_hashes: whether to return multimodal hashes Returns: @@ -490,21 +476,18 @@ class InputPreprocessor: return self._process_tokens( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "text": return self._process_text( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "str": return self._process_text( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) assert_never(parsed) @@ -514,7 +497,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> SingletonInputs: """ Async version of @@ -528,21 +510,18 @@ class InputPreprocessor: return await self._process_tokens_async( parsed["content"], lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "text": return await self._process_text_async( parsed["content"], tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) if parsed["type"] == "str": return await self._process_text_async( TextPrompt(prompt=parsed["content"]), tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) assert_never(parsed) @@ -785,7 +764,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ For decoder-only models: @@ -796,7 +774,6 @@ class InputPreprocessor: * prompt: input prompt * lora_request - * return_mm_hashes Returns: @@ -807,7 +784,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -817,7 +793,6 @@ class InputPreprocessor: prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ Async version of @@ -827,7 +802,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) return self._build_decoder_only_llm_inputs(prompt_comps) @@ -837,17 +811,15 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> ProcessorInputs: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder + # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( - prompt, tokenization_kwargs) + prompt, + tokenization_kwargs, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -858,7 +830,6 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) async def preprocess_async( @@ -866,19 +837,18 @@ class InputPreprocessor: prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - return_mm_hashes: bool = False, ) -> ProcessorInputs: """ Async version of [`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess]. """ if self.model_config.is_encoder_decoder: - assert not return_mm_hashes, ( - "Multimodal hashes for encoder-decoder models should not be ", - "returned until they are supported on vLLM V1.") # Encoder-decoder model requires special mapping of - # input prompts to encoder & decoder - return await self._process_encoder_decoder_prompt_async(prompt) + # input prompts to encoder & decoder. + return await self._process_encoder_decoder_prompt_async( + prompt, + tokenization_kwargs, + ) if is_explicit_encoder_decoder_prompt(prompt): raise ValueError("Cannot pass encoder-decoder prompt " @@ -889,5 +859,4 @@ class InputPreprocessor: prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 421076348386b..ceb5e1364b68d 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -290,8 +290,6 @@ class DeepseekVL2MultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is @@ -303,7 +301,6 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( @@ -311,7 +308,6 @@ class DeepseekVL2MultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 75ab4dbe7b57d..87e451a2769ea 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -479,8 +479,6 @@ class H2OVLMultiModalProcessor( mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is @@ -492,7 +490,6 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) return super()._cached_apply_hf_processor( @@ -500,7 +497,6 @@ class H2OVLMultiModalProcessor( mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3caaaa9f7d1e3..cd41d4fb43885 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -795,7 +795,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -807,7 +806,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 9d2ac771474e5..bb3267ce5b004 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -168,10 +168,9 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) image_token_id = self.info.get_hf_config().image_token_index # Check that the number of image tokens in the decoder prompt matches diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index f15e7a17d5d4d..7d6a6207c7c89 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -194,10 +194,9 @@ class PaliGemmaMultiModalProcessor( mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, - tokenization_kwargs, return_mm_hashes) + tokenization_kwargs) prompt_token_ids = mm_inputs["prompt_token_ids"] tokenizer = self.info.get_tokenizer() diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 25be44e3f6e13..c01074e2122bb 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -308,15 +308,12 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: The tokens are already inserted by the chat template diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 442596a6b555c..59e9f3e8a47b0 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -18,7 +18,7 @@ """Inference-only IBM/NASA Prithvi Geospatial model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn as nn @@ -32,18 +32,56 @@ from vllm.model_executor.models.interfaces import ( default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalFieldElem, MultiModalInputs, - MultiModalKwargsItem, - MultiModalKwargsItems, - MultiModalSharedField, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, MultiModalKwargsItems, + PlaceholderRange) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): + # This model receives in input a multi-dimensional tensor representing + # a single image patch and therefore it is not to be split + # into multiple elements, but rather to be considered a single one. + # Hence, the decision of using a MultiModalSharedField. + # The expected shape is (num_channels, width, height). + + # This model however allows the user to also submit multiple image + # patches as a batch, adding a further dimension to the above shape. + # At this stage we only support submitting one patch per request and + # batching is achieved via vLLM batching. + # TODO (christian-pinto): enable support for multi patch requests + # in tandem with vLLM batching. + return dict( + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + ) + + +class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={"pixel_values", "location_coords"}, + fields_factory=_prithvi_field_config, + ) + + return super()._parse_image_data(data) + + class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: @@ -64,26 +102,26 @@ class PrithviGeoSpatialMAEInputBuilder( # This model input is fixed and is in the form of a torch Tensor. # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. - return { + image_data = { "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } + return {"image": image_data} + class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): + def _get_data_parser(self) -> MultiModalDataParser: + return PrithviGeoSpatialMAEMultiModalDataParser() + def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) + return _prithvi_field_config(hf_inputs) def _get_prompt_updates( self, @@ -99,46 +137,32 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: - mm_kwargs = {} + if "image" in mm_data: + image_data = mm_data["image"] + else: + image_data = mm_data + mm_data = {"image": mm_data} - for k, v in mm_data.items(): - if isinstance(v, dict) and k == "image": - mm_kwargs.update(v) - else: - mm_kwargs[k] = v + mm_items = self._to_mm_items(mm_data) + mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs or {}) mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). + mm_processed_data = BatchFeature(image_data) - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems([ - MultiModalFieldElem( - modality="image", - key=key, - data=data, - field=MultiModalSharedField(1), - ) for key, data in mm_kwargs.items() - ]) - ] + mm_kwargs = MultiModalKwargsItems.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargsItems.from_seq(multimodal_kwargs_items), - mm_hashes=None, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, mm_placeholders=mm_placeholders, ) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index ed9d6c0ab4ce4..fc242d1adafd0 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -310,7 +310,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index d0e8e3d39b451..77f11a691e080 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -288,15 +288,12 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: The tokens are already inserted by the chat template diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index c9ce1f0be5f88..210a4ec762879 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -43,7 +43,7 @@ class MultiModalHasher: return cls.item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): - return cls.item_to_bytes("tensor", obj.numpy()) + return cls.item_to_bytes("tensor", obj.cpu().numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index d46d81fe14484..581f9a109cce6 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -901,7 +901,7 @@ class MultiModalInputs(TypedDict): mm_kwargs: MultiModalKwargsItems """Keyword arguments to be directly passed to the model after batching.""" - mm_hashes: Optional["MultiModalHashDict"] + mm_hashes: "MultiModalHashDict" """The hashes of the multi-modal data.""" mm_placeholders: "MultiModalPlaceholderDict" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index e1363b7b0d891..55fd1479d2de5 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -998,7 +998,7 @@ A collection of prompt updates with a similar structure as class MultiModalProcessingInfo(NamedTuple): kwargs: MultiModalKwargsItems - hashes: Optional[MultiModalHashes] + hashes: MultiModalHashes prompt_updates: MultiModalPromptUpdates @@ -1399,8 +1399,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, @@ -1420,9 +1418,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs), ) - mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, - tokenization_kwargs) - if return_mm_hashes else None) + mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, + tokenization_kwargs) unbound_prompt_updates = self._get_prompt_updates( mm_data_items, @@ -1446,8 +1443,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - *, - return_mm_hashes: bool, ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, @@ -1462,7 +1457,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, @@ -1476,8 +1470,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_hashes=mm_hashes, ) - mm_hashes_to_return = mm_hashes if return_mm_hashes else None - # NOTE: `prompt` does not correspond to `mm_missing_data_items`, # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items @@ -1515,7 +1507,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, - hashes=mm_hashes_to_return, + hashes=mm_hashes, prompt_updates=mm_prompt_updates, ) @@ -1697,7 +1689,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1726,7 +1717,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items, hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, - return_mm_hashes=return_mm_hashes, ) # NOTE: tokenization_kwargs are not required to init processor @@ -1811,7 +1801,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Optional[Mapping[str, object]] = None, - return_mm_hashes: bool = False, ) -> MultiModalEncDecInputs: """ Process multi-modal inputs to be used in vLLM. @@ -1826,7 +1815,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): mm_data, hf_processor_mm_kwargs, tokenization_kwargs, - return_mm_hashes, ) return self._get_enc_dec_inputs( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 97d79c2ae0931..69f8e531e01b1 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -17,7 +17,6 @@ from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from vllm.utils import is_list_of from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( @@ -253,13 +252,10 @@ class Processor: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - return_mm_hashes = (self.model_config.processor_return_mm_hashes - or bool(self.cache_config.enable_prefix_caching)) processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - return_mm_hashes=return_mm_hashes, ) from vllm.platforms import current_platform current_platform.validate_request( @@ -302,7 +298,7 @@ class Processor: if decoder_inputs["type"] == "multimodal": decoder_mm_inputs = decoder_inputs["mm_kwargs"] decoder_mm_positions = decoder_inputs["mm_placeholders"] - decoder_mm_hashes = decoder_inputs.get("mm_hashes") + decoder_mm_hashes = decoder_inputs["mm_hashes"] # Merge and flatten multimodal placeholders, hashes and inputs # from dictionaries to lists, and sort them by each item's position @@ -317,19 +313,15 @@ class Processor: decoder_mm_positions[modality][idx] for modality, idx in sorted_mm_idxs ] - sorted_mm_hashes = None if decoder_mm_hashes is None else [ + sorted_mm_hashes = [ decoder_mm_hashes[modality][idx] for modality, idx in sorted_mm_idxs ] - if sorted_mm_hashes is not None: - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - orig_sorted_mm_inputs, - sorted_mm_hashes, - ) - else: - assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem) - sorted_mm_inputs = orig_sorted_mm_inputs + sorted_mm_inputs = self.mm_input_cache_client.get_and_update( + orig_sorted_mm_inputs, + sorted_mm_hashes, + ) return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id,