vllm/vllm/v1/engine/mm_input_cache.py
Cyrus Leung 4dff91c93d
[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-08-16 11:30:49 +00:00

122 lines
4.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.config import ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
)
def get_and_update(
self,
mm_kwargs: Sequence[MultiModalKwargsItem],
mm_hashes: list[str],
) -> list[Optional[MultiModalKwargsItem]]:
if not self.enabled:
return list(mm_kwargs)
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[Optional[MultiModalKwargsItem]]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
out_mm_items.append(None)
else:
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_item)
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargsItem,
)
def get_and_update(
self,
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
mm_hashes: list[str],
) -> list[MultiModalKwargsItem]:
if not self.enabled:
mm_kwargs_lst = list(mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
return mm_kwargs_lst
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if mm_item is None:
out_mm_items.append(self.mm_cache[mm_hash])
else:
self.mm_cache[mm_hash] = mm_item
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()