diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index b9fed79c84cdd..a3af541d20676 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], BatchFeature, bool]: """ Qwen2.5-Omni reimplements this function to handle text only. """ @@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor( else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) - mm_kwargs = self._apply_hf_processor_mm_only( + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, False + return prompt_ids, mm_processed_data, False def _apply_hf_processor_mm_only( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> MultiModalKwargs: + ) -> BatchFeature: """ Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. """ @@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor( assert "audio" in mm_counts mm_counts["audio"] -= mm_counts["video"] - _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + _, mm_processed_data, _ = self._apply_hf_processor_text_mm( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return mm_kwargs + return mm_processed_data def _validate_mm_placeholders( self, diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 0c3df267edb11..92e132045c278 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -22,7 +22,8 @@ from typing import Literal, Optional, Union import regex as re import torch from torch import nn -from transformers import AutoModel, PretrainedConfig, PreTrainedModel +from transformers import (AutoModel, BatchFeature, PretrainedConfig, + PreTrainedModel) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from vllm.attention import Attention @@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ): + ) -> tuple[list[int], BatchFeature, bool]: """ Apply the HF processor on the prompt text and multi-modal data together. diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 0378539495fda..38c5d5d99f63e 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby +from vllm.utils import flatten_2d_lists, full_groupby from .cache import MultiModalCache from .hasher import MultiModalHasher @@ -887,120 +887,19 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -class ProcessingCacheOptionalItem(NamedTuple): - key: str - value: Optional[MultiModalKwargsItem] - - -class ProcessingCacheItem(NamedTuple): - key: str - value: MultiModalKwargsItem - - class ProcessingCache(MultiModalCache): - def __init__( - self, - capacity_gb: float, - *, - debug_cache_hit_ratio_steps: Optional[int] = None, - ) -> None: + def __init__(self, capacity_gb: float) -> None: super().__init__() - self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps - self.debug_cache_hits = 0 - self.debug_cache_total = 0 + self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - self._cache = self.get_lru_cache( - capacity_gb, - MultiModalKwargsItem, - debug=bool(debug_cache_hit_ratio_steps), - ) + self.get = self._cache.get + self.put = self._cache.put + self.reset = self._cache.clear - def _maybe_log_cache_stats(self) -> None: - steps = self.debug_cache_hit_ratio_steps - if not steps: - return - total = self.debug_cache_total - if total > 0 and total % steps == 0: - logger.debug("ProcessingCache: hit_ratio = %.2f", - self.debug_cache_hits / total) - logger.debug("ProcessingCache: size = %.2f / %.2f GiB", - self._cache.currsize / GiB_bytes, - self._cache.maxsize / GiB_bytes) - - def get( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - ) -> Optional[MultiModalKwargsItem]: - """ - Get a processed multi-modal item from the cache - according to its dependencies, including: - - - The model ID - - The modality of the item - - The original data item passed to the HF processor - - The configuration options of the HF processor - """ - self._maybe_log_cache_stats() - - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - - if self.debug_cache_hit_ratio_steps: - if cache_key in self._cache: - self.debug_cache_hits += 1 - - self.debug_cache_total += 1 - - return self._cache.get(cache_key) - - def get_item( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - ) -> ProcessingCacheOptionalItem: - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - - return ProcessingCacheOptionalItem( - key=cache_key, - value=self._cache.get(cache_key), - ) - - def put( - self, - model_id: str, - modality: str, - input_item: object, - input_kwargs: Mapping[str, object], - output_kwargs: MultiModalKwargsItem, - ) -> None: - """ - Put a processed multi-modal item into the cache - according to its dependencies - (see [`get`][vllm.multimodal.processing.ProcessingCache.get]). - """ - cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, - **{modality: input_item}, - **input_kwargs) - self._cache[cache_key] = output_kwargs - - def put_item(self, item: ProcessingCacheItem) -> None: - self._cache[item.key] = item.value - - def reset(self) -> bool: - self._cache.clear() - - return True +_CacheItemOrHash = Union[MultiModalKwargsItem, str] class BaseProcessingInfo: @@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], "BatchFeature", bool]: """ Apply the HF processor on the prompt text and multi-modal data together. @@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): prompt_ids, = processed_data.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs.from_hf_inputs( - processed_data, - self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), - ) - is_update_applied = self._hf_processor_applies_updates( prompt_text=prompt_text, mm_items=mm_items, @@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, is_update_applied + return prompt_ids, processed_data, is_update_applied def _apply_hf_processor_text_only( - self, prompt_text: str, - tokenization_kwargs: Mapping[str, object]) -> list[int]: + self, + prompt_text: str, + tokenization_kwargs: Mapping[str, object], + ) -> list[int]: """ Apply the HF processor on the prompt text only. @@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object], - ) -> MultiModalKwargs: + ) -> "BatchFeature": """ Apply the HF processor on the multi-modal data only. @@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ mm_counts = mm_items.get_all_counts() - _, mm_kwargs, _ = self._apply_hf_processor_text_mm( + _, mm_processed_data, _ = self._apply_hf_processor_text_mm( prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return mm_kwargs + return mm_processed_data def _apply_hf_processor_main( self, @@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): tokenization_kwargs: Mapping[str, object], *, enable_hf_prompt_update: bool, - ) -> tuple[list[int], MultiModalKwargs, bool]: + ) -> tuple[list[int], "BatchFeature", bool]: """ Apply the HF processor on the prompt text and multi-modal data. @@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): else: prompt_ids = self._apply_hf_processor_tokens_only(prompt) - mm_kwargs = self._apply_hf_processor_mm_only( + mm_processed_data = self._apply_hf_processor_mm_only( mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, ) - return prompt_ids, mm_kwargs, False + return prompt_ids, mm_processed_data, False def _get_cache_missing_items( self, cache: ProcessingCache, mm_data_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], - ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ - str, list[object]]]: - model_id = self.info.model_id - - mm_cache_items = { - modality: [ - cache.get_item( - model_id, modality, item, - dict(**hf_processor_mm_kwargs, **tokenization_kwargs)) - for item in items - ] - for modality, items in mm_data_items.items() + mm_hashes: MultiModalHashes, + ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: + mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { + modality: [(h if (v := cache.get(h)) is None else v) + for h in hashes] + for modality, hashes in mm_hashes.items() } mm_missing_idxs = { modality: [ - idx for idx, item in enumerate(cache_items) - if item.value is None + idx for idx, item_or_hash in enumerate(items_or_hashes) + if isinstance(item_or_hash, str) ] - for modality, cache_items in mm_cache_items.items() + for modality, items_or_hashes in mm_cache_items_or_hashes.items() } mm_missing_data = { modality: [mm_data_items[modality][idx] for idx in idxs] for modality, idxs in mm_missing_idxs.items() } - return mm_cache_items, mm_missing_data + return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) def _hash_mm_items( - self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> MultiModalHashes: """Create MM hashes to be returned (only used in V1).""" model_id = self.info.model_id @@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _merge_mm_kwargs( self, cache: ProcessingCache, - mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]], - mm_missing_data: dict[str, list[object]], + mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], mm_missing_kwargs: MultiModalKwargs, - ) -> dict[str, list[ProcessingCacheItem]]: - mm_missing_next_idx = {modality: 0 for modality in mm_missing_data} + ) -> dict[str, list[MultiModalKwargsItem]]: + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[ProcessingCacheItem]](list) - for modality, cache_items in mm_cache_items.items(): - for cache_item in cache_items: - if cache_item.value is None: + merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) + for modality, items_or_hashes in mm_cache_items_or_hashes.items(): + for item_or_hash in items_or_hashes: + if isinstance(item_or_hash, str): kw_item = mm_missing_kwargs.get_item( modality, mm_missing_next_idx[modality], ) - cache_item_new = ProcessingCacheItem( - key=cache_item.key, - value=kw_item, - ) - - cache.put_item(cache_item_new) + cache.put(item_or_hash, kw_item) mm_missing_next_idx[modality] += 1 else: - cache_item_new = ProcessingCacheItem( - key=cache_item.key, - value=cache_item.value, - ) + kw_item = item_or_hash - merged_items[modality].append(cache_item_new) + merged_items[modality].append(kw_item) return dict(merged_items) @@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ( prompt_ids, - mm_kwargs, + mm_processed_data, is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, @@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): enable_hf_prompt_update=True, ) + mm_kwargs = MultiModalKwargs.from_hf_inputs( + mm_processed_data, + self._get_mm_fields_config(mm_processed_data, + hf_processor_mm_kwargs), + ) + mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs) if return_mm_hashes else None) @@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return_mm_hashes=return_mm_hashes, ) + mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, + tokenization_kwargs) ( - mm_cache_items, - mm_missing_data, + mm_cache_items_or_hashes, + mm_missing_data_items, ) = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - tokenization_kwargs=tokenization_kwargs, + mm_hashes=mm_hashes, ) + mm_hashes_to_return = mm_hashes if return_mm_hashes else None + # NOTE: `prompt` does not correspond to `mm_missing_data_items`, # so we can't apply prompt updates until the new multimodal # items are combined with the cached multimodal items ( prompt_ids, - mm_missing_kwargs, + mm_missing_processed_data, is_update_applied, ) = self._apply_hf_processor_main( prompt=prompt, - mm_items=self._to_mm_items(mm_missing_data), + mm_items=mm_missing_data_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, enable_hf_prompt_update=False, ) + mm_missing_kwargs = MultiModalKwargs.from_hf_inputs( + mm_missing_processed_data, + self._get_mm_fields_config(mm_missing_processed_data, + hf_processor_mm_kwargs), + ) + mm_cache_items_merged = self._merge_mm_kwargs( cache, - mm_cache_items=mm_cache_items, - mm_missing_data=mm_missing_data, + mm_cache_items_or_hashes=mm_cache_items_or_hashes, mm_missing_kwargs=mm_missing_kwargs, ) mm_kwargs = MultiModalKwargs.from_items([ - item.value for cache_items in mm_cache_items_merged.values() + item for cache_items in mm_cache_items_merged.values() for item in cache_items ]) - mm_hashes = { - modality: [item.key for item in cache_items] - for modality, cache_items in mm_cache_items_merged.items() - } if return_mm_hashes else None - - return prompt_ids, mm_kwargs, mm_hashes, is_update_applied + return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied def _bind_and_group_updates( self, diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 809a60c1962f8..9d063f1edad06 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -312,25 +312,25 @@ class MsgpackDecoder: return arr.view(torch_dtype).view(shape) def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: - decoded_items = [] - for item in obj: - elems = [] - for v in item: - v["data"] = self._decode_nested_tensors(v["data"]) - # Reconstruct the field processor using MultiModalFieldConfig - factory_meth_name, *field_args = v["field"] - factory_meth = getattr(MultiModalFieldConfig, - factory_meth_name) + return [self._decode_mm_item(v) for v in obj] - # Special case: decode the union "slices" field of - # MultiModalFlatField - if factory_meth_name == "flat": - field_args[0] = self._decode_nested_slices(field_args[0]) + def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem: + return MultiModalKwargsItem.from_elems( + [self._decode_mm_field_elem(v) for v in obj]) - v["field"] = factory_meth(None, *field_args).field - elems.append(MultiModalFieldElem(**v)) - decoded_items.append(MultiModalKwargsItem.from_elems(elems)) - return decoded_items + def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem: + obj["data"] = self._decode_nested_tensors(obj["data"]) + # Reconstruct the field processor using MultiModalFieldConfig + factory_meth_name, *field_args = obj["field"] + factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + + # Special case: decode the union "slices" field of + # MultiModalFlatField + if factory_meth_name == "flat": + field_args[0] = self._decode_nested_slices(field_args[0]) + + obj["field"] = factory_meth(None, *field_args).field + return MultiModalFieldElem(**obj) def _decode_nested_tensors(self, obj: Any) -> NestedTensors: if isinstance(obj, (int, float)):