[VLM] Limit multimodal input cache by memory (#14805)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-15 17:52:05 +08:00 committed by GitHub
parent 9ed6ee92d6
commit 3556a41434
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 159 additions and 55 deletions

View File

@ -53,7 +53,7 @@ repos:
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9

View File

@ -1,3 +1,4 @@
cachetools
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0

View File

@ -9,6 +9,7 @@ msgspec
cloudpickle
# packages to install to build the documentation
cachetools
pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu
torch

View File

@ -48,7 +48,7 @@ def _test_processing_correctness(
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
cache = ProcessingCache(capacity_gb=2048)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()

View File

@ -56,7 +56,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MM_INPUT_CACHE_SIZE: int = 256
VLLM_MM_INPUT_CACHE_GIB: int = 8
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
@ -432,11 +432,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Cache size for multimodal feature/input cache for multimodal models
# in unit of number of multimodal data items (e.g. image, video, audio).
# Default is 256 multimodal data items.
"VLLM_MM_INPUT_CACHE_SIZE":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")),
# Cache size (in GiB) for multimodal input cache
# Default is 8GiB
"VLLM_MM_INPUT_CACHE_GIB":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "8")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.

79
vllm/jsontree.py Normal file
View File

@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
"""Helper functions to work with nested JSON structures."""
from collections.abc import Iterable
from functools import reduce
from typing import Callable, TypeVar, Union, overload
_T = TypeVar("_T")
_U = TypeVar("_U")
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...], _T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
"""Iterate through each leaf in a nested JSON structure."""
if isinstance(value, dict):
for v in value.values():
yield from json_iter_leaves(v)
elif isinstance(value, (list, tuple)):
for v in value:
yield from json_iter_leaves(v)
else:
yield value
def json_map_leaves(
func: Callable[[_T], _U],
value: JSONTree[_T],
) -> JSONTree[_U]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: JSONTree[_T],
/,
) -> _T:
...
@overload
def json_reduce_leaves(
func: Callable[[_U, _T], _U],
value: JSONTree[_T],
initial: _U,
/,
) -> _U:
...
def json_reduce_leaves(
func: Callable[..., Union[_T, _U]],
value: JSONTree[_T],
initial: _U = ..., # type: ignore[assignment]
/,
) -> Union[_T, _U]:
"""
Apply a function of two arguments cumulatively to each leaf in a
nested JSON structure, from left to right, so as to reduce the
sequence to a single value.
"""
if initial is ...:
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
return reduce(
func, # type: ignore[arg-type]
json_iter_leaves(value),
initial,
)

View File

@ -18,6 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -35,7 +36,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from vllm.utils import flatten_2d_lists
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@ -24,6 +24,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
@ -50,7 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)

View File

@ -16,7 +16,8 @@ from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import full_groupby, is_list_of
if TYPE_CHECKING:
from .hasher import MultiModalHashDict

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import re
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
@ -11,14 +11,17 @@ from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
@ -812,25 +815,50 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCache:
def __init__(self, capacity: int) -> None:
@staticmethod
def get_lru_cache(
capacity_gb: int,
value_type: type[_V],
) -> LRUCache[str, _V]:
def get_size(leaf: object) -> int:
if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors
return sys.getsizeof(leaf)
return LRUCache[str, _V](
GiB_bytes * capacity_gb,
getsizeof=lambda x: json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_size, x),
),
)
def __init__(self, capacity_gb: int) -> None:
super().__init__()
# DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hits = 0
self.debug_cache_total = 0
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
if not steps:
return
cache_stats = self._cache.stat()
if cache_stats.total % steps == 0:
total = self.debug_cache_total
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
cache_stats.hit_ratio)
self.debug_cache_hits / total)
def get(
self,
@ -853,6 +881,13 @@ class ProcessingCache:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
if self.debug_cache_hit_ratio_steps:
if cache_key in self._cache:
self.debug_cache_hits += 1
self.debug_cache_total += 1
return self._cache.get(cache_key)
def put(
@ -870,7 +905,7 @@ class ProcessingCache:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
self._cache.put(cache_key, output_kwargs)
self._cache[cache_key] = output_kwargs
class BaseProcessingInfo:

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@ -119,7 +119,7 @@ class MultiModalRegistry:
self._limits_by_model = _MultiModalLimits()
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE)
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""

View File

@ -845,22 +845,6 @@ def is_list_of(
assert_never(check)
JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"],
tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]

View File

@ -3,11 +3,11 @@
from typing import Any, Optional
from vllm.config import ModelConfig
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.utils import LRUCache
from vllm.multimodal.processing import ProcessingCache
logger = init_logger(__name__)
@ -30,7 +30,7 @@ logger = init_logger(__name__)
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_SIZE.
# variable VLLM_MM_INPUT_CACHE_GIB.
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
@ -50,18 +50,20 @@ class MMInputCacheClient:
# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str,
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
self.mm_cache_hits = 0
self.mm_cache_total = 0
self.mm_debug_cache_hits = 0
self.mm_debug_cache_total = 0
def cache_hit_ratio(self, steps):
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
total = self.mm_debug_cache_total
if total > 0 and total % steps == 0:
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)
self.mm_debug_cache_hits / total)
# NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor.
@ -71,7 +73,7 @@ class MMInputCacheClient:
mm_hashes: Optional[list[str]],
mm_processor_kwargs: Optional[dict[str, Any]],
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
) -> list[MultiModalKwargs]:
) -> list[Optional[MultiModalKwargs]]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
@ -88,7 +90,7 @@ class MMInputCacheClient:
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_inputs: list[MultiModalKwargs] = []
ret_inputs: list[Optional[MultiModalKwargs]] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
@ -99,7 +101,7 @@ class MMInputCacheClient:
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
self.mm_cache_total += 1
self.mm_debug_cache_total += 1
if mm_input is None:
if precomputed_mm_inputs is not None:
# Reuse precomputed input (for merged preprocessor)
@ -114,9 +116,9 @@ class MMInputCacheClient:
if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
self.mm_cache[mm_hash] = mm_input
else:
self.mm_cache_hits += 1
self.mm_debug_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
ret_inputs.append(mm_input)
@ -128,14 +130,14 @@ class MMInputCacheServer:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUCache[str,
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def get_and_update(
self,
mm_inputs: list[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> list[MultiModalKwargs]:
) -> list[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
@ -148,7 +150,7 @@ class MMInputCacheServer:
mm_input = self.mm_cache.get(mm_hash)
assert mm_input is not None
else:
self.mm_cache.put(mm_hash, mm_input)
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)