[Refactor] Get prompt updates earlier (#23097)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-18 20:31:53 +08:00 committed by GitHub
parent 5a30bd10d8
commit d3f71f1224
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 69 deletions

View File

@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor(
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only

View File

@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.processing import (MultiModalProcessingInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor(
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only

View File

@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True
return prompt_ids, mm_info, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,

View File

@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
num_videos = len(video_grid_sizes)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
)
@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if ('audio_feature_lengths' not in hf_inputs
and feature_attention_mask is not None):
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)
return hf_inputs
def _get_mm_fields_config(
@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
use_audio_in_video = hf_processor_mm_kwargs.get(
"use_audio_in_video", False)
use_audio_in_video = (all(
item["use_audio_in_video"].data
for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
if use_audio_in_video:
mm_kwargs["use_audio_in_video"] = True
return prompt_ids, prompt, mm_placeholders
def _get_prompt_updates(

View File

@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True
return prompt_ids, mm_info, True
def _get_data_parser(self) -> MultiModalDataParser:
sampling_rate = self.info.get_hf_processor().sampling_rate

View File

@ -989,6 +989,18 @@ A collection of hashes with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems
hashes: Optional[MultiModalHashes]
prompt_updates: MultiModalPromptUpdates
class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargsItems,
) -> dict[str, list[MultiModalKwargsItem]]:
) -> MultiModalKwargsItems:
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
merged_items[modality].append(kw_item)
return dict(merged_items)
return MultiModalKwargsItems(merged_items)
def _apply_hf_processor(
self,
@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
mm_processed_data,
@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs)
if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
unbound_prompt_updates = self._get_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_info, is_update_applied
def _cached_apply_hf_processor(
self,
@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs),
)
mm_cache_items_merged = self._merge_mm_kwargs(
mm_kwargs = self._merge_mm_kwargs(
cache,
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs,
)
mm_kwargs = MultiModalKwargsItems.from_seq([
item for cache_items in mm_cache_items_merged.values()
for item in cache_items
])
unbound_prompt_updates = self._get_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes_to_return,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_info, is_update_applied
def _bind_and_group_updates(
self,
@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
prompt_ids,
mm_kwargs,
mm_hashes,
mm_info,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
prompt_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
mm_placeholders=mm_placeholder_ranges,
)