[Core] Store only the keys for multi-modal data in P0 (#22198)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-07 16:45:04 +08:00 committed by GitHub
parent 289b18e670
commit 766bc8162c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 325 additions and 234 deletions

View File

@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process)
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits
@ -129,20 +129,18 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
Here are some examples:
??? code
```python
from vllm import LLM
```python
from vllm import LLM
# Available for Qwen2-VL series models
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
})
# Available for Qwen2-VL series models
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_kwargs={
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
})
# Available for InternVL series models
llm = LLM(model="OpenGVLab/InternVL2-2B",
mm_processor_kwargs={
"max_dynamic_patch": 4, # Default is 12
})
```
# Available for InternVL series models
llm = LLM(model="OpenGVLab/InternVL2-2B",
mm_processor_kwargs={
"max_dynamic_patch": 4, # Default is 12
})
```

View File

@ -2,6 +2,9 @@
This guide covers optimization strategies and performance tuning for vLLM V1.
!!! tip
Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory.
## Preemption
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
@ -126,62 +129,44 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
## Reducing Memory Usage
## Input Processing
If you encounter out-of-memory issues, consider these strategies:
### Parallel Processing
### Context Length and Batch Size
You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing).
This is useful when input processing (which is run inside the API server)
becomes a bottleneck compared to model execution (which is run inside engine core)
and you have excess CPU capacity.
You can reduce memory usage by limiting the context length and batch size:
```console
# Run 4 API processes and 1 engine core process
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4
```python
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
max_model_len=2048, # Limit context window
max_num_seqs=4 # Limit batch size
)
# Run 4 API processes and 2 engine core processes
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
```
### Adjust CUDA Graph Compilation
!!! note
API server scale-out is only available for online inference.
CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level:
!!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
because it requires a one-to-one correspondance between API and engine core processes.
## Multi-Modal Caching
### Processor Cache
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
the same multi-modal inputs via Hugging Face `AutoProcessor`,
which commonly occurs in multi-turn conversations.
You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable
(default 4 GiB per API process + 4 GiB per engine core process).
If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`:
```python
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes
)
)
```
Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`:
```python
from vllm import LLM
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
enforce_eager=True # Disable CUDA graph compilation
)
```
### Multimodal Models
For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request:
```python
from vllm import LLM
# Accept up to 2 images per prompt
llm = LLM(
model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"image": 2}
)
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
disable_mm_preprocessor_cache=True)
```

View File

@ -166,7 +166,7 @@ def parse_args():
parser.add_argument(
"--disable-mm-preprocessor-cache",
action="store_true",
help="If True, disables caching of multi-modal preprocessor/mapper.",
help="If True, disables caching of multi-modal processor.",
)
return parser.parse_args()

View File

@ -1565,7 +1565,7 @@ def parse_args():
parser.add_argument(
"--disable-mm-preprocessor-cache",
action="store_true",
help="If True, disables caching of multi-modal preprocessor/mapper.",
help="If True, disables caching of multi-modal processor.",
)
parser.add_argument(

View File

@ -9,7 +9,7 @@ import torch
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, RunnerOption
from vllm.config import ModelConfig, ModelDType, RunnerOption
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
@ -257,7 +257,7 @@ def check_logprobs_close(
def build_model_context(
model_id: str,
runner: RunnerOption = "auto",
dtype: Union[str, torch.dtype] = "auto",
dtype: ModelDType = "auto",
model_config_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
@ -279,6 +279,7 @@ def build_model_context(
model_info.check_transformers_version(on_fail="skip")
model_config_kwargs = model_config_kwargs or {}
limit_mm_per_prompt = limit_mm_per_prompt or {}
model_config = ModelConfig(
model_id,
runner=runner,

View File

@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = MultiModalCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item)
assert cache.currsize == expected_size

View File

@ -6,20 +6,15 @@ from typing import Optional, cast
import numpy as np
import pytest
import torch
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
ProcessingCache, PromptIndexTargets,
PromptInsertion, PromptReplacement,
apply_text_matches,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
@ -902,45 +897,6 @@ def test_find_mm_placeholders(
assert result == expected
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = ProcessingCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),

View File

@ -444,8 +444,7 @@ class ModelConfig:
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended)."""
"""If `True`, disable caching of the multi-modal processor."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
@ -1692,6 +1691,31 @@ class ModelConfig:
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@property
def processor_return_mm_hashes(self) -> bool:
"""Whether the multi-modal processor should output hashes."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return not mm_config.disable_mm_preprocessor_cache
@property
def enable_mm_input_cache(self) -> bool:
"""Whether the multi-modal input cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return not mm_config.disable_mm_preprocessor_cache
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
@property
def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding
@ -3369,7 +3393,7 @@ class MultiModalConfig:
disable_mm_preprocessor_cache: bool = False
"""
If `True`, disable caching of the processed multi-modal inputs.
If `True`, disable caching of the multi-modal processor.
"""
interleave_mm_strings: bool = False

View File

@ -1230,17 +1230,17 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel,
)
supports_mm_preprocessor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb)
if (not supports_mm_preprocessor_cache
and model_config.is_multimodal_model
and not model_config.disable_mm_preprocessor_cache):
logger.warning(
"Multi-modal preprocessor cache is not compatible "
"with data parallelism when there does not exist a "
"one-to-one correspondance between API process and "
"EngineCore process, so the cache will be disabled.")
model_config.set_disable_mm_preprocessor_cache(True)
if model_config.is_multimodal_model:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb)
if (not dp_supports_mm_processor_cache
and not model_config.disable_mm_preprocessor_cache):
logger.warning(
"Multi-modal processor cache is disabled because "
"it is not compatible with data parallelism when "
"there does not exist a one-to-one correspondance "
"between API and engine core processes.")
model_config.set_disable_mm_preprocessor_cache(True)
speculative_config = self.create_speculative_config(
target_model_config=model_config,

View File

@ -163,9 +163,8 @@ def run_multi_api_server(args: argparse.Namespace):
if model_config.is_multimodal_model and not (
orig_disable_mm_preprocessor_cache):
logger.warning(
"Multi-modal preprocessor cache is not compatible "
"with api_server_count > 1, so the cache will be disabled.")
logger.warning("Multi-modal processor cache is disabled because "
"it is not compatible with `api_server_count > 1`.")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats

View File

@ -65,7 +65,7 @@ if TYPE_CHECKING:
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MM_INPUT_CACHE_GIB: int = 8
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda"
MAX_JOBS: Optional[str] = None
NVCC_THREADS: Optional[str] = None
@ -561,8 +561,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_VIDEO_LOADER_BACKEND":
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
# Cache size (in GiB) for multimodal input cache
# Default is 4 GiB
# Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),

95
vllm/multimodal/cache.py Normal file
View File

@ -0,0 +1,95 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union
import torch
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors
logger = init_logger(__name__)
@dataclass
class MultiModalCacheItemMetadata:
size: int
@classmethod
def wraps(cls, value: "MultiModalCacheValue"):
return cls(size=MultiModalCache.get_item_size(value))
MultiModalCacheValue = Union[
MultiModalKwargs,
MultiModalKwargsItem,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
_V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache:
@classmethod
def get_leaf_size(
cls,
leaf: object,
*,
debug: bool = False,
) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data, debug=debug)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return cls.get_item_size(leaf_data, debug=debug)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
return size
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
@ -16,16 +15,16 @@ import torch
from typing_extensions import assert_never
from vllm.inputs import InputProcessingContext
from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
MultiModalKwargsItem, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
@ -888,9 +887,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
@ -901,48 +897,7 @@ class ProcessingCacheItem(NamedTuple):
value: MultiModalKwargsItem
class ProcessingCache:
@staticmethod
def get_lru_cache(
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
return sys.getsizeof(leaf)
def get_item_size(
value: Union[MultiModalKwargs, MultiModalKwargsItem,
Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_leaf_size, value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)
return size
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
class ProcessingCache(MultiModalCache):
def __init__(
self,

View File

@ -429,8 +429,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False.")
"is likely because you did not enable MM hashing. "
"Please set `disable_mm_preprocessor_cache=False`.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of

View File

@ -35,7 +35,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
@ -124,8 +124,7 @@ class EngineCore:
log_stats=self.log_stats,
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MirroredProcessingCache(
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
@ -413,7 +412,7 @@ class EngineCore:
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)

View File

@ -1,54 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional
from typing import TYPE_CHECKING, Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and
if TYPE_CHECKING:
from vllm.config import ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_inputs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_inputs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_inputs`, but not the `mm_inputs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_inputs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- Server:
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_inputs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_inputs`.
# - If the `mm_hash` is not cached (i.e. `mm_inputs` are sent from P0),
# MultiModalInputCacheServer stores `mm_inputs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_inputs` are sent to
# the engine for model execution.
#
# The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# client (=P0) and server (=P1) processes if the mm_hash is found in the client
# cache.
# Both Client and Server must use the same cache size
# (to perform mirrored caching). This cache size is set by the environment
# variable VLLM_MM_INPUT_CACHE_GIB.
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
class MirroredProcessingCache:
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config):
mm_config = model_config.multimodal_config
disable_mm_preprocessor_cache = (
mm_config is not None and mm_config.disable_mm_preprocessor_cache)
self.use_cache = not disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
def get_and_update_p0(
self.enabled = model_config.enable_mm_input_cache
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
)
def get_and_update(
self,
mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str],
) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
@ -57,20 +71,37 @@ class MirroredProcessingCache:
if self.mm_cache.get(mm_hash) is not None:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_input)
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
def reset(self) -> None:
self.mm_cache.clear()
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
self.enabled = model_config.enable_mm_input_cache
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargs,
)
def get_and_update(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
if not self.enabled:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
@ -85,7 +116,5 @@ class MirroredProcessingCache:
return full_mm_inputs
def reset(self) -> bool:
def reset(self) -> None:
self.mm_cache.clear()
return True

View File

@ -19,7 +19,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_outlines import (
@ -50,11 +50,8 @@ class Processor:
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching
self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config)
@property
def mm_registry(self):
@ -256,11 +253,13 @@ class Processor:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
return_mm_hashes = (self.model_config.processor_return_mm_hashes
or bool(self.cache_config.enable_prefix_caching))
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=self.use_hash,
return_mm_hashes=return_mm_hashes,
)
from vllm.platforms import current_platform
current_platform.validate_request(
@ -312,7 +311,7 @@ class Processor:
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
decoder_inputs["mm_placeholders"],
decoder_inputs["mm_hashes"] if self.use_hash else None,
decoder_inputs["mm_hashes"] if return_mm_hashes else None,
)
# The output of merged multi-modal processor (`decoder_mm_inputs`)
@ -339,7 +338,7 @@ class Processor:
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs