mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Sequence
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from vllm.multimodal import MultiModalRegistry
|
|
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
|
from vllm.multimodal.inputs import MultiModalKwargsItem
|
|
from vllm.utils import is_list_of
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import ModelConfig
|
|
|
|
# The idea of multimodal input caching is based on having a client and
|
|
# a server, where the client executes in the frontend process (=P0) and the
|
|
# server in the core process (=P1).
|
|
#
|
|
# -- P0:
|
|
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
|
|
# each input multi-modal item (e.g. image),
|
|
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
|
|
# which are MultiModalKwargsItem instances that each correspond to an
|
|
# input multi-modal item.
|
|
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
|
|
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
|
|
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
|
|
# up additional memory in P0.
|
|
# - The `mm_hash` is always sent to P1.
|
|
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
|
|
# in MultiModalInputCacheServer.
|
|
#
|
|
# -- P1:
|
|
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
|
|
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
|
|
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
|
|
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
|
|
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
|
|
# the engine for model execution.
|
|
#
|
|
# Both Client and Server must perform cache update and eviction based on the
|
|
# same item size. This ensures that the keys of MultiModalInputCacheClient
|
|
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
|
|
# whether a key is cached in MultiModalInputCacheServer by querying
|
|
# MultiModalInputCacheClient without having to communicate with P1.
|
|
|
|
|
|
class MultiModalInputCacheClient:
|
|
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
|
|
|
|
def __init__(self, model_config: "ModelConfig",
|
|
mm_registry: MultiModalRegistry) -> None:
|
|
super().__init__()
|
|
|
|
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
|
self.mm_cache = MultiModalCache.get_lru_cache(
|
|
model_config.get_mm_input_cache_gb(),
|
|
MultiModalCacheItemMetadata,
|
|
)
|
|
|
|
def get_and_update(
|
|
self,
|
|
mm_kwargs: Sequence[MultiModalKwargsItem],
|
|
mm_hashes: list[str],
|
|
) -> list[Optional[MultiModalKwargsItem]]:
|
|
if not self.enabled:
|
|
return list(mm_kwargs)
|
|
|
|
assert len(mm_kwargs) == len(mm_hashes)
|
|
|
|
out_mm_items = list[Optional[MultiModalKwargsItem]]()
|
|
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
|
if self.mm_cache.get(mm_hash) is not None:
|
|
out_mm_items.append(None)
|
|
else:
|
|
self.mm_cache[mm_hash] = \
|
|
MultiModalCacheItemMetadata.wraps(mm_item)
|
|
out_mm_items.append(mm_item)
|
|
|
|
return out_mm_items
|
|
|
|
def reset(self) -> None:
|
|
self.mm_cache.clear()
|
|
|
|
|
|
class MultiModalInputCacheServer:
|
|
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
|
|
|
|
def __init__(self, model_config: "ModelConfig",
|
|
mm_registry: MultiModalRegistry) -> None:
|
|
super().__init__()
|
|
|
|
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
|
self.mm_cache = MultiModalCache.get_lru_cache(
|
|
model_config.get_mm_input_cache_gb(),
|
|
MultiModalKwargsItem,
|
|
)
|
|
|
|
def get_and_update(
|
|
self,
|
|
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
|
|
mm_hashes: list[str],
|
|
) -> list[MultiModalKwargsItem]:
|
|
if not self.enabled:
|
|
mm_kwargs_lst = list(mm_kwargs)
|
|
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
|
|
return mm_kwargs_lst
|
|
|
|
assert len(mm_kwargs) == len(mm_hashes)
|
|
|
|
out_mm_items = list[MultiModalKwargsItem]()
|
|
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
|
if mm_item is None:
|
|
out_mm_items.append(self.mm_cache[mm_hash])
|
|
else:
|
|
self.mm_cache[mm_hash] = mm_item
|
|
out_mm_items.append(mm_item)
|
|
|
|
return out_mm_items
|
|
|
|
def reset(self) -> None:
|
|
self.mm_cache.clear()
|