mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 09:34:29 +08:00
824 lines
25 KiB
Python
824 lines
25 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import operator
|
|
import sys
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Mapping, Sequence
|
|
from multiprocessing.synchronize import Lock as LockType
|
|
from typing import TYPE_CHECKING, Generic, TypeAlias, TypeVar, cast
|
|
|
|
import torch
|
|
from typing_extensions import override
|
|
|
|
import vllm.envs as envs
|
|
from vllm.distributed.device_communicators.shm_object_storage import (
|
|
MsgpackSerde,
|
|
SingleWriterShmObjectStorage,
|
|
SingleWriterShmRingBuffer,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.cache import CacheInfo, LRUCache
|
|
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
|
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
|
|
|
from .inputs import (
|
|
MultiModalBatchedField,
|
|
MultiModalFeatureSpec,
|
|
MultiModalFieldElem,
|
|
MultiModalKwargsItem,
|
|
MultiModalKwargsItems,
|
|
NestedTensors,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig, VllmConfig
|
|
|
|
from .processing import ResolvedPromptUpdate
|
|
from .registry import MultiModalRegistry
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MultiModalProcessorCacheItem:
|
|
"""
|
|
The data to store inside `MultiModalProcessorOnlyCache`.
|
|
|
|
Args:
|
|
item: The processed tensor data corresponding to a multi-modal item.
|
|
prompt_updates: The prompt updates corresponding to `item`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
item: MultiModalKwargsItem,
|
|
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.item = item
|
|
self.prompt_updates = prompt_updates
|
|
|
|
|
|
class MultiModalProcessorCacheItemMetadata:
|
|
"""
|
|
The metadata to store inside `MultiModalProcessorSenderCache`.
|
|
|
|
Args:
|
|
item: The processed tensor data corresponding to a multi-modal item.
|
|
Since P1 already stores the tensor data, we only store its size
|
|
metadata in P0 to reduce memory usage. The size metadata is still
|
|
needed to keep the same cache eviction policy as P0.
|
|
prompt_updates: The prompt updates corresponding to `item`.
|
|
This needs to stay on P0 because for some models, they are
|
|
dependent on the processed tensor data (cached on P1).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
item: MultiModalKwargsItem,
|
|
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.item_size = MultiModalCache.get_item_size(item)
|
|
self.prompt_updates = prompt_updates
|
|
|
|
|
|
MultiModalCacheValue: TypeAlias = (
|
|
MultiModalProcessorCacheItem
|
|
| MultiModalProcessorCacheItemMetadata
|
|
| MultiModalKwargsItems
|
|
| MultiModalKwargsItem
|
|
| Mapping[str, NestedTensors]
|
|
)
|
|
|
|
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
|
|
|
|
|
class MultiModalCache:
|
|
@classmethod
|
|
def get_leaf_size(cls, leaf: object) -> int:
|
|
if isinstance(leaf, MultiModalProcessorCacheItem):
|
|
return cls.get_leaf_size(leaf.item)
|
|
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
|
return leaf.item_size
|
|
|
|
# These are not subclasses of dict
|
|
if isinstance(
|
|
leaf,
|
|
(MultiModalKwargsItems, MultiModalKwargsItem, MultiModalFieldElem),
|
|
):
|
|
return cls.get_item_size(leaf.data) # type: ignore
|
|
|
|
# sys.getsizeof doesn't work for tensors
|
|
if isinstance(leaf, torch.Tensor):
|
|
return leaf.nbytes
|
|
|
|
return sys.getsizeof(leaf)
|
|
|
|
@classmethod
|
|
def get_item_size(
|
|
cls,
|
|
value: MultiModalCacheValue,
|
|
*,
|
|
debug: bool = False,
|
|
) -> int:
|
|
size = json_reduce_leaves(
|
|
operator.add, json_map_leaves(cls.get_leaf_size, value)
|
|
)
|
|
|
|
if debug:
|
|
leaf_count = json_count_leaves(value)
|
|
logger.debug(
|
|
"Calculated size of %s to be %.2f GiB (%d leaves)",
|
|
type(value),
|
|
size / GiB_bytes,
|
|
leaf_count,
|
|
)
|
|
|
|
return size
|
|
|
|
@classmethod
|
|
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
|
|
"""
|
|
Get the number of leaf elements in a multi-modal cache value.
|
|
|
|
This provides a measure of structural complexity that can be useful
|
|
for debugging cache performance and understanding data patterns.
|
|
|
|
Args:
|
|
value: The multi-modal cache value to analyze.
|
|
|
|
Returns:
|
|
The number of leaf elements in the nested structure.
|
|
"""
|
|
return json_count_leaves(value)
|
|
|
|
@classmethod
|
|
def get_lru_cache(
|
|
cls,
|
|
capacity_gb: float,
|
|
value_type: type[_V],
|
|
*,
|
|
debug: bool = False,
|
|
) -> LRUCache[str, _V]:
|
|
return LRUCache(
|
|
GiB_bytes * capacity_gb,
|
|
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
|
)
|
|
|
|
|
|
_I = TypeVar("_I", contravariant=True)
|
|
_O = TypeVar("_O", covariant=True)
|
|
|
|
|
|
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
|
"""
|
|
Abstract base class to read/write multi-modal items from cache.
|
|
|
|
The idea of multi-modal caching is based on having a client and server
|
|
where the client executes in the frontend process (=P0) and
|
|
the server in the core process (=P1). The data flow is as follows:
|
|
|
|
```
|
|
is_cached() x N get_and_update()
|
|
P0: From API -----------------> -----------------> To P1
|
|
|
|
get_and_update()
|
|
P1: From P0 -----------------> To model
|
|
```
|
|
|
|
`is_cached()` can be called any number of times in P0. However,
|
|
`get_and_update()` must be called in P0 and P1 one after another
|
|
so that their cache eviction order remains the same.
|
|
|
|
This ensures that the keys in P0 and P1 caches are mirrored,
|
|
allowing us to determine whether a key is cached in P1 by looking
|
|
up the P0 cache, without having to communicate with P1.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: _I,
|
|
mm_hash: str,
|
|
) -> _O:
|
|
"""
|
|
Possibly update a multi-modal item based on whether it is
|
|
in the underlying cache.
|
|
|
|
This update is done out-of-place and updates the cache eviction order.
|
|
|
|
Args:
|
|
mm_item: The multi-modal item to update.
|
|
mm_hash: The hash of `mm_item`.
|
|
|
|
Returns:
|
|
The update multi-modal item.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def get_and_update(
|
|
self,
|
|
mm_items: Sequence[_I],
|
|
mm_hashes: list[str],
|
|
) -> list[_O]:
|
|
"""
|
|
Possibly update a sequence of multi-modal items based on whether they
|
|
are in the underlying cache.
|
|
|
|
This update is done out-of-place and updates the cache eviction order.
|
|
|
|
Args:
|
|
mm_items: The multi-modal items to update.
|
|
mm_hashes: The hash of each item in `mm_items`.
|
|
|
|
Returns:
|
|
A new list of updated multi-modal items.
|
|
"""
|
|
assert len(mm_items) == len(mm_hashes)
|
|
|
|
return [
|
|
self.get_and_update_item(mm_item, mm_hash)
|
|
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
|
]
|
|
|
|
@abstractmethod
|
|
def clear_cache(self) -> None:
|
|
"""Clear the underlying cache."""
|
|
raise NotImplementedError
|
|
|
|
|
|
MultiModalProcessorCacheInItem: TypeAlias = (
|
|
tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]] | None
|
|
)
|
|
|
|
|
|
MultiModalProcessorCacheOutItem: TypeAlias = tuple[
|
|
MultiModalKwargsItem | None, Sequence["ResolvedPromptUpdate"]
|
|
]
|
|
|
|
|
|
class BaseMultiModalProcessorCache(
|
|
BaseMultiModalCache[MultiModalProcessorCacheInItem, MultiModalProcessorCacheOutItem]
|
|
):
|
|
"""The required interface for caches on P0."""
|
|
|
|
@abstractmethod
|
|
def is_cached_item(self, mm_hash: str) -> bool:
|
|
"""
|
|
Check whether a multi-modal item is
|
|
in the underlying cache.
|
|
|
|
This **DOES NOT** update the cache eviction order.
|
|
|
|
Args:
|
|
mm_hash: The hash of the item to check.
|
|
|
|
Returns:
|
|
`True` if the item is cached, otherwise `False`.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
|
"""
|
|
Check whether a sequence of multi-modal items are
|
|
in the underlying cache.
|
|
|
|
This **DOES NOT** update the cache eviction order.
|
|
|
|
Args:
|
|
mm_hashes: The hash of each item to check.
|
|
|
|
Returns:
|
|
For each item, `True` if the item is cached, otherwise `False`.
|
|
"""
|
|
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
|
|
|
@abstractmethod
|
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
|
"""
|
|
Update the cache eviction order for a multi-modal item.
|
|
|
|
This is used to touch the item in the cache without changing
|
|
its value.
|
|
|
|
Args:
|
|
mm_hash: The hash of the multi-modal item.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
|
"""
|
|
Get (and reset) the multi-modal cache stats.
|
|
|
|
Returns:
|
|
The current multi-modal caching stats.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
|
"""
|
|
The cache which is used on P0 when IPC caching is disabled.
|
|
|
|
How to update each item:
|
|
|
|
- If the item is in the cache, replace the input with the cached item.
|
|
- If the item is not in the cache, store that item (which includes
|
|
tensor data and metadata) into the cache, and return the input.
|
|
"""
|
|
|
|
def __init__(self, model_config: "ModelConfig") -> None:
|
|
super().__init__()
|
|
|
|
mm_config = model_config.get_multimodal_config()
|
|
|
|
self._cache = MultiModalCache.get_lru_cache(
|
|
mm_config.mm_processor_cache_gb,
|
|
MultiModalProcessorCacheItem,
|
|
)
|
|
|
|
@override
|
|
def is_cached_item(self, mm_hash: str) -> bool:
|
|
return mm_hash in self._cache
|
|
|
|
@override
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: MultiModalProcessorCacheInItem,
|
|
mm_hash: str,
|
|
) -> MultiModalProcessorCacheOutItem:
|
|
if (cached_item := self._cache.get(mm_hash)) is not None:
|
|
return cached_item.item, cached_item.prompt_updates
|
|
|
|
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
|
|
|
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
|
|
|
return mm_item
|
|
|
|
@override
|
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
|
self._cache.touch(mm_hash)
|
|
|
|
@override
|
|
def clear_cache(self) -> None:
|
|
self._cache.clear()
|
|
|
|
@override
|
|
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
|
return self._cache.stat(delta=delta)
|
|
|
|
|
|
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
|
"""
|
|
The cache which is used on P0 when IPC caching is enabled.
|
|
|
|
How to update each item:
|
|
|
|
- If the item is already in the cache, clear the input to avoid
|
|
unnecessary IPC.
|
|
|
|
- If the item is not in the cache, store the metadata of that item so
|
|
that the eviction policy remains the same as the cache on P1,
|
|
and return the input.
|
|
By only storing the metadata, we avoid keeping the data itself in
|
|
memory inside P0.
|
|
"""
|
|
|
|
def __init__(self, model_config: "ModelConfig") -> None:
|
|
super().__init__()
|
|
|
|
mm_config = model_config.get_multimodal_config()
|
|
|
|
self._cache = MultiModalCache.get_lru_cache(
|
|
mm_config.mm_processor_cache_gb,
|
|
MultiModalProcessorCacheItemMetadata,
|
|
)
|
|
|
|
@override
|
|
def is_cached_item(self, mm_hash: str) -> bool:
|
|
return mm_hash in self._cache
|
|
|
|
@override
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: MultiModalProcessorCacheInItem,
|
|
mm_hash: str,
|
|
) -> MultiModalProcessorCacheOutItem:
|
|
if (cached_item := self._cache.get(mm_hash)) is not None:
|
|
return None, cached_item.prompt_updates
|
|
|
|
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
|
|
|
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
|
|
|
return mm_item
|
|
|
|
@override
|
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
|
self._cache.touch(mm_hash)
|
|
|
|
@override
|
|
def clear_cache(self) -> None:
|
|
self._cache.clear()
|
|
|
|
@override
|
|
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
|
return self._cache.stat(delta=delta)
|
|
|
|
|
|
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
|
"""
|
|
The cache which is used on P0 when IPC caching is enabled.
|
|
|
|
How to update each item:
|
|
|
|
- If the item is already in the cache, clear the input to avoid
|
|
unnecessary IPC.
|
|
|
|
- If the item is not in the cache, store the data in shared memory.
|
|
"""
|
|
|
|
def __init__(self, vllm_config: "VllmConfig") -> None:
|
|
super().__init__()
|
|
|
|
self.world_size = vllm_config.parallel_config.world_size
|
|
mm_config = vllm_config.model_config.get_multimodal_config()
|
|
|
|
ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
|
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
|
create=True, # sender is the writer
|
|
)
|
|
self._shm_cache = SingleWriterShmObjectStorage(
|
|
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
|
n_readers=self.world_size,
|
|
ring_buffer=ring_buffer,
|
|
serde_class=MsgpackSerde,
|
|
)
|
|
# cache (prompt_updates, modality) for P0 only
|
|
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
|
|
|
|
self._hits = 0
|
|
self._total = 0
|
|
self._last_info = CacheInfo(hits=0, total=0)
|
|
|
|
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
|
info = CacheInfo(hits=self._hits, total=self._total)
|
|
|
|
if delta:
|
|
info_delta = info - self._last_info
|
|
self._last_info = info
|
|
info = info_delta
|
|
|
|
return info
|
|
|
|
@override
|
|
def is_cached_item(self, mm_hash: str) -> bool:
|
|
return self._shm_cache.is_cached(mm_hash)
|
|
|
|
@override
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: MultiModalProcessorCacheInItem,
|
|
mm_hash: str,
|
|
) -> MultiModalProcessorCacheOutItem:
|
|
if self._shm_cache.is_cached(mm_hash):
|
|
self._hits += 1
|
|
self._total += 1
|
|
|
|
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
|
prompt_updates, modality = self._p0_cache[mm_hash]
|
|
return self.address_as_item(address, monotonic_id, modality), prompt_updates
|
|
|
|
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
|
|
|
self._total += 1
|
|
|
|
try:
|
|
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
|
# Try to remove dangling items if p0 cache is too large.
|
|
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
|
self.remove_dangling_items()
|
|
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
|
address_item = self.address_as_item(
|
|
address, monotonic_id, mm_item[0].modality
|
|
)
|
|
return address_item, mm_item[1]
|
|
except (ValueError, MemoryError) as e:
|
|
# put may fail if the object is too large or
|
|
# the cache is full.
|
|
# In this case we log the error and keep the original mm_input.
|
|
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, e)
|
|
return mm_item
|
|
|
|
@override
|
|
def touch_sender_cache_item(self, mm_hash: str) -> None:
|
|
"""Touch the item in shared memory cache to prevent eviction.
|
|
Increments writer_flag on sender side."""
|
|
self._shm_cache.touch(mm_hash)
|
|
|
|
@override
|
|
def clear_cache(self) -> None:
|
|
self._shm_cache.clear()
|
|
self._p0_cache.clear()
|
|
|
|
self._hits = 0
|
|
self._total = 0
|
|
self._last_info = CacheInfo(hits=0, total=0)
|
|
|
|
@override
|
|
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
|
return self._stat(delta=delta)
|
|
|
|
def remove_dangling_items(self) -> None:
|
|
"""Remove items that are no longer in the shared memory cache."""
|
|
cached_hashes = self._shm_cache.key_index.keys()
|
|
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
|
for mm_hash in dangling_hashes:
|
|
del self._p0_cache[mm_hash]
|
|
|
|
def address_as_item(
|
|
self, address: int, monotonic_id: int, modality: str
|
|
) -> MultiModalKwargsItem:
|
|
addr_elem = MultiModalFieldElem(
|
|
modality=modality,
|
|
key="address",
|
|
data=address,
|
|
field=MultiModalBatchedField(),
|
|
)
|
|
id_elem = MultiModalFieldElem(
|
|
modality=modality,
|
|
key="monotonic_id",
|
|
data=monotonic_id,
|
|
field=MultiModalBatchedField(),
|
|
)
|
|
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
|
return mm_item
|
|
|
|
|
|
def _enable_processor_cache(
|
|
model_config: "ModelConfig",
|
|
mm_registry: "MultiModalRegistry",
|
|
) -> bool:
|
|
if not mm_registry.supports_multimodal_inputs(model_config):
|
|
return False
|
|
|
|
mm_config = model_config.get_multimodal_config()
|
|
return mm_config.mm_processor_cache_gb > 0
|
|
|
|
|
|
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
|
parallel_config = vllm_config.parallel_config
|
|
supports_ipc_cache = (
|
|
parallel_config._api_process_count == 1
|
|
and parallel_config.data_parallel_size == 1
|
|
) or parallel_config.data_parallel_external_lb
|
|
|
|
return supports_ipc_cache
|
|
|
|
|
|
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
|
"""Whether the shared memory based cache should be enabled."""
|
|
|
|
if not _enable_ipc_cache(vllm_config):
|
|
return False
|
|
|
|
mm_config = vllm_config.model_config.get_multimodal_config()
|
|
|
|
return mm_config.mm_processor_cache_type == "shm"
|
|
|
|
|
|
def processor_cache_from_config(
|
|
vllm_config: "VllmConfig",
|
|
mm_registry: "MultiModalRegistry",
|
|
) -> BaseMultiModalProcessorCache | None:
|
|
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
|
model_config = vllm_config.model_config
|
|
|
|
if not _enable_processor_cache(model_config, mm_registry):
|
|
return None
|
|
|
|
if not _enable_ipc_cache(vllm_config):
|
|
return MultiModalProcessorOnlyCache(model_config)
|
|
|
|
if not _enable_mm_input_shm_cache(vllm_config):
|
|
return MultiModalProcessorSenderCache(model_config)
|
|
return ShmObjectStoreSenderCache(vllm_config)
|
|
|
|
|
|
def processor_only_cache_from_config(
|
|
model_config: "ModelConfig",
|
|
mm_registry: "MultiModalRegistry",
|
|
):
|
|
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
|
if not _enable_processor_cache(model_config, mm_registry):
|
|
return None
|
|
|
|
return MultiModalProcessorOnlyCache(model_config)
|
|
|
|
|
|
class BaseMultiModalReceiverCache(
|
|
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
|
):
|
|
"""The required interface for caches on P1."""
|
|
|
|
def get_and_update_features(
|
|
self,
|
|
mm_features: list["MultiModalFeatureSpec"],
|
|
) -> list["MultiModalFeatureSpec"]:
|
|
"""
|
|
Update multimodal features with cached encoder outputs.
|
|
Touch all identifier at first before update to avoid
|
|
item in updated list evict during update.
|
|
"""
|
|
for feature in mm_features:
|
|
self.touch_receiver_cache_item(feature.identifier, feature.data)
|
|
|
|
for feature in mm_features:
|
|
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
|
return mm_features
|
|
|
|
@abstractmethod
|
|
def touch_receiver_cache_item(
|
|
self,
|
|
mm_hash: str,
|
|
mm_item: MultiModalKwargsItem | None = None,
|
|
) -> None:
|
|
"""
|
|
Update the cache eviction order for a multi-modal item.
|
|
|
|
This is used to touch the item in the cache without changing
|
|
its value.
|
|
|
|
Args:
|
|
mm_hash: The hash of the multi-modal item.
|
|
mm_item: The multi-modal item itself. This is optional and
|
|
may not be needed by some cache implementations.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
|
"""
|
|
The cache which is used on P1 when IPC caching is enabled.
|
|
|
|
How to update each item:
|
|
|
|
- If the item is in the cache, replace the input with the cached item.
|
|
- If the item is not in the cache, store that item (which includes tensor
|
|
data) into the cache, and return the input.
|
|
"""
|
|
|
|
def __init__(self, model_config: "ModelConfig") -> None:
|
|
super().__init__()
|
|
|
|
mm_config = model_config.get_multimodal_config()
|
|
|
|
self._cache = MultiModalCache.get_lru_cache(
|
|
mm_config.mm_processor_cache_gb,
|
|
MultiModalKwargsItem,
|
|
)
|
|
|
|
@override
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: MultiModalKwargsItem | None,
|
|
mm_hash: str,
|
|
) -> MultiModalKwargsItem:
|
|
if (cached_item := self._cache.get(mm_hash)) is not None:
|
|
return cached_item
|
|
|
|
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
|
|
|
self._cache[mm_hash] = mm_item
|
|
return mm_item
|
|
|
|
@override
|
|
def touch_receiver_cache_item(
|
|
self,
|
|
mm_hash: str,
|
|
mm_item: MultiModalKwargsItem | None = None,
|
|
) -> None:
|
|
self._cache.touch(mm_hash)
|
|
|
|
@override
|
|
def clear_cache(self) -> None:
|
|
self._cache.clear()
|
|
|
|
|
|
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
|
"""
|
|
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
|
|
|
How to update each item:
|
|
|
|
- If the item has an address, replace the input with the cached item.
|
|
- If not, return the input.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: "VllmConfig",
|
|
shared_worker_lock: LockType,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.world_size = vllm_config.parallel_config.world_size
|
|
mm_config = vllm_config.model_config.get_multimodal_config()
|
|
|
|
ring_buffer = SingleWriterShmRingBuffer(
|
|
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
|
name=envs.VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
|
create=False, # Server is a reader
|
|
)
|
|
self._shm_cache = SingleWriterShmObjectStorage(
|
|
max_object_size=mm_config.mm_shm_cache_max_object_size_mb * MiB_bytes,
|
|
n_readers=self.world_size,
|
|
ring_buffer=ring_buffer,
|
|
serde_class=MsgpackSerde,
|
|
reader_lock=shared_worker_lock,
|
|
)
|
|
|
|
@override
|
|
def get_and_update_item(
|
|
self,
|
|
mm_item: MultiModalKwargsItem | None,
|
|
mm_hash: str,
|
|
) -> MultiModalKwargsItem:
|
|
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
|
if "address" in mm_item:
|
|
address = cast(int, mm_item["address"].data)
|
|
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
|
return self._shm_cache.get(address, monotonic_id)
|
|
|
|
return mm_item
|
|
|
|
@override
|
|
def touch_receiver_cache_item(
|
|
self,
|
|
mm_hash: str,
|
|
mm_item: MultiModalKwargsItem | None = None,
|
|
) -> None:
|
|
"""Touch the item in shared memory cache to prevent eviction.
|
|
Increments reader_count on receiver side."""
|
|
assert mm_item is not None
|
|
if "address" in mm_item:
|
|
address = cast(int, mm_item["address"].data)
|
|
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
|
self._shm_cache.touch(mm_hash, address=address, monotonic_id=monotonic_id)
|
|
|
|
@override
|
|
def clear_cache(self) -> None:
|
|
self._shm_cache.clear()
|
|
|
|
|
|
def engine_receiver_cache_from_config(
|
|
vllm_config: "VllmConfig",
|
|
mm_registry: "MultiModalRegistry",
|
|
) -> BaseMultiModalReceiverCache | None:
|
|
"""
|
|
This is used in the engine process.
|
|
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
|
mm_processor_cache_type=="lru".
|
|
"""
|
|
model_config = vllm_config.model_config
|
|
|
|
if not _enable_processor_cache(model_config, mm_registry):
|
|
return None
|
|
|
|
if not _enable_ipc_cache(vllm_config):
|
|
return None
|
|
|
|
if not _enable_mm_input_shm_cache(vllm_config):
|
|
return MultiModalReceiverCache(model_config)
|
|
|
|
return None
|
|
|
|
|
|
def worker_receiver_cache_from_config(
|
|
vllm_config: "VllmConfig",
|
|
mm_registry: "MultiModalRegistry",
|
|
shared_worker_lock: LockType,
|
|
) -> BaseMultiModalReceiverCache | None:
|
|
"""
|
|
This is used in the worker process.
|
|
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
|
mm_processor_cache_type=="shm".
|
|
"""
|
|
model_config = vllm_config.model_config
|
|
|
|
if not _enable_processor_cache(model_config, mm_registry):
|
|
return None
|
|
|
|
if not _enable_ipc_cache(vllm_config):
|
|
return None
|
|
|
|
if not _enable_mm_input_shm_cache(vllm_config):
|
|
return None
|
|
|
|
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|