[Bugfix] Missing cached item in the MultiModalReceiverCache (#28525)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: Chenguang Zheng <645327136@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
knlnguyen1802 2025-12-02 02:18:07 +08:00 committed by GitHub
parent d0985c5feb
commit fc6acc88ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 436 additions and 21 deletions

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp
import numpy as np
import pytest
@ -8,9 +9,16 @@ import torch
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import (
BaseMultiModalProcessorCache,
BaseMultiModalReceiverCache,
MultiModalCache,
MultiModalProcessorCacheInItem,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
MultiModalProcessorSenderCache,
MultiModalReceiverCache,
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
engine_receiver_cache_from_config,
processor_cache_from_config,
)
@ -22,6 +30,7 @@ from vllm.multimodal.inputs import (
MultiModalSharedField,
)
from vllm.multimodal.processing import PromptInsertion
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
pytestmark = pytest.mark.cpu_test
@ -144,8 +153,7 @@ def _compare_caches(
MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items
]
# Should not be used since there is nothing to convert to text
prompt_update = PromptInsertion("dummy", "target", "insertion")
prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0)
for it in range(n_iter):
num_items_to_select = rng.randint(0, max_items_per_iter)
@ -159,10 +167,11 @@ def _compare_caches(
else:
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item
for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
[(item, [prompt_update]) for item in selected_items],
selected_hashes,
)
]
@ -172,10 +181,11 @@ def _compare_caches(
else:
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item
for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
[(item, [prompt_update]) for item in selected_items],
selected_hashes,
)
]
@ -225,3 +235,289 @@ def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
vllm_config_ipc_enabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
def _run_test_cache_eviction_lru(
p0_cache: BaseMultiModalProcessorCache,
p1_cache: BaseMultiModalReceiverCache,
base_item_size: int,
):
request1_hashes = [
"image_A",
"image_B",
"image_C",
]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
for h in request1_hashes
}
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
request2_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
for h in request2_hashes
}
##########################
# STEP 1: Request 1 send
##########################
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
# Cache is empty
assert sender_is_cached_item_req1 == [False, False, False]
# Touch all mm hash for P0 Cache before process
for mm_hash in request1_hashes:
p0_cache.touch_sender_cache_item(mm_hash)
###########################
# Process request 1 for P0 Cache
###########################
item_tuple: MultiModalProcessorCacheInItem
for i, h in enumerate(request1_hashes):
# Use precomputed cache state
is_cached = sender_is_cached_item_req1[i]
item_tuple = (request1_items[h], []) if not is_cached else None
print(f"Request 1: key={h} | cached={is_cached}")
p0_cache.get_and_update_item(item_tuple, h)
###########################
# Process request 1 for P1 Cache
###########################
# Touch all mm hash for P1 Cache before process
for mm_hash in request1_hashes:
p1_cache.touch_receiver_cache_item(mm_hash)
for h in request1_hashes:
p1_cache.get_and_update_item(request1_items[h], h)
expected_hashes = ["image_A", "image_B", "image_C"]
assert list(p0_cache._cache.order) == expected_hashes
##########################
# STEP 2: Request 2 send
##########################
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
assert sender_is_cached_item_req2 == [False, False, True, True]
# Touch all mm hash for P0 Cache before process
for mm_hash in request2_hashes:
p0_cache.touch_sender_cache_item(mm_hash)
###########################
# Process request 2 for P0 Cache
###########################
for i, h in enumerate(request2_hashes):
# Use precomputed cache state again
is_cached = sender_is_cached_item_req2[i]
item_tuple = (request2_items[h], []) if not is_cached else None
print(f"Request 2: key={h} | cached={is_cached}")
p0_cache.get_and_update_item(item_tuple, h)
###########################
# Process request 2 for P1 Cache
###########################
# Touch all mm hash for P1 Cache before process
for mm_hash in request2_hashes:
p1_cache.touch_receiver_cache_item(mm_hash)
for h in request2_hashes:
p1_cache.get_and_update_item(request2_items[h], h)
expected_hashes = ["image_D", "image_E", "image_A", "image_C"]
assert list(p0_cache._cache.order) == expected_hashes
def test_cache_eviction_lru_cache():
model_config = ModelConfig(
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
mm_processor_cache_gb=6 / GiB_bytes,
)
sender_cache = MultiModalProcessorSenderCache(model_config)
receiver_cache = MultiModalReceiverCache(model_config)
_run_test_cache_eviction_lru(sender_cache, receiver_cache, base_item_size=1)
# This test verifies shared-memory cache eviction behavior across processor (p0)
# and receiver (p1) caches.
# Flow summary:
# 1. Request 1 adds images A, B, C — completely filling the cache.
# 2. Request 2 tries to add image_G and image_A, but image_G cannot be added because
# cache is full and A is protected from eviction — cache remains unchanged.
# 3. Request 3 adds image_G, image_H, image_I and image_B
# this time, image_A is evicted, freeing 5MB space
# and image_G, image_H successfully fits,
# image_B is protected from eviction then image_i cannot be added.
# This proving normal eviction and reuse behavior.
def _run_test_cache_eviction_shm(
p0_cache: BaseMultiModalProcessorCache,
p1_cache: BaseMultiModalReceiverCache,
base_item_size: int,
):
request1_hashes = ["image_A", "image_B", "image_C"]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
for h in request1_hashes
}
request1_items_p0_result = []
request2_hashes = ["image_G", "image_A"]
request2_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
)
for h in request2_hashes
}
request2_items_p0_result = []
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
request3_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
)
for h in request3_hashes
}
request3_items_p0_result = []
##########################
# STEP 1: Request 1 send
# This will fill up the cache
##########################
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
# Cache is empty
assert sender_is_cached_item_req1 == [False, False, False]
# Touch all mm hash for P0 Cache before process
for mm_hash in request1_hashes:
p0_cache.touch_sender_cache_item(mm_hash)
###########################
# Process request 1 for P0 Cache
###########################
item_tuple: MultiModalProcessorCacheInItem
for i, h in enumerate(request1_hashes):
# Use precomputed cache state
is_cached = sender_is_cached_item_req1[i]
item_tuple = (request1_items[h], []) if not is_cached else None
print(f"Request 1: key={h} | cached={is_cached}")
p0_result = p0_cache.get_and_update_item(item_tuple, h)
# Only get mm item, ignore prompt update result
request1_items_p0_result.append(p0_result[0])
###########################
# Process request 1 for P1 Cache
###########################
# Touch all mm hash for P1 Cache before process
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
p1_cache.get_and_update_item(mm_item, mm_hash)
expected_hashes = ["image_A", "image_B", "image_C"]
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
##########################
# STEP 2: Request 2 send
# There is no eviction because image_A is protected
# No new item can add to cache
##########################
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
assert sender_is_cached_item_req2 == [False, True]
# Touch all mm hash for P0 Cache before process
for mm_hash in request2_hashes:
p0_cache.touch_sender_cache_item(mm_hash)
###########################
# Process request 2 for P0 Cache
###########################
for i, h in enumerate(request2_hashes):
# Use precomputed cache state again
is_cached = sender_is_cached_item_req2[i]
item_tuple = (request2_items[h], []) if not is_cached else None
print(f"Request 2: key={h} | cached={is_cached}")
p0_result = p0_cache.get_and_update_item(item_tuple, h)
# Only get mm item, ignore prompt update result
request2_items_p0_result.append(p0_result[0])
# image_A cannot be evict then
# image_G will fail to allocate anyway and image_A still in cache
assert p0_cache.is_cached(request2_hashes) == [False, True]
###########################
# Process request 2 for P1 Cache
###########################
# Touch all mm hash for P1 Cache before process
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
p1_cache.get_and_update_item(mm_item, mm_hash)
# Prove that cache state is unchanged
expected_hashes = ["image_A", "image_B", "image_C"]
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
##########################
# STEP 3: Request 3 send
##########################
##### Prove that cache eviction work normally
sender_is_cached_item_req3 = p0_cache.is_cached(request3_hashes)
assert sender_is_cached_item_req3 == [False, False, False, True]
# Touch all mm hash for P0 Cache before process
for mm_hash in request3_hashes:
p0_cache.touch_sender_cache_item(mm_hash)
###########################
# Process request 3 for P0 Cache
###########################
for i, h in enumerate(request3_hashes):
# Use precomputed cache state again
is_cached = sender_is_cached_item_req3[i]
item_tuple = (request3_items[h], []) if not is_cached else None
print(f"Request 3: key={h} | cached={is_cached}")
p0_result = p0_cache.get_and_update_item(item_tuple, h)
# Only get mm item, ignore prompt update result
request3_items_p0_result.append(p0_result[0])
# image_A got evict and image_G add to cache
# image_B is still protected
# image_G, image_H fit but image_I cannot fit
assert p0_cache.is_cached(request3_hashes) == [True, True, False, True]
###########################
# Process request 3 for P1 Cache
###########################
# Touch all mm hash for P1 Cache before process
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
p1_cache.get_and_update_item(mm_item, mm_hash)
expected_hashes = ["image_B", "image_C", "image_G", "image_H"]
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
def test_cache_eviction_shm_cache():
vllm_config = VllmConfig(
model_config=ModelConfig(
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
mm_processor_cache_type="shm",
mm_shm_cache_max_object_size_mb=6,
mm_processor_cache_gb=15.2 * MiB_bytes / GiB_bytes,
),
)
sender_cache = ShmObjectStoreSenderCache(vllm_config)
receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock())
_run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes)

View File

@ -574,7 +574,6 @@ class SingleWriterShmObjectStorage:
value
)
buffer_size = self.flag_bytes + data_bytes + md_bytes
# Sanity checks
if buffer_size > self.max_object_size:
raise ValueError(
@ -626,6 +625,44 @@ class SingleWriterShmObjectStorage:
return obj
def touch(
self,
key: str,
address: int = 0,
monotonic_id: int = 0,
) -> None:
"""
Touch an existing cached item to update its eviction status.
For writers (ShmObjectStoreSenderCache): Increment writer_flag
For readers (ShmObjectStoreReceiverCache): Increment reader_count
Args:
key: String key of the object to touch
address: Address of the object (only for readers)
monotonic_id: Monotonic ID of the object (only for readers)
"""
if self._reader_lock is None:
if key not in self.key_index:
return None
address, monotonic_id = self.key_index[key]
# Writer side: increment writer_flag to raise eviction threshold
self.increment_writer_flag(monotonic_id)
else:
with (
self._reader_lock,
self.ring_buffer.access_buf(address) as (data_view, _),
):
reader_count = self.ring_buffer.byte2int(data_view[: self.flag_bytes])
# NOTE(Long):
# Avoid increasing flag on newly added item (sync with sender)
# Since when a new item is added
# pre-touch has no effect on writer side
if reader_count >= self.n_readers:
self.increment_reader_flag(data_view[: self.flag_bytes])
def handle(self):
"""Get handle for sharing across processes."""
return ShmObjectStorageHandle(

View File

@ -302,6 +302,19 @@ class BaseMultiModalProcessorCache(
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
@abstractmethod
def touch_sender_cache_item(self, mm_hash: str) -> None:
"""
Update the cache eviction order for a multi-modal item.
This is used to touch the item in the cache without changing
its value.
Args:
mm_hash: The hash of the multi-modal item.
"""
raise NotImplementedError
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
@ -353,6 +366,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
@ -407,6 +424,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
@ -501,6 +522,12 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
return mm_item
@override
def touch_sender_cache_item(self, mm_hash: str) -> None:
"""Touch the item in shared memory cache to prevent eviction.
Increments writer_flag on sender side."""
self._shm_cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
@ -610,11 +637,37 @@ class BaseMultiModalReceiverCache(
self,
mm_features: list["MultiModalFeatureSpec"],
) -> list["MultiModalFeatureSpec"]:
"""Update multimodal features with cached encoder outputs."""
"""
Update multimodal features with cached encoder outputs.
Touch all identifier at first before update to avoid
item in updated list evict during update.
"""
for feature in mm_features:
self.touch_receiver_cache_item(feature.identifier, feature.data)
for feature in mm_features:
feature.data = self.get_and_update_item(feature.data, feature.identifier)
return mm_features
@abstractmethod
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
"""
Update the cache eviction order for a multi-modal item.
This is used to touch the item in the cache without changing
its value.
Args:
mm_hash: The hash of the multi-modal item.
mm_item: The multi-modal item itself. This is optional and
may not be needed by some cache implementations.
"""
raise NotImplementedError
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
@ -651,6 +704,14 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache):
self._cache[mm_hash] = mm_item
return mm_item
@override
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
self._cache.touch(mm_hash)
@override
def clear_cache(self) -> None:
self._cache.clear()
@ -703,6 +764,20 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
return mm_item
@override
def touch_receiver_cache_item(
self,
mm_hash: str,
mm_item: MultiModalKwargsItem | None = None,
) -> None:
"""Touch the item in shared memory cache to prevent eviction.
Increments reader_count on receiver side."""
assert mm_item is not None
if "address" in mm_item:
address = cast(int, mm_item["address"].data)
monotonic_id = cast(int, mm_item["monotonic_id"].data)
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
@override
def clear_cache(self) -> None:
self._shm_cache.clear()

View File

@ -721,12 +721,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
@staticmethod
def dummy(modality: str):
def dummy(modality: str, nbytes: int = 1):
"""Convenience class for testing."""
mm_elem = MultiModalFieldElem(
modality=modality,
key="dummy",
data=torch.empty(1),
data=torch.empty(nbytes, dtype=torch.uint8),
field=MultiModalSharedField(1),
)
return MultiModalKwargsItem.from_elems([mm_elem])

View File

@ -1248,7 +1248,13 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as
A collection of the multi-modal hash for each item, with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
MultiModalIsCached = dict[str, list[bool]]
"""
A collection of the `is_cached` flag for each item, with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
@ -1725,7 +1731,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: BaseMultiModalProcessorCache,
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> MultiModalDataItems:
) -> tuple[MultiModalIsCached, MultiModalDataItems]:
mm_is_cached = {
modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
}
@ -1752,7 +1758,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
missing_modality_data.append(data)
mm_missing_data[modality] = missing_modality_data
return self._to_mm_items(mm_missing_data)
return mm_is_cached, self._to_mm_items(mm_missing_data)
def _recompute_cached_prompt_update(
self,
@ -1769,14 +1775,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
cache: BaseMultiModalProcessorCache,
mm_hashes: MultiModalHashes,
mm_is_cached: MultiModalIsCached,
mm_missing_kwargs: MultiModalKwargsItems,
mm_missing_prompt_updates: MultiModalPromptUpdates,
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
# Need to calculate this at the beginning to avoid skipping cache logic
# for subsequently repeated items in the same modality
mm_is_cached = {
modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
}
# Need to touch all mm hashes before update to avoid hash in updated
# list evict during update
for hashes in mm_hashes.values():
for item_hash in hashes:
cache.touch_sender_cache_item(item_hash)
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
@ -1789,15 +1796,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
for item_idx, item_hash in enumerate(hashes):
kwargs: MultiModalKwargsItem | None
if not mm_is_cached[modality][item_idx]:
missing_next_idx = mm_missing_next_idx[modality]
kwargs = missing_kwargs[missing_next_idx]
updates = missing_prompt_updates[missing_next_idx]
missing_kwargs_item = missing_kwargs[missing_next_idx]
missing_updates_item = missing_prompt_updates[missing_next_idx]
mm_missing_next_idx[modality] += 1
item = kwargs, updates
item = missing_kwargs_item, missing_updates_item
else:
item = None
@ -1896,7 +1902,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids,
)
mm_missing_data_items = self._get_cache_missing_items(
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
@ -1933,6 +1939,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_is_cached=mm_is_cached,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)