From 3556a414341033aad1bbb84674ec16b235324b25 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 15 Mar 2025 17:52:05 +0800 Subject: [PATCH] [VLM] Limit multimodal input cache by memory (#14805) Signed-off-by: DarkLight1337 --- .pre-commit-config.yaml | 2 +- requirements/common.txt | 1 + requirements/docs.txt | 1 + .../multimodal/processing/test_common.py | 2 +- vllm/envs.py | 11 ++- vllm/jsontree.py | 79 +++++++++++++++++++ vllm/model_executor/models/llava.py | 3 +- vllm/model_executor/models/molmo.py | 3 +- vllm/multimodal/inputs.py | 3 +- vllm/multimodal/processing.py | 51 ++++++++++-- vllm/multimodal/registry.py | 4 +- vllm/utils.py | 16 ---- vllm/v1/engine/mm_input_cache.py | 38 ++++----- 13 files changed, 159 insertions(+), 55 deletions(-) create mode 100644 vllm/jsontree.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 074ac9d122bfe..484cd171f5f52 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,7 +53,7 @@ repos: entry: tools/mypy.sh 0 "local" language: python types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] stages: [pre-commit] # Don't run in CI - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 diff --git a/requirements/common.txt b/requirements/common.txt index 3cd933f347f59..bb021d9e45499 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,3 +1,4 @@ +cachetools psutil sentencepiece # Required for LLaMA tokenizer. numpy < 2.0.0 diff --git a/requirements/docs.txt b/requirements/docs.txt index 1d669699f4b2a..7a9b921a11715 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -9,6 +9,7 @@ msgspec cloudpickle # packages to install to build the documentation +cachetools pydantic >= 2.8 -f https://download.pytorch.org/whl/cpu torch diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index aef5db9bc06bb..0e0d3711357e4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -48,7 +48,7 @@ def _test_processing_correctness( tokenizer=cached_tokenizer_from_config(model_config), ) # Ensure that it can fit all of the data - cache = ProcessingCache(capacity=1 << 30) + cache = ProcessingCache(capacity_gb=2048) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() diff --git a/vllm/envs.py b/vllm/envs.py index 463059dc06704..bf214f314c458 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -56,7 +56,7 @@ if TYPE_CHECKING: VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 - VLLM_MM_INPUT_CACHE_SIZE: int = 256 + VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -432,11 +432,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), - # Cache size for multimodal feature/input cache for multimodal models - # in unit of number of multimodal data items (e.g. image, video, audio). - # Default is 256 multimodal data items. - "VLLM_MM_INPUT_CACHE_SIZE": - lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")), + # Cache size (in GiB) for multimodal input cache + # Default is 8GiB + "VLLM_MM_INPUT_CACHE_GIB": + lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "8")), # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. diff --git a/vllm/jsontree.py b/vllm/jsontree.py new file mode 100644 index 0000000000000..91cd7cb216d77 --- /dev/null +++ b/vllm/jsontree.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Helper functions to work with nested JSON structures.""" +from collections.abc import Iterable +from functools import reduce +from typing import Callable, TypeVar, Union, overload + +_T = TypeVar("_T") +_U = TypeVar("_U") + +JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], + tuple["JSONTree[_T]", ...], _T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: + """Iterate through each leaf in a nested JSON structure.""" + if isinstance(value, dict): + for v in value.values(): + yield from json_iter_leaves(v) + elif isinstance(value, (list, tuple)): + for v in value: + yield from json_iter_leaves(v) + else: + yield value + + +def json_map_leaves( + func: Callable[[_T], _U], + value: JSONTree[_T], +) -> JSONTree[_U]: + """Apply a function to each leaf in a nested JSON structure.""" + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: JSONTree[_T], + /, +) -> _T: + ... + + +@overload +def json_reduce_leaves( + func: Callable[[_U, _T], _U], + value: JSONTree[_T], + initial: _U, + /, +) -> _U: + ... + + +def json_reduce_leaves( + func: Callable[..., Union[_T, _U]], + value: JSONTree[_T], + initial: _U = ..., # type: ignore[assignment] + /, +) -> Union[_T, _U]: + """ + Apply a function of two arguments cumulatively to each leaf in a + nested JSON structure, from left to right, so as to reduce the + sequence to a single value. + """ + if initial is ...: + return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type] + + return reduce( + func, # type: ignore[arg-type] + json_iter_leaves(value), + initial, + ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 478dbd83d3002..42bf6a5b2979a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -18,6 +18,7 @@ from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext +from vllm.jsontree import JSONTree, json_map_leaves from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -35,7 +36,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves +from vllm.utils import flatten_2d_lists from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 444b619437a09..e709b08815eaf 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -24,6 +24,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) +from vllm.jsontree import JSONTree, json_map_leaves from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) @@ -50,7 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, PromptInsertion, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.sequence import IntermediateTensors -from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves +from vllm.utils import flatten_2d_lists from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 7b186d89dad4a..3c609fd967650 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -16,7 +16,8 @@ from PIL.Image import Image from transformers import BatchFeature from typing_extensions import NotRequired, TypeAlias -from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves +from vllm.jsontree import JSONTree, json_map_leaves +from vllm.utils import full_groupby, is_list_of if TYPE_CHECKING: from .hasher import MultiModalHashDict diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 080a2362aac52..cdbbed27a5218 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - import re +import sys from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, @@ -11,14 +11,17 @@ from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) +import torch +from cachetools import LRUCache from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from typing_extensions import assert_never from vllm.inputs import InputProcessingContext +from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import LRUCache, flatten_2d_lists, full_groupby +from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, @@ -812,25 +815,50 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) +_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") + + class ProcessingCache: - def __init__(self, capacity: int) -> None: + @staticmethod + def get_lru_cache( + capacity_gb: int, + value_type: type[_V], + ) -> LRUCache[str, _V]: + + def get_size(leaf: object) -> int: + if isinstance(leaf, torch.Tensor): + return leaf.nbytes # sys.getsizeof doesn't work for tensors + + return sys.getsizeof(leaf) + + return LRUCache[str, _V]( + GiB_bytes * capacity_gb, + getsizeof=lambda x: json_reduce_leaves( + lambda a, b: a + b, + json_map_leaves(get_size, x), + ), + ) + + def __init__(self, capacity_gb: int) -> None: super().__init__() # DEBUG: Set to None to disable self.debug_cache_hit_ratio_steps: Optional[int] = None + self.debug_cache_hits = 0 + self.debug_cache_total = 0 - self._cache = LRUCache[str, MultiModalKwargsItem](capacity) + self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) def _maybe_log_cache_stats(self) -> None: steps = self.debug_cache_hit_ratio_steps if not steps: return - cache_stats = self._cache.stat() - if cache_stats.total % steps == 0: + total = self.debug_cache_total + if total > 0 and total % steps == 0: logger.debug("ProcessingCache: hit_ratio = %.2f", - cache_stats.hit_ratio) + self.debug_cache_hits / total) def get( self, @@ -853,6 +881,13 @@ class ProcessingCache: cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: input_item}, **input_kwargs) + + if self.debug_cache_hit_ratio_steps: + if cache_key in self._cache: + self.debug_cache_hits += 1 + + self.debug_cache_total += 1 + return self._cache.get(cache_key) def put( @@ -870,7 +905,7 @@ class ProcessingCache: cache_key = MultiModalHasher.hash_kwargs(model_id=model_id, **{modality: input_item}, **input_kwargs) - self._cache.put(cache_key, output_kwargs) + self._cache[cache_key] = output_kwargs class BaseProcessingInfo: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index febf3ad9eea42..24b8358982797 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar import torch.nn as nn -from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE +from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, @@ -119,7 +119,7 @@ class MultiModalRegistry: self._limits_by_model = _MultiModalLimits() - self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE) + self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) def register_plugin(self, plugin: MultiModalPlugin) -> None: """ diff --git a/vllm/utils.py b/vllm/utils.py index 9334741225008..632b3666e959c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -845,22 +845,6 @@ def is_list_of( assert_never(check) -JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"], - tuple["JSONTree[T]", ...], T] -"""A nested JSON structure where the leaves need not be JSON-serializable.""" - - -def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: - if isinstance(value, dict): - return {k: json_map_leaves(func, v) for k, v in value.items()} - elif isinstance(value, list): - return [json_map_leaves(func, v) for v in value] - elif isinstance(value, tuple): - return tuple(json_map_leaves(func, v) for v in value) - else: - return func(value) - - def flatten_2d_lists(lists: list[list[T]]) -> list[T]: """Flatten a list of lists to a single list.""" return [item for sublist in lists for item in sublist] diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 0f66f68109b17..e2dda73ba4299 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -3,11 +3,11 @@ from typing import Any, Optional from vllm.config import ModelConfig -from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE +from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.logger import init_logger from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs, MultiModalRegistry) -from vllm.utils import LRUCache +from vllm.multimodal.processing import ProcessingCache logger = init_logger(__name__) @@ -30,7 +30,7 @@ logger = init_logger(__name__) # Both Client and Server must use the same cache size # (to perform mirrored caching). This cache size is set by the environment -# variable VLLM_MM_INPUT_CACHE_SIZE. +# variable VLLM_MM_INPUT_CACHE_GIB. # TODO(ywang96): Deprecate this class once all multimodal models migrate to use @@ -50,18 +50,20 @@ class MMInputCacheClient: # Init cache self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUCache[str, - MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) + self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, + MultiModalKwargs) # DEBUG: Set to None to disable self.mm_debug_cache_hit_ratio_steps = None - self.mm_cache_hits = 0 - self.mm_cache_total = 0 + self.mm_debug_cache_hits = 0 + self.mm_debug_cache_total = 0 def cache_hit_ratio(self, steps): - if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0: + total = self.mm_debug_cache_total + + if total > 0 and total % steps == 0: logger.debug("MMInputMapper: cache_hit_ratio = %.2f ", - self.mm_cache_hits / self.mm_cache_total) + self.mm_debug_cache_hits / total) # NOTE: process_inputs only supports image inputs since all multimodal # models with other modalities have migrated to use merged preprocessor. @@ -71,7 +73,7 @@ class MMInputCacheClient: mm_hashes: Optional[list[str]], mm_processor_kwargs: Optional[dict[str, Any]], precomputed_mm_inputs: Optional[list[MultiModalKwargs]], - ) -> list[MultiModalKwargs]: + ) -> list[Optional[MultiModalKwargs]]: if precomputed_mm_inputs is None: image_inputs = mm_data["image"] if not isinstance(image_inputs, list): @@ -88,7 +90,7 @@ class MMInputCacheClient: # Process each image input separately, so that later we can schedule # them in a fine-grained manner. # Apply caching (if enabled) and reuse precomputed inputs (if provided) - ret_inputs: list[MultiModalKwargs] = [] + ret_inputs: list[Optional[MultiModalKwargs]] = [] for input_id in range(num_inputs): if self.mm_debug_cache_hit_ratio_steps is not None: self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps) @@ -99,7 +101,7 @@ class MMInputCacheClient: mm_hash = mm_hashes[input_id] mm_input = self.mm_cache.get(mm_hash) - self.mm_cache_total += 1 + self.mm_debug_cache_total += 1 if mm_input is None: if precomputed_mm_inputs is not None: # Reuse precomputed input (for merged preprocessor) @@ -114,9 +116,9 @@ class MMInputCacheClient: if self.use_cache: # Add to cache assert mm_hash is not None - self.mm_cache.put(mm_hash, mm_input) + self.mm_cache[mm_hash] = mm_input else: - self.mm_cache_hits += 1 + self.mm_debug_cache_hits += 1 mm_input = None # Avoids sending mm_input to Server ret_inputs.append(mm_input) @@ -128,14 +130,14 @@ class MMInputCacheServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUCache[str, - MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) + self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, + MultiModalKwargs) def get_and_update( self, mm_inputs: list[Optional[MultiModalKwargs]], mm_hashes: list[str], - ) -> list[MultiModalKwargs]: + ) -> list[Optional[MultiModalKwargs]]: assert len(mm_inputs) == len(mm_hashes) if not self.use_cache: @@ -148,7 +150,7 @@ class MMInputCacheServer: mm_input = self.mm_cache.get(mm_hash) assert mm_input is not None else: - self.mm_cache.put(mm_hash, mm_input) + self.mm_cache[mm_hash] = mm_input full_mm_inputs.append(mm_input)