From dd5ede444032cfb00d8377722b7f24ec9157666b Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 13 Feb 2025 20:19:03 -0800 Subject: [PATCH] [V1] Consolidate MM cache size to vllm.envs (#13239) --- vllm/envs.py | 11 +++++++++-- vllm/multimodal/registry.py | 6 ++---- vllm/v1/engine/mm_input_cache.py | 12 +++++++----- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index d99c794e69e6c..f8a18cc662ab0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -55,6 +55,7 @@ if TYPE_CHECKING: VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MM_INPUT_CACHE_SIZE: int = 256 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -401,15 +402,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), # Timeout for fetching videos when serving multimodal models - # Default is 15 seconds + # Default is 30 seconds "VLLM_VIDEO_FETCH_TIMEOUT": - lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "15")), + lambda: int(os.getenv("VLLM_VIDEO_FETCH_TIMEOUT", "30")), # Timeout for fetching audio when serving multimodal models # Default is 10 seconds "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + # Cache size for multimodal feature/input cache for multimodal models + # in unit of number of multimodal data items (e.g. image, video, audio). + # Default is 256 multimodal data items. + "VLLM_MM_INPUT_CACHE_SIZE": + lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")), + # Path to the XLA persistent cache directory. # Only used for XLA devices such as TPUs. "VLLM_XLA_CACHE_PATH": diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 04141114288c9..613d1db416720 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional, import torch.nn as nn +from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -28,9 +29,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -# TODO: Tune the MM cache size -MM_CACHE_SIZE = 256 - N = TypeVar("N", bound=Type[nn.Module]) _I = TypeVar("_I", bound=BaseProcessingInfo) _I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True) @@ -121,7 +119,7 @@ class MultiModalRegistry: self._limits_by_model = _MultiModalLimits() - self._processing_cache = ProcessingCache(MM_CACHE_SIZE) + self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE) def register_plugin(self, plugin: MultiModalPlugin) -> None: """ diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index e1b6679c284b4..a1d802bf818a2 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from vllm.config import ModelConfig +from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE from vllm.logger import init_logger from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalKwargs, MultiModalRegistry) @@ -28,9 +29,8 @@ logger = init_logger(__name__) # client (=P0) and server (=P1) processes. # Both Client and Server must use the same cache size -# (to perform mirrored caching) -# TODO: Tune the MM cache size -MM_CACHE_SIZE = 256 +# (to perform mirrored caching). This cache size is set by the environment +# variable VLLM_MM_INPUT_CACHE_SIZE. # TODO(ywang96): Deprecate this class once all multimodal models migrate to use @@ -50,7 +50,8 @@ class MMInputCacheClient: # Init cache self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, + MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) # DEBUG: Set to None to disable self.mm_debug_cache_hit_ratio_steps = None @@ -127,7 +128,8 @@ class MMInputCacheServer: def __init__(self, model_config): self.use_cache = not model_config.disable_mm_preprocessor_cache - self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE) + self.mm_cache = LRUCache[str, + MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE) def get_and_update( self,