diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 531674c30f55f..2ddc93f8daf7b 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import multiprocessing as mp import numpy as np import pytest @@ -8,9 +9,16 @@ import torch from vllm.config import ModelConfig, ParallelConfig, VllmConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import ( + BaseMultiModalProcessorCache, + BaseMultiModalReceiverCache, MultiModalCache, + MultiModalProcessorCacheInItem, MultiModalProcessorCacheItem, MultiModalProcessorCacheItemMetadata, + MultiModalProcessorSenderCache, + MultiModalReceiverCache, + ShmObjectStoreReceiverCache, + ShmObjectStoreSenderCache, engine_receiver_cache_from_config, processor_cache_from_config, ) @@ -22,6 +30,7 @@ from vllm.multimodal.inputs import ( MultiModalSharedField, ) from vllm.multimodal.processing import PromptInsertion +from vllm.utils.mem_constants import GiB_bytes, MiB_bytes pytestmark = pytest.mark.cpu_test @@ -144,8 +153,7 @@ def _compare_caches( MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items ] - # Should not be used since there is nothing to convert to text - prompt_update = PromptInsertion("dummy", "target", "insertion") + prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0) for it in range(n_iter): num_items_to_select = rng.randint(0, max_items_per_iter) @@ -159,10 +167,11 @@ def _compare_caches( else: for _ in range(is_cached_calls_per_iter): cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ item for item, _ in cache_0_p0.get_and_update( - [(item, prompt_update.content) for item in selected_items], + [(item, [prompt_update]) for item in selected_items], selected_hashes, ) ] @@ -172,10 +181,11 @@ def _compare_caches( else: for _ in range(is_cached_calls_per_iter): cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ item for item, _ in cache_1_p0.get_and_update( - [(item, prompt_update.content) for item in selected_items], + [(item, [prompt_update]) for item in selected_items], selected_hashes, ) ] @@ -225,3 +235,289 @@ def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): vllm_config_ipc_enabled, is_cached_calls_per_iter=is_cached_calls_per_iter, ) + + +def _run_test_cache_eviction_lru( + p0_cache: BaseMultiModalProcessorCache, + p1_cache: BaseMultiModalReceiverCache, + base_item_size: int, +): + request1_hashes = [ + "image_A", + "image_B", + "image_C", + ] + request1_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size) + for h in request1_hashes + } + + request2_hashes = ["image_D", "image_E", "image_A", "image_C"] + request2_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size) + for h in request2_hashes + } + + ########################## + # STEP 1: Request 1 send + ########################## + sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes) + # Cache is empty + assert sender_is_cached_item_req1 == [False, False, False] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request1_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 1 for P0 Cache + ########################### + item_tuple: MultiModalProcessorCacheInItem + for i, h in enumerate(request1_hashes): + # Use precomputed cache state + is_cached = sender_is_cached_item_req1[i] + item_tuple = (request1_items[h], []) if not is_cached else None + print(f"Request 1: key={h} | cached={is_cached}") + + p0_cache.get_and_update_item(item_tuple, h) + + ########################### + # Process request 1 for P1 Cache + ########################### + # Touch all mm hash for P1 Cache before process + for mm_hash in request1_hashes: + p1_cache.touch_receiver_cache_item(mm_hash) + + for h in request1_hashes: + p1_cache.get_and_update_item(request1_items[h], h) + + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._cache.order) == expected_hashes + + ########################## + # STEP 2: Request 2 send + ########################## + sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes) + assert sender_is_cached_item_req2 == [False, False, True, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request2_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 2 for P0 Cache + ########################### + for i, h in enumerate(request2_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req2[i] + item_tuple = (request2_items[h], []) if not is_cached else None + print(f"Request 2: key={h} | cached={is_cached}") + + p0_cache.get_and_update_item(item_tuple, h) + + ########################### + # Process request 2 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash in request2_hashes: + p1_cache.touch_receiver_cache_item(mm_hash) + + for h in request2_hashes: + p1_cache.get_and_update_item(request2_items[h], h) + + expected_hashes = ["image_D", "image_E", "image_A", "image_C"] + assert list(p0_cache._cache.order) == expected_hashes + + +def test_cache_eviction_lru_cache(): + model_config = ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_gb=6 / GiB_bytes, + ) + sender_cache = MultiModalProcessorSenderCache(model_config) + receiver_cache = MultiModalReceiverCache(model_config) + + _run_test_cache_eviction_lru(sender_cache, receiver_cache, base_item_size=1) + + +# This test verifies shared-memory cache eviction behavior across processor (p0) +# and receiver (p1) caches. +# Flow summary: +# 1. Request 1 adds images A, B, C — completely filling the cache. +# 2. Request 2 tries to add image_G and image_A, but image_G cannot be added because +# cache is full and A is protected from eviction — cache remains unchanged. +# 3. Request 3 adds image_G, image_H, image_I and image_B +# this time, image_A is evicted, freeing 5MB space +# and image_G, image_H successfully fits, +# image_B is protected from eviction then image_i cannot be added. +# This proving normal eviction and reuse behavior. +def _run_test_cache_eviction_shm( + p0_cache: BaseMultiModalProcessorCache, + p1_cache: BaseMultiModalReceiverCache, + base_item_size: int, +): + request1_hashes = ["image_A", "image_B", "image_C"] + request1_items = { + h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size) + for h in request1_hashes + } + request1_items_p0_result = [] + + request2_hashes = ["image_G", "image_A"] + request2_items = { + h: MultiModalKwargsItem.dummy( + h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + ) + for h in request2_hashes + } + request2_items_p0_result = [] + + request3_hashes = ["image_G", "image_H", "image_I", "image_B"] + request3_items = { + h: MultiModalKwargsItem.dummy( + h, nbytes=(5 if h in request1_hashes else 2) * base_item_size + ) + for h in request3_hashes + } + request3_items_p0_result = [] + + ########################## + # STEP 1: Request 1 send + # This will fill up the cache + ########################## + sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes) + # Cache is empty + assert sender_is_cached_item_req1 == [False, False, False] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request1_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 1 for P0 Cache + ########################### + item_tuple: MultiModalProcessorCacheInItem + for i, h in enumerate(request1_hashes): + # Use precomputed cache state + is_cached = sender_is_cached_item_req1[i] + item_tuple = (request1_items[h], []) if not is_cached else None + print(f"Request 1: key={h} | cached={is_cached}") + + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request1_items_p0_result.append(p0_result[0]) + + ########################### + # Process request 1 for P1 Cache + ########################### + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + ########################## + # STEP 2: Request 2 send + # There is no eviction because image_A is protected + # No new item can add to cache + ########################## + sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes) + assert sender_is_cached_item_req2 == [False, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request2_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 2 for P0 Cache + ########################### + for i, h in enumerate(request2_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req2[i] + item_tuple = (request2_items[h], []) if not is_cached else None + print(f"Request 2: key={h} | cached={is_cached}") + + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request2_items_p0_result.append(p0_result[0]) + + # image_A cannot be evict then + # image_G will fail to allocate anyway and image_A still in cache + assert p0_cache.is_cached(request2_hashes) == [False, True] + + ########################### + # Process request 2 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + # Prove that cache state is unchanged + expected_hashes = ["image_A", "image_B", "image_C"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + ########################## + # STEP 3: Request 3 send + ########################## + ##### Prove that cache eviction work normally + sender_is_cached_item_req3 = p0_cache.is_cached(request3_hashes) + assert sender_is_cached_item_req3 == [False, False, False, True] + + # Touch all mm hash for P0 Cache before process + for mm_hash in request3_hashes: + p0_cache.touch_sender_cache_item(mm_hash) + + ########################### + # Process request 3 for P0 Cache + ########################### + for i, h in enumerate(request3_hashes): + # Use precomputed cache state again + is_cached = sender_is_cached_item_req3[i] + item_tuple = (request3_items[h], []) if not is_cached else None + print(f"Request 3: key={h} | cached={is_cached}") + p0_result = p0_cache.get_and_update_item(item_tuple, h) + # Only get mm item, ignore prompt update result + request3_items_p0_result.append(p0_result[0]) + + # image_A got evict and image_G add to cache + # image_B is still protected + # image_G, image_H fit but image_I cannot fit + assert p0_cache.is_cached(request3_hashes) == [True, True, False, True] + + ########################### + # Process request 3 for P1 Cache + ########################### + + # Touch all mm hash for P1 Cache before process + for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result): + p1_cache.touch_receiver_cache_item(mm_hash, mm_item) + + for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result): + p1_cache.get_and_update_item(mm_item, mm_hash) + + expected_hashes = ["image_B", "image_C", "image_G", "image_H"] + assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes + + +def test_cache_eviction_shm_cache(): + vllm_config = VllmConfig( + model_config=ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_type="shm", + mm_shm_cache_max_object_size_mb=6, + mm_processor_cache_gb=15.2 * MiB_bytes / GiB_bytes, + ), + ) + sender_cache = ShmObjectStoreSenderCache(vllm_config) + receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock()) + + _run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes) diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 4af2caa16b0d6..5da261fbc6cfc 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -574,7 +574,6 @@ class SingleWriterShmObjectStorage: value ) buffer_size = self.flag_bytes + data_bytes + md_bytes - # Sanity checks if buffer_size > self.max_object_size: raise ValueError( @@ -626,6 +625,44 @@ class SingleWriterShmObjectStorage: return obj + def touch( + self, + key: str, + address: int = 0, + monotonic_id: int = 0, + ) -> None: + """ + Touch an existing cached item to update its eviction status. + + For writers (ShmObjectStoreSenderCache): Increment writer_flag + For readers (ShmObjectStoreReceiverCache): Increment reader_count + + Args: + key: String key of the object to touch + address: Address of the object (only for readers) + monotonic_id: Monotonic ID of the object (only for readers) + + """ + if self._reader_lock is None: + if key not in self.key_index: + return None + address, monotonic_id = self.key_index[key] + # Writer side: increment writer_flag to raise eviction threshold + self.increment_writer_flag(monotonic_id) + else: + with ( + self._reader_lock, + self.ring_buffer.access_buf(address) as (data_view, _), + ): + reader_count = self.ring_buffer.byte2int(data_view[: self.flag_bytes]) + + # NOTE(Long): + # Avoid increasing flag on newly added item (sync with sender) + # Since when a new item is added + # pre-touch has no effect on writer side + if reader_count >= self.n_readers: + self.increment_reader_flag(data_view[: self.flag_bytes]) + def handle(self): """Get handle for sharing across processes.""" return ShmObjectStorageHandle( diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index c1531cbfdc31d..97f6aa461b90c 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -302,6 +302,19 @@ class BaseMultiModalProcessorCache( """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + @abstractmethod + def touch_sender_cache_item(self, mm_hash: str) -> None: + """ + Update the cache eviction order for a multi-modal item. + + This is used to touch the item in the cache without changing + its value. + + Args: + mm_hash: The hash of the multi-modal item. + """ + raise NotImplementedError + @abstractmethod def make_stats(self, *, delta: bool = False) -> CacheInfo: """ @@ -353,6 +366,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -407,6 +424,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -501,6 +522,12 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e) return mm_item + @override + def touch_sender_cache_item(self, mm_hash: str) -> None: + """Touch the item in shared memory cache to prevent eviction. + Increments writer_flag on sender side.""" + self._shm_cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._shm_cache.clear() @@ -610,11 +637,37 @@ class BaseMultiModalReceiverCache( self, mm_features: list["MultiModalFeatureSpec"], ) -> list["MultiModalFeatureSpec"]: - """Update multimodal features with cached encoder outputs.""" + """ + Update multimodal features with cached encoder outputs. + Touch all identifier at first before update to avoid + item in updated list evict during update. + """ + for feature in mm_features: + self.touch_receiver_cache_item(feature.identifier, feature.data) + for feature in mm_features: feature.data = self.get_and_update_item(feature.data, feature.identifier) return mm_features + @abstractmethod + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + """ + Update the cache eviction order for a multi-modal item. + + This is used to touch the item in the cache without changing + its value. + + Args: + mm_hash: The hash of the multi-modal item. + mm_item: The multi-modal item itself. This is optional and + may not be needed by some cache implementations. + """ + raise NotImplementedError + class MultiModalReceiverCache(BaseMultiModalReceiverCache): """ @@ -651,6 +704,14 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache): self._cache[mm_hash] = mm_item return mm_item + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + self._cache.touch(mm_hash) + @override def clear_cache(self) -> None: self._cache.clear() @@ -703,6 +764,20 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): return mm_item + @override + def touch_receiver_cache_item( + self, + mm_hash: str, + mm_item: MultiModalKwargsItem | None = None, + ) -> None: + """Touch the item in shared memory cache to prevent eviction. + Increments reader_count on receiver side.""" + assert mm_item is not None + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id) + @override def clear_cache(self) -> None: self._shm_cache.clear() diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7518a023c5f50..f4e38b1f3325f 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -721,12 +721,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): """ @staticmethod - def dummy(modality: str): + def dummy(modality: str, nbytes: int = 1): """Convenience class for testing.""" mm_elem = MultiModalFieldElem( modality=modality, key="dummy", - data=torch.empty(1), + data=torch.empty(nbytes, dtype=torch.uint8), field=MultiModalSharedField(1), ) return MultiModalKwargsItem.from_elems([mm_elem]) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 912cff2343dd0..2f651bd71706f 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1248,7 +1248,13 @@ _I = TypeVar("_I", bound=BaseProcessingInfo) MultiModalHashes = dict[str, list[str]] """ -A collection of hashes with a similar structure as +A collection of the multi-modal hash for each item, with a similar structure as +[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. +""" + +MultiModalIsCached = dict[str, list[bool]] +""" +A collection of the `is_cached` flag for each item, with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ @@ -1725,7 +1731,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): cache: BaseMultiModalProcessorCache, mm_data_items: MultiModalDataItems, mm_hashes: MultiModalHashes, - ) -> MultiModalDataItems: + ) -> tuple[MultiModalIsCached, MultiModalDataItems]: mm_is_cached = { modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() } @@ -1752,7 +1758,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_modality_data.append(data) mm_missing_data[modality] = missing_modality_data - return self._to_mm_items(mm_missing_data) + return mm_is_cached, self._to_mm_items(mm_missing_data) def _recompute_cached_prompt_update( self, @@ -1769,14 +1775,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, cache: BaseMultiModalProcessorCache, mm_hashes: MultiModalHashes, + mm_is_cached: MultiModalIsCached, mm_missing_kwargs: MultiModalKwargsItems, mm_missing_prompt_updates: MultiModalPromptUpdates, ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: - # Need to calculate this at the beginning to avoid skipping cache logic - # for subsequently repeated items in the same modality - mm_is_cached = { - modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items() - } + # Need to touch all mm hashes before update to avoid hash in updated + # list evict during update + for hashes in mm_hashes.values(): + for item_hash in hashes: + cache.touch_sender_cache_item(item_hash) mm_missing_next_idx = defaultdict[str, int](lambda: 0) @@ -1789,15 +1796,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_prompt_updates = mm_missing_prompt_updates.get(modality, []) for item_idx, item_hash in enumerate(hashes): - kwargs: MultiModalKwargsItem | None if not mm_is_cached[modality][item_idx]: missing_next_idx = mm_missing_next_idx[modality] - kwargs = missing_kwargs[missing_next_idx] - updates = missing_prompt_updates[missing_next_idx] + missing_kwargs_item = missing_kwargs[missing_next_idx] + missing_updates_item = missing_prompt_updates[missing_next_idx] mm_missing_next_idx[modality] += 1 - item = kwargs, updates + item = missing_kwargs_item, missing_updates_item else: item = None @@ -1896,7 +1902,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_uuids=mm_uuids, ) - mm_missing_data_items = self._get_cache_missing_items( + mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, @@ -1933,6 +1939,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( cache, mm_hashes=mm_hashes, + mm_is_cached=mm_is_cached, mm_missing_kwargs=mm_missing_kwargs, mm_missing_prompt_updates=mm_missing_prompt_updates, )