[Core] Simplify mm processing cache (#22457)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-08 00:47:07 +08:00 committed by GitHub
parent 399d2a10e2
commit 8c9da6be22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 95 additions and 204 deletions

View File

@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenization_kwargs: Mapping[str, object],
*,
enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]:
) -> tuple[list[int], BatchFeature, bool]:
"""
Qwen2.5-Omni reimplements this function to handle text only.
"""
@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only(
mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, False
return prompt_ids, mm_processed_data, False
def _apply_hf_processor_mm_only(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
) -> BatchFeature:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
assert "audio" in mm_counts
mm_counts["audio"] -= mm_counts["video"]
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return mm_kwargs
return mm_processed_data
def _validate_mm_placeholders(
self,

View File

@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
import regex as re
import torch
from torch import nn
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers import (AutoModel, BatchFeature, PretrainedConfig,
PreTrainedModel)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention
@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
):
) -> tuple[list[int], BatchFeature, bool]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.

View File

@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher
@ -887,120 +887,19 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
class ProcessingCacheItem(NamedTuple):
key: str
value: MultiModalKwargsItem
class ProcessingCache(MultiModalCache):
def __init__(
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
def __init__(self, capacity_gb: float) -> None:
super().__init__()
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)
self.get = self._cache.get
self.put = self._cache.put
self.reset = self._cache.clear
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
if not steps:
return
total = self.debug_cache_total
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> Optional[MultiModalKwargsItem]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
- The model ID
- The modality of the item
- The original data item passed to the HF processor
- The configuration options of the HF processor
"""
self._maybe_log_cache_stats()
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
if self.debug_cache_hit_ratio_steps:
if cache_key in self._cache:
self.debug_cache_hits += 1
self.debug_cache_total += 1
return self._cache.get(cache_key)
def get_item(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> ProcessingCacheOptionalItem:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return ProcessingCacheOptionalItem(
key=cache_key,
value=self._cache.get(cache_key),
)
def put(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
output_kwargs: MultiModalKwargsItem,
) -> None:
"""
Put a processed multi-modal item into the cache
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache[cache_key] = output_kwargs
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value
def reset(self) -> bool:
self._cache.clear()
return True
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
class BaseProcessingInfo:
@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
) -> tuple[list[int], "BatchFeature", bool]:
"""
Apply the HF processor on the prompt text and multi-modal data
together.
@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids, = processed_data.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, is_update_applied
return prompt_ids, processed_data, is_update_applied
def _apply_hf_processor_text_only(
self, prompt_text: str,
tokenization_kwargs: Mapping[str, object]) -> list[int]:
self,
prompt_text: str,
tokenization_kwargs: Mapping[str, object],
) -> list[int]:
"""
Apply the HF processor on the prompt text only.
@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
) -> "BatchFeature":
"""
Apply the HF processor on the multi-modal data only.
@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
mm_counts = mm_items.get_all_counts()
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return mm_kwargs
return mm_processed_data
def _apply_hf_processor_main(
self,
@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]:
) -> tuple[list[int], "BatchFeature", bool]:
"""
Apply the HF processor on the prompt text and multi-modal data.
@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only(
mm_processed_data = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, False
return prompt_ids, mm_processed_data, False
def _get_cache_missing_items(
self,
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
str, list[object]]]:
model_id = self.info.model_id
mm_cache_items = {
modality: [
cache.get_item(
model_id, modality, item,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
for item in items
]
for modality, items in mm_data_items.items()
mm_hashes: MultiModalHashes,
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
modality: [(h if (v := cache.get(h)) is None else v)
for h in hashes]
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item in enumerate(cache_items)
if item.value is None
idx for idx, item_or_hash in enumerate(items_or_hashes)
if isinstance(item_or_hash, str)
]
for modality, cache_items in mm_cache_items.items()
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return mm_cache_items, mm_missing_data
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes:
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id
@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _merge_mm_kwargs(
self,
cache: ProcessingCache,
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
mm_missing_data: dict[str, list[object]],
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargs,
) -> dict[str, list[ProcessingCacheItem]]:
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}
) -> dict[str, list[MultiModalKwargsItem]]:
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
for modality, cache_items in mm_cache_items.items():
for cache_item in cache_items:
if cache_item.value is None:
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for item_or_hash in items_or_hashes:
if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=kw_item,
)
cache.put_item(cache_item_new)
cache.put(item_or_hash, kw_item)
mm_missing_next_idx[modality] += 1
else:
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=cache_item.value,
)
kw_item = item_or_hash
merged_items[modality].append(cache_item_new)
merged_items[modality].append(kw_item)
return dict(merged_items)
@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
(
prompt_ids,
mm_kwargs,
mm_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update=True,
)
mm_kwargs = MultiModalKwargs.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data,
hf_processor_mm_kwargs),
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
if return_mm_hashes else None)
@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return_mm_hashes=return_mm_hashes,
)
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
(
mm_cache_items,
mm_missing_data,
mm_cache_items_or_hashes,
mm_missing_data_items,
) = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_hashes=mm_hashes,
)
mm_hashes_to_return = mm_hashes if return_mm_hashes else None
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_kwargs,
mm_missing_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=self._to_mm_items(mm_missing_data),
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False,
)
mm_missing_kwargs = MultiModalKwargs.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(mm_missing_processed_data,
hf_processor_mm_kwargs),
)
mm_cache_items_merged = self._merge_mm_kwargs(
cache,
mm_cache_items=mm_cache_items,
mm_missing_data=mm_missing_data,
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs,
)
mm_kwargs = MultiModalKwargs.from_items([
item.value for cache_items in mm_cache_items_merged.values()
item for cache_items in mm_cache_items_merged.values()
for item in cache_items
])
mm_hashes = {
modality: [item.key for item in cache_items]
for modality, cache_items in mm_cache_items_merged.items()
} if return_mm_hashes else None
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
def _bind_and_group_updates(
self,

View File

@ -312,25 +312,25 @@ class MsgpackDecoder:
return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
for item in obj:
elems = []
for v in item:
v["data"] = self._decode_nested_tensors(v["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
return [self._decode_mm_item(v) for v in obj]
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
def _decode_mm_item(self, obj: list) -> MultiModalKwargsItem:
return MultiModalKwargsItem.from_elems(
[self._decode_mm_field_elem(v) for v in obj])
v["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v))
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
return decoded_items
def _decode_mm_field_elem(self, obj: dict) -> MultiModalFieldElem:
obj["data"] = self._decode_nested_tensors(obj["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = obj["field"]
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
# Special case: decode the union "slices" field of
# MultiModalFlatField
if factory_meth_name == "flat":
field_args[0] = self._decode_nested_slices(field_args[0])
obj["field"] = factory_meth(None, *field_args).field
return MultiModalFieldElem(**obj)
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, (int, float)):