mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 04:15:01 +08:00
[Bugfix] Missing cached item in the MultiModalReceiverCache (#28525)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: Chenguang Zheng <645327136@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
d0985c5feb
commit
fc6acc88ca
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -8,9 +9,16 @@ import torch
|
|||||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.cache import (
|
from vllm.multimodal.cache import (
|
||||||
|
BaseMultiModalProcessorCache,
|
||||||
|
BaseMultiModalReceiverCache,
|
||||||
MultiModalCache,
|
MultiModalCache,
|
||||||
|
MultiModalProcessorCacheInItem,
|
||||||
MultiModalProcessorCacheItem,
|
MultiModalProcessorCacheItem,
|
||||||
MultiModalProcessorCacheItemMetadata,
|
MultiModalProcessorCacheItemMetadata,
|
||||||
|
MultiModalProcessorSenderCache,
|
||||||
|
MultiModalReceiverCache,
|
||||||
|
ShmObjectStoreReceiverCache,
|
||||||
|
ShmObjectStoreSenderCache,
|
||||||
engine_receiver_cache_from_config,
|
engine_receiver_cache_from_config,
|
||||||
processor_cache_from_config,
|
processor_cache_from_config,
|
||||||
)
|
)
|
||||||
@ -22,6 +30,7 @@ from vllm.multimodal.inputs import (
|
|||||||
MultiModalSharedField,
|
MultiModalSharedField,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.processing import PromptInsertion
|
from vllm.multimodal.processing import PromptInsertion
|
||||||
|
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||||
|
|
||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
@ -144,8 +153,7 @@ def _compare_caches(
|
|||||||
MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items
|
MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items
|
||||||
]
|
]
|
||||||
|
|
||||||
# Should not be used since there is nothing to convert to text
|
prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0)
|
||||||
prompt_update = PromptInsertion("dummy", "target", "insertion")
|
|
||||||
|
|
||||||
for it in range(n_iter):
|
for it in range(n_iter):
|
||||||
num_items_to_select = rng.randint(0, max_items_per_iter)
|
num_items_to_select = rng.randint(0, max_items_per_iter)
|
||||||
@ -159,10 +167,11 @@ def _compare_caches(
|
|||||||
else:
|
else:
|
||||||
for _ in range(is_cached_calls_per_iter):
|
for _ in range(is_cached_calls_per_iter):
|
||||||
cache_0_p0.is_cached(selected_hashes)
|
cache_0_p0.is_cached(selected_hashes)
|
||||||
|
|
||||||
cache_0_p0_out = [
|
cache_0_p0_out = [
|
||||||
item
|
item
|
||||||
for item, _ in cache_0_p0.get_and_update(
|
for item, _ in cache_0_p0.get_and_update(
|
||||||
[(item, prompt_update.content) for item in selected_items],
|
[(item, [prompt_update]) for item in selected_items],
|
||||||
selected_hashes,
|
selected_hashes,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -172,10 +181,11 @@ def _compare_caches(
|
|||||||
else:
|
else:
|
||||||
for _ in range(is_cached_calls_per_iter):
|
for _ in range(is_cached_calls_per_iter):
|
||||||
cache_1_p0.is_cached(selected_hashes)
|
cache_1_p0.is_cached(selected_hashes)
|
||||||
|
|
||||||
cache_1_p0_out = [
|
cache_1_p0_out = [
|
||||||
item
|
item
|
||||||
for item, _ in cache_1_p0.get_and_update(
|
for item, _ in cache_1_p0.get_and_update(
|
||||||
[(item, prompt_update.content) for item in selected_items],
|
[(item, [prompt_update]) for item in selected_items],
|
||||||
selected_hashes,
|
selected_hashes,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -225,3 +235,289 @@ def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
|
|||||||
vllm_config_ipc_enabled,
|
vllm_config_ipc_enabled,
|
||||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_test_cache_eviction_lru(
|
||||||
|
p0_cache: BaseMultiModalProcessorCache,
|
||||||
|
p1_cache: BaseMultiModalReceiverCache,
|
||||||
|
base_item_size: int,
|
||||||
|
):
|
||||||
|
request1_hashes = [
|
||||||
|
"image_A",
|
||||||
|
"image_B",
|
||||||
|
"image_C",
|
||||||
|
]
|
||||||
|
request1_items = {
|
||||||
|
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
|
||||||
|
for h in request1_hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
|
||||||
|
request2_items = {
|
||||||
|
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
|
||||||
|
for h in request2_hashes
|
||||||
|
}
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# STEP 1: Request 1 send
|
||||||
|
##########################
|
||||||
|
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
|
||||||
|
# Cache is empty
|
||||||
|
assert sender_is_cached_item_req1 == [False, False, False]
|
||||||
|
|
||||||
|
# Touch all mm hash for P0 Cache before process
|
||||||
|
for mm_hash in request1_hashes:
|
||||||
|
p0_cache.touch_sender_cache_item(mm_hash)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 1 for P0 Cache
|
||||||
|
###########################
|
||||||
|
item_tuple: MultiModalProcessorCacheInItem
|
||||||
|
for i, h in enumerate(request1_hashes):
|
||||||
|
# Use precomputed cache state
|
||||||
|
is_cached = sender_is_cached_item_req1[i]
|
||||||
|
item_tuple = (request1_items[h], []) if not is_cached else None
|
||||||
|
print(f"Request 1: key={h} | cached={is_cached}")
|
||||||
|
|
||||||
|
p0_cache.get_and_update_item(item_tuple, h)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 1 for P1 Cache
|
||||||
|
###########################
|
||||||
|
# Touch all mm hash for P1 Cache before process
|
||||||
|
for mm_hash in request1_hashes:
|
||||||
|
p1_cache.touch_receiver_cache_item(mm_hash)
|
||||||
|
|
||||||
|
for h in request1_hashes:
|
||||||
|
p1_cache.get_and_update_item(request1_items[h], h)
|
||||||
|
|
||||||
|
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||||
|
assert list(p0_cache._cache.order) == expected_hashes
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# STEP 2: Request 2 send
|
||||||
|
##########################
|
||||||
|
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
|
||||||
|
assert sender_is_cached_item_req2 == [False, False, True, True]
|
||||||
|
|
||||||
|
# Touch all mm hash for P0 Cache before process
|
||||||
|
for mm_hash in request2_hashes:
|
||||||
|
p0_cache.touch_sender_cache_item(mm_hash)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 2 for P0 Cache
|
||||||
|
###########################
|
||||||
|
for i, h in enumerate(request2_hashes):
|
||||||
|
# Use precomputed cache state again
|
||||||
|
is_cached = sender_is_cached_item_req2[i]
|
||||||
|
item_tuple = (request2_items[h], []) if not is_cached else None
|
||||||
|
print(f"Request 2: key={h} | cached={is_cached}")
|
||||||
|
|
||||||
|
p0_cache.get_and_update_item(item_tuple, h)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 2 for P1 Cache
|
||||||
|
###########################
|
||||||
|
|
||||||
|
# Touch all mm hash for P1 Cache before process
|
||||||
|
for mm_hash in request2_hashes:
|
||||||
|
p1_cache.touch_receiver_cache_item(mm_hash)
|
||||||
|
|
||||||
|
for h in request2_hashes:
|
||||||
|
p1_cache.get_and_update_item(request2_items[h], h)
|
||||||
|
|
||||||
|
expected_hashes = ["image_D", "image_E", "image_A", "image_C"]
|
||||||
|
assert list(p0_cache._cache.order) == expected_hashes
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_eviction_lru_cache():
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
|
mm_processor_cache_gb=6 / GiB_bytes,
|
||||||
|
)
|
||||||
|
sender_cache = MultiModalProcessorSenderCache(model_config)
|
||||||
|
receiver_cache = MultiModalReceiverCache(model_config)
|
||||||
|
|
||||||
|
_run_test_cache_eviction_lru(sender_cache, receiver_cache, base_item_size=1)
|
||||||
|
|
||||||
|
|
||||||
|
# This test verifies shared-memory cache eviction behavior across processor (p0)
|
||||||
|
# and receiver (p1) caches.
|
||||||
|
# Flow summary:
|
||||||
|
# 1. Request 1 adds images A, B, C — completely filling the cache.
|
||||||
|
# 2. Request 2 tries to add image_G and image_A, but image_G cannot be added because
|
||||||
|
# cache is full and A is protected from eviction — cache remains unchanged.
|
||||||
|
# 3. Request 3 adds image_G, image_H, image_I and image_B
|
||||||
|
# this time, image_A is evicted, freeing 5MB space
|
||||||
|
# and image_G, image_H successfully fits,
|
||||||
|
# image_B is protected from eviction then image_i cannot be added.
|
||||||
|
# This proving normal eviction and reuse behavior.
|
||||||
|
def _run_test_cache_eviction_shm(
|
||||||
|
p0_cache: BaseMultiModalProcessorCache,
|
||||||
|
p1_cache: BaseMultiModalReceiverCache,
|
||||||
|
base_item_size: int,
|
||||||
|
):
|
||||||
|
request1_hashes = ["image_A", "image_B", "image_C"]
|
||||||
|
request1_items = {
|
||||||
|
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
|
||||||
|
for h in request1_hashes
|
||||||
|
}
|
||||||
|
request1_items_p0_result = []
|
||||||
|
|
||||||
|
request2_hashes = ["image_G", "image_A"]
|
||||||
|
request2_items = {
|
||||||
|
h: MultiModalKwargsItem.dummy(
|
||||||
|
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||||
|
)
|
||||||
|
for h in request2_hashes
|
||||||
|
}
|
||||||
|
request2_items_p0_result = []
|
||||||
|
|
||||||
|
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
|
||||||
|
request3_items = {
|
||||||
|
h: MultiModalKwargsItem.dummy(
|
||||||
|
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||||
|
)
|
||||||
|
for h in request3_hashes
|
||||||
|
}
|
||||||
|
request3_items_p0_result = []
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# STEP 1: Request 1 send
|
||||||
|
# This will fill up the cache
|
||||||
|
##########################
|
||||||
|
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
|
||||||
|
# Cache is empty
|
||||||
|
assert sender_is_cached_item_req1 == [False, False, False]
|
||||||
|
|
||||||
|
# Touch all mm hash for P0 Cache before process
|
||||||
|
for mm_hash in request1_hashes:
|
||||||
|
p0_cache.touch_sender_cache_item(mm_hash)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 1 for P0 Cache
|
||||||
|
###########################
|
||||||
|
item_tuple: MultiModalProcessorCacheInItem
|
||||||
|
for i, h in enumerate(request1_hashes):
|
||||||
|
# Use precomputed cache state
|
||||||
|
is_cached = sender_is_cached_item_req1[i]
|
||||||
|
item_tuple = (request1_items[h], []) if not is_cached else None
|
||||||
|
print(f"Request 1: key={h} | cached={is_cached}")
|
||||||
|
|
||||||
|
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||||
|
# Only get mm item, ignore prompt update result
|
||||||
|
request1_items_p0_result.append(p0_result[0])
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 1 for P1 Cache
|
||||||
|
###########################
|
||||||
|
# Touch all mm hash for P1 Cache before process
|
||||||
|
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
|
||||||
|
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||||
|
|
||||||
|
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
|
||||||
|
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||||
|
|
||||||
|
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||||
|
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# STEP 2: Request 2 send
|
||||||
|
# There is no eviction because image_A is protected
|
||||||
|
# No new item can add to cache
|
||||||
|
##########################
|
||||||
|
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
|
||||||
|
assert sender_is_cached_item_req2 == [False, True]
|
||||||
|
|
||||||
|
# Touch all mm hash for P0 Cache before process
|
||||||
|
for mm_hash in request2_hashes:
|
||||||
|
p0_cache.touch_sender_cache_item(mm_hash)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 2 for P0 Cache
|
||||||
|
###########################
|
||||||
|
for i, h in enumerate(request2_hashes):
|
||||||
|
# Use precomputed cache state again
|
||||||
|
is_cached = sender_is_cached_item_req2[i]
|
||||||
|
item_tuple = (request2_items[h], []) if not is_cached else None
|
||||||
|
print(f"Request 2: key={h} | cached={is_cached}")
|
||||||
|
|
||||||
|
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||||
|
# Only get mm item, ignore prompt update result
|
||||||
|
request2_items_p0_result.append(p0_result[0])
|
||||||
|
|
||||||
|
# image_A cannot be evict then
|
||||||
|
# image_G will fail to allocate anyway and image_A still in cache
|
||||||
|
assert p0_cache.is_cached(request2_hashes) == [False, True]
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 2 for P1 Cache
|
||||||
|
###########################
|
||||||
|
|
||||||
|
# Touch all mm hash for P1 Cache before process
|
||||||
|
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
|
||||||
|
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||||
|
|
||||||
|
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
|
||||||
|
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||||
|
|
||||||
|
# Prove that cache state is unchanged
|
||||||
|
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||||
|
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||||
|
|
||||||
|
##########################
|
||||||
|
# STEP 3: Request 3 send
|
||||||
|
##########################
|
||||||
|
##### Prove that cache eviction work normally
|
||||||
|
sender_is_cached_item_req3 = p0_cache.is_cached(request3_hashes)
|
||||||
|
assert sender_is_cached_item_req3 == [False, False, False, True]
|
||||||
|
|
||||||
|
# Touch all mm hash for P0 Cache before process
|
||||||
|
for mm_hash in request3_hashes:
|
||||||
|
p0_cache.touch_sender_cache_item(mm_hash)
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 3 for P0 Cache
|
||||||
|
###########################
|
||||||
|
for i, h in enumerate(request3_hashes):
|
||||||
|
# Use precomputed cache state again
|
||||||
|
is_cached = sender_is_cached_item_req3[i]
|
||||||
|
item_tuple = (request3_items[h], []) if not is_cached else None
|
||||||
|
print(f"Request 3: key={h} | cached={is_cached}")
|
||||||
|
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||||
|
# Only get mm item, ignore prompt update result
|
||||||
|
request3_items_p0_result.append(p0_result[0])
|
||||||
|
|
||||||
|
# image_A got evict and image_G add to cache
|
||||||
|
# image_B is still protected
|
||||||
|
# image_G, image_H fit but image_I cannot fit
|
||||||
|
assert p0_cache.is_cached(request3_hashes) == [True, True, False, True]
|
||||||
|
|
||||||
|
###########################
|
||||||
|
# Process request 3 for P1 Cache
|
||||||
|
###########################
|
||||||
|
|
||||||
|
# Touch all mm hash for P1 Cache before process
|
||||||
|
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
|
||||||
|
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||||
|
|
||||||
|
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
|
||||||
|
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||||
|
|
||||||
|
expected_hashes = ["image_B", "image_C", "image_G", "image_H"]
|
||||||
|
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_eviction_shm_cache():
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=ModelConfig(
|
||||||
|
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
|
mm_processor_cache_type="shm",
|
||||||
|
mm_shm_cache_max_object_size_mb=6,
|
||||||
|
mm_processor_cache_gb=15.2 * MiB_bytes / GiB_bytes,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sender_cache = ShmObjectStoreSenderCache(vllm_config)
|
||||||
|
receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock())
|
||||||
|
|
||||||
|
_run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes)
|
||||||
|
|||||||
@ -574,7 +574,6 @@ class SingleWriterShmObjectStorage:
|
|||||||
value
|
value
|
||||||
)
|
)
|
||||||
buffer_size = self.flag_bytes + data_bytes + md_bytes
|
buffer_size = self.flag_bytes + data_bytes + md_bytes
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
if buffer_size > self.max_object_size:
|
if buffer_size > self.max_object_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -626,6 +625,44 @@ class SingleWriterShmObjectStorage:
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def touch(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
address: int = 0,
|
||||||
|
monotonic_id: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Touch an existing cached item to update its eviction status.
|
||||||
|
|
||||||
|
For writers (ShmObjectStoreSenderCache): Increment writer_flag
|
||||||
|
For readers (ShmObjectStoreReceiverCache): Increment reader_count
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: String key of the object to touch
|
||||||
|
address: Address of the object (only for readers)
|
||||||
|
monotonic_id: Monotonic ID of the object (only for readers)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._reader_lock is None:
|
||||||
|
if key not in self.key_index:
|
||||||
|
return None
|
||||||
|
address, monotonic_id = self.key_index[key]
|
||||||
|
# Writer side: increment writer_flag to raise eviction threshold
|
||||||
|
self.increment_writer_flag(monotonic_id)
|
||||||
|
else:
|
||||||
|
with (
|
||||||
|
self._reader_lock,
|
||||||
|
self.ring_buffer.access_buf(address) as (data_view, _),
|
||||||
|
):
|
||||||
|
reader_count = self.ring_buffer.byte2int(data_view[: self.flag_bytes])
|
||||||
|
|
||||||
|
# NOTE(Long):
|
||||||
|
# Avoid increasing flag on newly added item (sync with sender)
|
||||||
|
# Since when a new item is added
|
||||||
|
# pre-touch has no effect on writer side
|
||||||
|
if reader_count >= self.n_readers:
|
||||||
|
self.increment_reader_flag(data_view[: self.flag_bytes])
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
"""Get handle for sharing across processes."""
|
"""Get handle for sharing across processes."""
|
||||||
return ShmObjectStorageHandle(
|
return ShmObjectStorageHandle(
|
||||||
|
|||||||
@ -302,6 +302,19 @@ class BaseMultiModalProcessorCache(
|
|||||||
"""
|
"""
|
||||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||||
|
"""
|
||||||
|
Update the cache eviction order for a multi-modal item.
|
||||||
|
|
||||||
|
This is used to touch the item in the cache without changing
|
||||||
|
its value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_hash: The hash of the multi-modal item.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||||
"""
|
"""
|
||||||
@ -353,6 +366,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
|||||||
|
|
||||||
return mm_item
|
return mm_item
|
||||||
|
|
||||||
|
@override
|
||||||
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||||
|
self._cache.touch(mm_hash)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self._cache.clear()
|
self._cache.clear()
|
||||||
@ -407,6 +424,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
|||||||
|
|
||||||
return mm_item
|
return mm_item
|
||||||
|
|
||||||
|
@override
|
||||||
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||||
|
self._cache.touch(mm_hash)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self._cache.clear()
|
self._cache.clear()
|
||||||
@ -501,6 +522,12 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
|||||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
||||||
return mm_item
|
return mm_item
|
||||||
|
|
||||||
|
@override
|
||||||
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
||||||
|
"""Touch the item in shared memory cache to prevent eviction.
|
||||||
|
Increments writer_flag on sender side."""
|
||||||
|
self._shm_cache.touch(mm_hash)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self._shm_cache.clear()
|
self._shm_cache.clear()
|
||||||
@ -610,11 +637,37 @@ class BaseMultiModalReceiverCache(
|
|||||||
self,
|
self,
|
||||||
mm_features: list["MultiModalFeatureSpec"],
|
mm_features: list["MultiModalFeatureSpec"],
|
||||||
) -> list["MultiModalFeatureSpec"]:
|
) -> list["MultiModalFeatureSpec"]:
|
||||||
"""Update multimodal features with cached encoder outputs."""
|
"""
|
||||||
|
Update multimodal features with cached encoder outputs.
|
||||||
|
Touch all identifier at first before update to avoid
|
||||||
|
item in updated list evict during update.
|
||||||
|
"""
|
||||||
|
for feature in mm_features:
|
||||||
|
self.touch_receiver_cache_item(feature.identifier, feature.data)
|
||||||
|
|
||||||
for feature in mm_features:
|
for feature in mm_features:
|
||||||
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
||||||
return mm_features
|
return mm_features
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def touch_receiver_cache_item(
|
||||||
|
self,
|
||||||
|
mm_hash: str,
|
||||||
|
mm_item: MultiModalKwargsItem | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the cache eviction order for a multi-modal item.
|
||||||
|
|
||||||
|
This is used to touch the item in the cache without changing
|
||||||
|
its value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_hash: The hash of the multi-modal item.
|
||||||
|
mm_item: The multi-modal item itself. This is optional and
|
||||||
|
may not be needed by some cache implementations.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||||
"""
|
"""
|
||||||
@ -651,6 +704,14 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
|||||||
self._cache[mm_hash] = mm_item
|
self._cache[mm_hash] = mm_item
|
||||||
return mm_item
|
return mm_item
|
||||||
|
|
||||||
|
@override
|
||||||
|
def touch_receiver_cache_item(
|
||||||
|
self,
|
||||||
|
mm_hash: str,
|
||||||
|
mm_item: MultiModalKwargsItem | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._cache.touch(mm_hash)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self._cache.clear()
|
self._cache.clear()
|
||||||
@ -703,6 +764,20 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
|||||||
|
|
||||||
return mm_item
|
return mm_item
|
||||||
|
|
||||||
|
@override
|
||||||
|
def touch_receiver_cache_item(
|
||||||
|
self,
|
||||||
|
mm_hash: str,
|
||||||
|
mm_item: MultiModalKwargsItem | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Touch the item in shared memory cache to prevent eviction.
|
||||||
|
Increments reader_count on receiver side."""
|
||||||
|
assert mm_item is not None
|
||||||
|
if "address" in mm_item:
|
||||||
|
address = cast(int, mm_item["address"].data)
|
||||||
|
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||||
|
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
self._shm_cache.clear()
|
self._shm_cache.clear()
|
||||||
|
|||||||
@ -721,12 +721,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dummy(modality: str):
|
def dummy(modality: str, nbytes: int = 1):
|
||||||
"""Convenience class for testing."""
|
"""Convenience class for testing."""
|
||||||
mm_elem = MultiModalFieldElem(
|
mm_elem = MultiModalFieldElem(
|
||||||
modality=modality,
|
modality=modality,
|
||||||
key="dummy",
|
key="dummy",
|
||||||
data=torch.empty(1),
|
data=torch.empty(nbytes, dtype=torch.uint8),
|
||||||
field=MultiModalSharedField(1),
|
field=MultiModalSharedField(1),
|
||||||
)
|
)
|
||||||
return MultiModalKwargsItem.from_elems([mm_elem])
|
return MultiModalKwargsItem.from_elems([mm_elem])
|
||||||
|
|||||||
@ -1248,7 +1248,13 @@ _I = TypeVar("_I", bound=BaseProcessingInfo)
|
|||||||
|
|
||||||
MultiModalHashes = dict[str, list[str]]
|
MultiModalHashes = dict[str, list[str]]
|
||||||
"""
|
"""
|
||||||
A collection of hashes with a similar structure as
|
A collection of the multi-modal hash for each item, with a similar structure as
|
||||||
|
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||||
|
"""
|
||||||
|
|
||||||
|
MultiModalIsCached = dict[str, list[bool]]
|
||||||
|
"""
|
||||||
|
A collection of the `is_cached` flag for each item, with a similar structure as
|
||||||
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -1725,7 +1731,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
cache: BaseMultiModalProcessorCache,
|
cache: BaseMultiModalProcessorCache,
|
||||||
mm_data_items: MultiModalDataItems,
|
mm_data_items: MultiModalDataItems,
|
||||||
mm_hashes: MultiModalHashes,
|
mm_hashes: MultiModalHashes,
|
||||||
) -> MultiModalDataItems:
|
) -> tuple[MultiModalIsCached, MultiModalDataItems]:
|
||||||
mm_is_cached = {
|
mm_is_cached = {
|
||||||
modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
|
modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
|
||||||
}
|
}
|
||||||
@ -1752,7 +1758,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
missing_modality_data.append(data)
|
missing_modality_data.append(data)
|
||||||
mm_missing_data[modality] = missing_modality_data
|
mm_missing_data[modality] = missing_modality_data
|
||||||
|
|
||||||
return self._to_mm_items(mm_missing_data)
|
return mm_is_cached, self._to_mm_items(mm_missing_data)
|
||||||
|
|
||||||
def _recompute_cached_prompt_update(
|
def _recompute_cached_prompt_update(
|
||||||
self,
|
self,
|
||||||
@ -1769,14 +1775,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
self,
|
self,
|
||||||
cache: BaseMultiModalProcessorCache,
|
cache: BaseMultiModalProcessorCache,
|
||||||
mm_hashes: MultiModalHashes,
|
mm_hashes: MultiModalHashes,
|
||||||
|
mm_is_cached: MultiModalIsCached,
|
||||||
mm_missing_kwargs: MultiModalKwargsItems,
|
mm_missing_kwargs: MultiModalKwargsItems,
|
||||||
mm_missing_prompt_updates: MultiModalPromptUpdates,
|
mm_missing_prompt_updates: MultiModalPromptUpdates,
|
||||||
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
|
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
|
||||||
# Need to calculate this at the beginning to avoid skipping cache logic
|
# Need to touch all mm hashes before update to avoid hash in updated
|
||||||
# for subsequently repeated items in the same modality
|
# list evict during update
|
||||||
mm_is_cached = {
|
for hashes in mm_hashes.values():
|
||||||
modality: cache.is_cached(hashes) for modality, hashes in mm_hashes.items()
|
for item_hash in hashes:
|
||||||
}
|
cache.touch_sender_cache_item(item_hash)
|
||||||
|
|
||||||
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
|
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
|
||||||
|
|
||||||
@ -1789,15 +1796,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
|
missing_prompt_updates = mm_missing_prompt_updates.get(modality, [])
|
||||||
|
|
||||||
for item_idx, item_hash in enumerate(hashes):
|
for item_idx, item_hash in enumerate(hashes):
|
||||||
kwargs: MultiModalKwargsItem | None
|
|
||||||
if not mm_is_cached[modality][item_idx]:
|
if not mm_is_cached[modality][item_idx]:
|
||||||
missing_next_idx = mm_missing_next_idx[modality]
|
missing_next_idx = mm_missing_next_idx[modality]
|
||||||
kwargs = missing_kwargs[missing_next_idx]
|
missing_kwargs_item = missing_kwargs[missing_next_idx]
|
||||||
updates = missing_prompt_updates[missing_next_idx]
|
missing_updates_item = missing_prompt_updates[missing_next_idx]
|
||||||
|
|
||||||
mm_missing_next_idx[modality] += 1
|
mm_missing_next_idx[modality] += 1
|
||||||
|
|
||||||
item = kwargs, updates
|
item = missing_kwargs_item, missing_updates_item
|
||||||
else:
|
else:
|
||||||
item = None
|
item = None
|
||||||
|
|
||||||
@ -1896,7 +1902,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_uuids=mm_uuids,
|
mm_uuids=mm_uuids,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_missing_data_items = self._get_cache_missing_items(
|
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
|
||||||
cache=cache,
|
cache=cache,
|
||||||
mm_data_items=mm_data_items,
|
mm_data_items=mm_data_items,
|
||||||
mm_hashes=mm_hashes,
|
mm_hashes=mm_hashes,
|
||||||
@ -1933,6 +1939,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
||||||
cache,
|
cache,
|
||||||
mm_hashes=mm_hashes,
|
mm_hashes=mm_hashes,
|
||||||
|
mm_is_cached=mm_is_cached,
|
||||||
mm_missing_kwargs=mm_missing_kwargs,
|
mm_missing_kwargs=mm_missing_kwargs,
|
||||||
mm_missing_prompt_updates=mm_missing_prompt_updates,
|
mm_missing_prompt_updates=mm_missing_prompt_updates,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user