mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 02:37:04 +08:00
[VLM] Limit multimodal input cache by memory (#14805)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
9ed6ee92d6
commit
3556a41434
@ -53,7 +53,7 @@ repos:
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
cachetools
|
||||
psutil
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy < 2.0.0
|
||||
|
||||
@ -9,6 +9,7 @@ msgspec
|
||||
cloudpickle
|
||||
|
||||
# packages to install to build the documentation
|
||||
cachetools
|
||||
pydantic >= 2.8
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
torch
|
||||
|
||||
@ -48,7 +48,7 @@ def _test_processing_correctness(
|
||||
tokenizer=cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
# Ensure that it can fit all of the data
|
||||
cache = ProcessingCache(capacity=1 << 30)
|
||||
cache = ProcessingCache(capacity_gb=2048)
|
||||
|
||||
processing_info = factories.info(ctx)
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MM_INPUT_CACHE_SIZE: int = 256
|
||||
VLLM_MM_INPUT_CACHE_GIB: int = 8
|
||||
VLLM_TARGET_DEVICE: str = "cuda"
|
||||
MAX_JOBS: Optional[str] = None
|
||||
NVCC_THREADS: Optional[str] = None
|
||||
@ -432,11 +432,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Cache size for multimodal feature/input cache for multimodal models
|
||||
# in unit of number of multimodal data items (e.g. image, video, audio).
|
||||
# Default is 256 multimodal data items.
|
||||
"VLLM_MM_INPUT_CACHE_SIZE":
|
||||
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_SIZE", "256")),
|
||||
# Cache size (in GiB) for multimodal input cache
|
||||
# Default is 8GiB
|
||||
"VLLM_MM_INPUT_CACHE_GIB":
|
||||
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "8")),
|
||||
|
||||
# Path to the XLA persistent cache directory.
|
||||
# Only used for XLA devices such as TPUs.
|
||||
|
||||
79
vllm/jsontree.py
Normal file
79
vllm/jsontree.py
Normal file
@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Helper functions to work with nested JSON structures."""
|
||||
from collections.abc import Iterable
|
||||
from functools import reduce
|
||||
from typing import Callable, TypeVar, Union, overload
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
|
||||
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"],
|
||||
tuple["JSONTree[_T]", ...], _T]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
|
||||
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||
"""Iterate through each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
for v in value.values():
|
||||
yield from json_iter_leaves(v)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
for v in value:
|
||||
yield from json_iter_leaves(v)
|
||||
else:
|
||||
yield value
|
||||
|
||||
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: JSONTree[_T],
|
||||
) -> JSONTree[_U]:
|
||||
"""Apply a function to each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(json_map_leaves(func, v) for v in value)
|
||||
else:
|
||||
return func(value)
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: JSONTree[_T],
|
||||
/,
|
||||
) -> _T:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_U, _T], _U],
|
||||
value: JSONTree[_T],
|
||||
initial: _U,
|
||||
/,
|
||||
) -> _U:
|
||||
...
|
||||
|
||||
|
||||
def json_reduce_leaves(
|
||||
func: Callable[..., Union[_T, _U]],
|
||||
value: JSONTree[_T],
|
||||
initial: _U = ..., # type: ignore[assignment]
|
||||
/,
|
||||
) -> Union[_T, _U]:
|
||||
"""
|
||||
Apply a function of two arguments cumulatively to each leaf in a
|
||||
nested JSON structure, from left to right, so as to reduce the
|
||||
sequence to a single value.
|
||||
"""
|
||||
if initial is ...:
|
||||
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
|
||||
|
||||
return reduce(
|
||||
func, # type: ignore[arg-type]
|
||||
json_iter_leaves(value),
|
||||
initial,
|
||||
)
|
||||
@ -18,6 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -35,7 +36,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||
SiluAndMul)
|
||||
@ -50,7 +51,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
|
||||
@ -16,7 +16,8 @@ from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.utils import full_groupby, is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hasher import MultiModalHashDict
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
||||
@ -11,14 +11,17 @@ from functools import lru_cache
|
||||
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
TypeVar, Union, cast)
|
||||
|
||||
import torch
|
||||
from cachetools import LRUCache
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.jsontree import json_map_leaves, json_reduce_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||
encode_tokens)
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
|
||||
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
@ -812,25 +815,50 @@ def find_mm_placeholders(
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
|
||||
|
||||
|
||||
class ProcessingCache:
|
||||
|
||||
def __init__(self, capacity: int) -> None:
|
||||
@staticmethod
|
||||
def get_lru_cache(
|
||||
capacity_gb: int,
|
||||
value_type: type[_V],
|
||||
) -> LRUCache[str, _V]:
|
||||
|
||||
def get_size(leaf: object) -> int:
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes # sys.getsizeof doesn't work for tensors
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
return LRUCache[str, _V](
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: json_reduce_leaves(
|
||||
lambda a, b: a + b,
|
||||
json_map_leaves(get_size, x),
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, capacity_gb: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
||||
self.debug_cache_hits = 0
|
||||
self.debug_cache_total = 0
|
||||
|
||||
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
|
||||
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
|
||||
|
||||
def _maybe_log_cache_stats(self) -> None:
|
||||
steps = self.debug_cache_hit_ratio_steps
|
||||
if not steps:
|
||||
return
|
||||
|
||||
cache_stats = self._cache.stat()
|
||||
if cache_stats.total % steps == 0:
|
||||
total = self.debug_cache_total
|
||||
if total > 0 and total % steps == 0:
|
||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
||||
cache_stats.hit_ratio)
|
||||
self.debug_cache_hits / total)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@ -853,6 +881,13 @@ class ProcessingCache:
|
||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
|
||||
if self.debug_cache_hit_ratio_steps:
|
||||
if cache_key in self._cache:
|
||||
self.debug_cache_hits += 1
|
||||
|
||||
self.debug_cache_total += 1
|
||||
|
||||
return self._cache.get(cache_key)
|
||||
|
||||
def put(
|
||||
@ -870,7 +905,7 @@ class ProcessingCache:
|
||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
self._cache.put(cache_key, output_kwargs)
|
||||
self._cache[cache_key] = output_kwargs
|
||||
|
||||
|
||||
class BaseProcessingInfo:
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
@ -119,7 +119,7 @@ class MultiModalRegistry:
|
||||
|
||||
self._limits_by_model = _MultiModalLimits()
|
||||
|
||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE)
|
||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
||||
"""
|
||||
|
||||
@ -845,22 +845,6 @@ def is_list_of(
|
||||
assert_never(check)
|
||||
|
||||
|
||||
JSONTree = Union[dict[str, "JSONTree[T]"], list["JSONTree[T]"],
|
||||
tuple["JSONTree[T]", ...], T]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
|
||||
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
|
||||
if isinstance(value, dict):
|
||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(json_map_leaves(func, v) for v in value)
|
||||
else:
|
||||
return func(value)
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: list[list[T]]) -> list[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
@ -3,11 +3,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.utils import LRUCache
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -30,7 +30,7 @@ logger = init_logger(__name__)
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching). This cache size is set by the environment
|
||||
# variable VLLM_MM_INPUT_CACHE_SIZE.
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
|
||||
@ -50,18 +50,20 @@ class MMInputCacheClient:
|
||||
|
||||
# Init cache
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUCache[str,
|
||||
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
self.mm_cache_hits = 0
|
||||
self.mm_cache_total = 0
|
||||
self.mm_debug_cache_hits = 0
|
||||
self.mm_debug_cache_total = 0
|
||||
|
||||
def cache_hit_ratio(self, steps):
|
||||
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
||||
total = self.mm_debug_cache_total
|
||||
|
||||
if total > 0 and total % steps == 0:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_cache_hits / self.mm_cache_total)
|
||||
self.mm_debug_cache_hits / total)
|
||||
|
||||
# NOTE: process_inputs only supports image inputs since all multimodal
|
||||
# models with other modalities have migrated to use merged preprocessor.
|
||||
@ -71,7 +73,7 @@ class MMInputCacheClient:
|
||||
mm_hashes: Optional[list[str]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
|
||||
) -> list[MultiModalKwargs]:
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
@ -88,7 +90,7 @@ class MMInputCacheClient:
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_inputs: list[MultiModalKwargs] = []
|
||||
ret_inputs: list[Optional[MultiModalKwargs]] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
|
||||
@ -99,7 +101,7 @@ class MMInputCacheClient:
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
|
||||
self.mm_cache_total += 1
|
||||
self.mm_debug_cache_total += 1
|
||||
if mm_input is None:
|
||||
if precomputed_mm_inputs is not None:
|
||||
# Reuse precomputed input (for merged preprocessor)
|
||||
@ -114,9 +116,9 @@ class MMInputCacheClient:
|
||||
if self.use_cache:
|
||||
# Add to cache
|
||||
assert mm_hash is not None
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
else:
|
||||
self.mm_cache_hits += 1
|
||||
self.mm_debug_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
ret_inputs.append(mm_input)
|
||||
@ -128,14 +130,14 @@ class MMInputCacheServer:
|
||||
|
||||
def __init__(self, model_config):
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUCache[str,
|
||||
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_inputs: list[Optional[MultiModalKwargs]],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargs]:
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
@ -148,7 +150,7 @@ class MMInputCacheServer:
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
assert mm_input is not None
|
||||
else:
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user