From d3f71f1224403fef0d59ef73b894ac51800b8068 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 18 Aug 2025 20:31:53 +0800 Subject: [PATCH] [Refactor] Get prompt updates earlier (#23097) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/deepseek_vl2.py | 6 +- vllm/model_executor/models/h2ovl.py | 8 +- vllm/model_executor/models/pixtral.py | 15 ++-- .../models/qwen2_5_omni_thinker.py | 33 ++++---- vllm/model_executor/models/voxtral.py | 11 ++- vllm/multimodal/processing.py | 80 ++++++++++++------- 6 files changed, 84 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index e881e9c6ddb6..421076348386 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor( tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 2 vs > 2 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 9ab3f4d0d9a1..75ab4dbe7b57 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, MultiModalDataItems) -from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, - PromptUpdate, PromptUpdateDetails) +from vllm.multimodal.processing import (MultiModalProcessingInfo, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.transformers_utils.tokenizer import AnyTokenizer from .intern_vit import InternVisionModel @@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor( tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: # The processor logic is different for len(images) <= 1 vs > 1 # Since the processing cache assumes that the processor output is # invariant of how many images are passed per prompt, we only diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 5427e9a5935c..25be44e3f6e1 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs @@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: - ( - prompt_ids, - mm_kwargs, - mm_hashes, - _, - ) = super()._cached_apply_hf_processor( + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index d43573ea2752..5aadebc33324 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - # vllm use `second_per_grid_ts` to compute multimodal rotary embedding - video_second_per_grid = hf_inputs.get("video_second_per_grid", None) - if video_second_per_grid is not None: - hf_inputs["second_per_grid_ts"] = video_second_per_grid + num_videos = len(video_grid_sizes) return dict( input_audio_features=MultiModalFieldConfig.flat_from_sizes( @@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]): "video", video_grid_sizes), video_grid_thw=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"), + use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos), ) @@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( if ('audio_feature_lengths' not in hf_inputs and feature_attention_mask is not None): hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) + + video_second_per_grid = hf_inputs.get("video_second_per_grid", None) + if video_second_per_grid is not None: + hf_inputs["second_per_grid_ts"] = video_second_per_grid + + use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) + hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video) + return hf_inputs def _get_mm_fields_config( @@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor( def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) - use_audio_in_video = hf_processor_mm_kwargs.get( - "use_audio_in_video", False) + use_audio_in_video = (all( + item["use_audio_in_video"].data + for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) if is_update_applied: mm_placeholders = self._find_mm_placeholders( @@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenizer = self.info.get_tokenizer() prompt = decode_tokens(tokenizer, prompt_ids) - if use_audio_in_video: - mm_kwargs["use_audio_in_video"] = True - return prompt_ids, prompt, mm_placeholders def _get_prompt_updates( diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 70ba561642a2..d0e8e3d39b45 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, MultiModalHashes, + BaseProcessingInfo, + MultiModalProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors @@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: - prompt_ids, mm_kwargs, mm_hashes, _ = super( - )._cached_apply_hf_processor( + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: + prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt=prompt, mm_data_items=mm_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, @@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] ) # NOTE: The tokens are already inserted by the chat template - return prompt_ids, mm_kwargs, mm_hashes, True + return prompt_ids, mm_info, True def _get_data_parser(self) -> MultiModalDataParser: sampling_rate = self.info.get_hf_processor().sampling_rate diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 08113da74ada..e1363b7b0d89 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -989,6 +989,18 @@ A collection of hashes with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ +MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]] +""" +A collection of prompt updates with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + + +class MultiModalProcessingInfo(NamedTuple): + kwargs: MultiModalKwargsItems + hashes: Optional[MultiModalHashes] + prompt_updates: MultiModalPromptUpdates + class BaseMultiModalProcessor(ABC, Generic[_I]): """ @@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache: ProcessingCache, mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], mm_missing_kwargs: MultiModalKwargsItems, - ) -> dict[str, list[MultiModalKwargsItem]]: + ) -> MultiModalKwargsItems: mm_missing_next_idx = defaultdict[str, int](lambda: 0) merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) @@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): merged_items[modality].append(kw_item) - return dict(merged_items) + return MultiModalKwargsItems(merged_items) def _apply_hf_processor( self, @@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: ( prompt_ids, mm_processed_data, @@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs) if return_mm_hashes else None) - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + unbound_prompt_updates = self._get_prompt_updates( + mm_data_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes, + prompt_updates=mm_prompt_updates, + ) + + return prompt_ids, mm_info, is_update_applied def _cached_apply_hf_processor( self, @@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs: Mapping[str, object], *, return_mm_hashes: bool, - ) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], - bool]: + ) -> tuple[list[int], MultiModalProcessingInfo, bool]: """ Apply the HF processor on the full prompt text, caching the results and reusing cached results. @@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs), ) - mm_cache_items_merged = self._merge_mm_kwargs( + mm_kwargs = self._merge_mm_kwargs( cache, mm_cache_items_or_hashes=mm_cache_items_or_hashes, mm_missing_kwargs=mm_missing_kwargs, ) - mm_kwargs = MultiModalKwargsItems.from_seq([ - item for cache_items in mm_cache_items_merged.values() - for item in cache_items - ]) + unbound_prompt_updates = self._get_prompt_updates( + mm_data_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) - return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied + mm_info = MultiModalProcessingInfo( + kwargs=mm_kwargs, + hashes=mm_hashes_to_return, + prompt_updates=mm_prompt_updates, + ) + + return prompt_ids, mm_info, is_update_applied def _bind_and_group_updates( self, @@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], prompt_ids: list[int], mm_kwargs: MultiModalKwargsItems, + mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: - unbound_prompt_updates = self._get_prompt_updates( - mm_items, - hf_processor_mm_kwargs, - mm_kwargs, - ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) - mm_item_counts = mm_items.get_all_counts() self._validate_mm_kwargs(mm_kwargs, mm_item_counts) @@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ( prompt_ids, - mm_kwargs, - mm_hashes, + mm_info, is_update_applied, ) = self._cached_apply_hf_processor( prompt, @@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # NOTE: tokenization_kwargs are not required to init processor prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, prompt_ids=prompt_ids, - mm_kwargs=mm_kwargs, + mm_kwargs=mm_info.kwargs, + mm_prompt_updates=mm_info.prompt_updates, is_update_applied=is_update_applied, ) @@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): type="multimodal", prompt=prompt, prompt_token_ids=prompt_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, + mm_kwargs=mm_info.kwargs, + mm_hashes=mm_info.hashes, mm_placeholders=mm_placeholder_ranges, )