[Core] Simplify mm processing cache (#22457)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-08 00:47:07 +08:00 committed by GitHub
parent 399d2a10e2
commit 8c9da6be22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 204 deletions

View File

@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], BatchFeature, bool]:
""" """
Qwen2.5-Omni reimplements this function to handle text only. Qwen2.5-Omni reimplements this function to handle text only.
""" """
@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only( mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_processed_data, False
def _apply_hf_processor_mm_only( def _apply_hf_processor_mm_only(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> BatchFeature:
""" """
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
""" """
@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
assert "audio" in mm_counts assert "audio" in mm_counts
mm_counts["audio"] -= mm_counts["video"] mm_counts["audio"] -= mm_counts["video"]
_, mm_kwargs, _ = self._apply_hf_processor_text_mm( _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_processed_data
def _validate_mm_placeholders( def _validate_mm_placeholders(
self, self,

View File

@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
import regex as re import regex as re
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
): ) -> tuple[list[int], BatchFeature, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.

View File

@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache from .cache import MultiModalCache
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
@ -887,120 +887,19 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
class ProcessingCacheItem(NamedTuple):
key: str
value: MultiModalKwargsItem
class ProcessingCache(MultiModalCache): class ProcessingCache(MultiModalCache):
def __init__( def __init__(self, capacity_gb: float) -> None:
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = self.get_lru_cache( self.get = self._cache.get
capacity_gb, self.put = self._cache.put
MultiModalKwargsItem, self.reset = self._cache.clear
debug=bool(debug_cache_hit_ratio_steps),
)
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
if not steps:
return
total = self.debug_cache_total _CacheItemOrHash = Union[MultiModalKwargsItem, str]
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> Optional[MultiModalKwargsItem]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
- The model ID
- The modality of the item
- The original data item passed to the HF processor
- The configuration options of the HF processor
"""
self._maybe_log_cache_stats()
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
if self.debug_cache_hit_ratio_steps:
if cache_key in self._cache:
self.debug_cache_hits += 1
self.debug_cache_total += 1
return self._cache.get(cache_key)
def get_item(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> ProcessingCacheOptionalItem:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return ProcessingCacheOptionalItem(
key=cache_key,
value=self._cache.get(cache_key),
)
def put(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
output_kwargs: MultiModalKwargsItem,
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache[cache_key] = output_kwargs
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value
def reset(self) -> bool:
self._cache.clear()
return True
class BaseProcessingInfo: class BaseProcessingInfo:
@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], "BatchFeature", bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
together. together.
@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids, = processed_data.pop("input_ids").tolist() prompt_ids, = processed_data.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
is_update_applied = self._hf_processor_applies_updates( is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_items, mm_items=mm_items,
@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, is_update_applied return prompt_ids, processed_data, is_update_applied
def _apply_hf_processor_text_only( def _apply_hf_processor_text_only(
self, prompt_text: str, self,
tokenization_kwargs: Mapping[str, object]) -> list[int]: prompt_text: str,
tokenization_kwargs: Mapping[str, object],
) -> list[int]:
""" """
Apply the HF processor on the prompt text only. Apply the HF processor on the prompt text only.
@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> "BatchFeature":
""" """
Apply the HF processor on the multi-modal data only. Apply the HF processor on the multi-modal data only.
@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
mm_counts = mm_items.get_all_counts() mm_counts = mm_items.get_all_counts()
_, mm_kwargs, _ = self._apply_hf_processor_text_mm( _, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_processed_data
def _apply_hf_processor_main( def _apply_hf_processor_main(
self, self,
@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], "BatchFeature", bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data. Apply the HF processor on the prompt text and multi-modal data.
@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only( mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_processed_data, False
def _get_cache_missing_items( def _get_cache_missing_items(
self, self,
cache: ProcessingCache, cache: ProcessingCache,
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_hashes: MultiModalHashes,
tokenization_kwargs: Mapping[str, object], ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
str, list[object]]]: modality: [(h if (v := cache.get(h)) is None else v)
model_id = self.info.model_id for h in hashes]
for modality, hashes in mm_hashes.items()
mm_cache_items = {
modality: [
cache.get_item(
model_id, modality, item,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
for item in items
]
for modality, items in mm_data_items.items()
} }
mm_missing_idxs = { mm_missing_idxs = {
modality: [ modality: [
idx for idx, item in enumerate(cache_items) idx for idx, item_or_hash in enumerate(items_or_hashes)
if item.value is None if isinstance(item_or_hash, str)
] ]
for modality, cache_items in mm_cache_items.items() for modality, items_or_hashes in mm_cache_items_or_hashes.items()
} }
mm_missing_data = { mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs] modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items() for modality, idxs in mm_missing_idxs.items()
} }
return mm_cache_items, mm_missing_data return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items( def _hash_mm_items(
self, mm_items: MultiModalDataItems, self,
hf_processor_mm_kwargs: Mapping[str, object], mm_items: MultiModalDataItems,
tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes: hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).""" """Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id model_id = self.info.model_id
@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _merge_mm_kwargs( def _merge_mm_kwargs(
self, self,
cache: ProcessingCache, cache: ProcessingCache,
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]], mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_data: dict[str, list[object]],
mm_missing_kwargs: MultiModalKwargs, mm_missing_kwargs: MultiModalKwargs,
) -> dict[str, list[ProcessingCacheItem]]: ) -> dict[str, list[MultiModalKwargsItem]]:
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data} mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[ProcessingCacheItem]](list) merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
for modality, cache_items in mm_cache_items.items(): for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for cache_item in cache_items: for item_or_hash in items_or_hashes:
if cache_item.value is None: if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs.get_item( kw_item = mm_missing_kwargs.get_item(
modality, modality,
mm_missing_next_idx[modality], mm_missing_next_idx[modality],
) )
cache_item_new = ProcessingCacheItem( cache.put(item_or_hash, kw_item)
key=cache_item.key,
value=kw_item,
)
cache.put_item(cache_item_new)
mm_missing_next_idx[modality] += 1 mm_missing_next_idx[modality] += 1
else: else:
cache_item_new = ProcessingCacheItem( kw_item = item_or_hash
key=cache_item.key,
value=cache_item.value,
)
merged_items[modality].append(cache_item_new) merged_items[modality].append(kw_item)
return dict(merged_items) return dict(merged_items)
@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_processed_data,
is_update_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=True, enable_hf_prompt_update=True,
) )
mm_kwargs = MultiModalKwargs.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs),
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs) tokenization_kwargs)
if return_mm_hashes else None) if return_mm_hashes else None)
@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
( (
mm_cache_items, mm_cache_items_or_hashes,
mm_missing_data, mm_missing_data_items,
) = self._get_cache_missing_items( ) = self._get_cache_missing_items(
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, mm_hashes=mm_hashes,
tokenization_kwargs=tokenization_kwargs,
) )
mm_hashes_to_return = mm_hashes if return_mm_hashes else None
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal # so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items # items are combined with the cached multimodal items
( (
prompt_ids, prompt_ids,
mm_missing_kwargs, mm_missing_processed_data,
is_update_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=self._to_mm_items(mm_missing_data), mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False, enable_hf_prompt_update=False,
) )
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(mm_missing_processed_data,
hf_processor_mm_kwargs),
)
mm_cache_items_merged = self._merge_mm_kwargs( mm_cache_items_merged = self._merge_mm_kwargs(
cache, cache,
mm_cache_items=mm_cache_items, mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_data=mm_missing_data,
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs=mm_missing_kwargs,
) )
mm_kwargs = MultiModalKwargs.from_items([ mm_kwargs = MultiModalKwargs.from_items([
item.value for cache_items in mm_cache_items_merged.values() item for cache_items in mm_cache_items_merged.values()
for item in cache_items for item in cache_items
]) ])
mm_hashes = { return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
modality: [item.key for item in cache_items]
for modality, cache_items in mm_cache_items_merged.items()
} if return_mm_hashes else None
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _bind_and_group_updates( def _bind_and_group_updates(
self, self,

View File

@ -312,25 +312,25 @@ class MsgpackDecoder:
return arr.view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = [] return [self._decode_mm_item(v) for v in obj]
for item in obj:
elems = []
for v in item:
v["data"] = self._decode_nested_tensors(v["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
# Special case: decode the union "slices" field of def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
# MultiModalFlatField return MultiModalKwargsItem.from_elems(
if factory_meth_name == "flat": [self._decode_mm_field_elem(v) for v in obj])
field_args[0] = self._decode_nested_slices(field_args[0])
v["field"] = factory_meth(None, *field_args).field def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
elems.append(MultiModalFieldElem(**v)) obj["data"] = self._decode_nested_tensors(obj["data"])
decoded_items.append(MultiModalKwargsItem.from_elems(elems)) # Reconstruct the field processor using MultiModalFieldConfig
return decoded_items factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
obj["field"] = factory_meth(None, *field_args).field
return MultiModalFieldElem(**obj)
def _decode_nested_tensors(self, obj: Any) -> NestedTensors: def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)): if isinstance(obj, (int, float)):