[Refactor] Get prompt updates earlier (#23097)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-18 20:31:53 +08:00 committed by GitHub
parent 5a30bd10d8
commit d3f71f1224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 69 deletions

View File

@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes, BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor(
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only

View File

@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement, from vllm.multimodal.processing import (MultiModalProcessingInfo,
PromptUpdate, PromptUpdateDetails) PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor(
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only

View File

@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes, BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True return prompt_ids, mm_info, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,

View File

@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1) video_grid_sizes = video_grid_thw.prod(-1)
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding num_videos = len(video_grid_sizes)
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
return dict( return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes( input_audio_features=MultiModalFieldConfig.flat_from_sizes(
@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
"video", video_grid_sizes), "video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"), second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
) )
@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if ('audio_feature_lengths' not in hf_inputs if ('audio_feature_lengths' not in hf_inputs
and feature_attention_mask is not None): and feature_attention_mask is not None):
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1) hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)
return hf_inputs return hf_inputs
def _get_mm_fields_config( def _get_mm_fields_config(
@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _maybe_apply_prompt_updates( def _maybe_apply_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int], prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems, mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool, is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
""" """
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
""" """
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
use_audio_in_video = hf_processor_mm_kwargs.get( use_audio_in_video = (all(
"use_audio_in_video", False) item["use_audio_in_video"].data
for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
if is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids) prompt = decode_tokens(tokenizer, prompt_ids)
if use_audio_in_video:
mm_kwargs["use_audio_in_video"] = True
return prompt_ids, prompt, mm_placeholders return prompt_ids, prompt, mm_placeholders
def _get_prompt_updates( def _get_prompt_updates(

View File

@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes, BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]: prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True return prompt_ids, mm_info, True
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
sampling_rate = self.info.get_hf_processor().sampling_rate sampling_rate = self.info.get_hf_processor().sampling_rate

View File

@ -989,6 +989,18 @@ A collection of hashes with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
""" """
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems
hashes: Optional[MultiModalHashes]
prompt_updates: MultiModalPromptUpdates
class BaseMultiModalProcessor(ABC, Generic[_I]): class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: ProcessingCache, cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargsItems, mm_missing_kwargs: MultiModalKwargsItems,
) -> dict[str, list[MultiModalKwargsItem]]: ) -> MultiModalKwargsItems:
mm_missing_next_idx = defaultdict[str, int](lambda: 0) mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
merged_items[modality].append(kw_item) merged_items[modality].append(kw_item)
return dict(merged_items) return MultiModalKwargsItems(merged_items)
def _apply_hf_processor( def _apply_hf_processor(
self, self,
@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]:
( (
prompt_ids, prompt_ids,
mm_processed_data, mm_processed_data,
@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs) tokenization_kwargs)
if return_mm_hashes else None) if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied unbound_prompt_updates = self._get_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_info, is_update_applied
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes], ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
caching the results and reusing cached results. caching the results and reusing cached results.
@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
) )
mm_cache_items_merged = self._merge_mm_kwargs( mm_kwargs = self._merge_mm_kwargs(
cache, cache,
mm_cache_items_or_hashes=mm_cache_items_or_hashes, mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs=mm_missing_kwargs,
) )
mm_kwargs = MultiModalKwargsItems.from_seq([ unbound_prompt_updates = self._get_prompt_updates(
item for cache_items in mm_cache_items_merged.values() mm_data_items,
for item in cache_items hf_processor_mm_kwargs,
]) mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes_to_return,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_info, is_update_applied
def _bind_and_group_updates( def _bind_and_group_updates(
self, self,
@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _maybe_apply_prompt_updates( def _maybe_apply_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int], prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems, mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool, is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_info,
mm_hashes,
is_update_applied, is_update_applied,
) = self._cached_apply_hf_processor( ) = self._cached_apply_hf_processor(
prompt, prompt,
@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied, is_update_applied=is_update_applied,
) )
@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
type="multimodal", type="multimodal",
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_info.kwargs,
mm_hashes=mm_hashes, mm_hashes=mm_info.hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )