[Frontend] Use engine argument to control MM cache size (#22441)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-08 00:47:10 +08:00 committed by GitHub
parent 8c9da6be22
commit 139d155781
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 101 additions and 47 deletions

View File

@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options: If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process) - (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits ## Multi-modal input limits

View File

@ -161,12 +161,18 @@ By default, the multi-modal processor cache is enabled to avoid repeatedly proce
the same multi-modal inputs via Hugging Face `AutoProcessor`, the same multi-modal inputs via Hugging Face `AutoProcessor`,
which commonly occurs in multi-turn conversations. which commonly occurs in multi-turn conversations.
You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
(default 4 GiB per API process + 4 GiB per engine core process). (default 4 GiB per API process + 4 GiB per engine core process).
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`: Examples:
```python ```python
# Use a larger cache
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
disable_mm_preprocessor_cache=True) mm_processor_cache_gb=8)
# Disable the cache
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=0)
``` ```

View File

@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace):
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
tensor_parallel_size=2, tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
) )
prompt = "Describe this image in one sentence." prompt = "Describe this image in one sentence."
@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace):
limit_mm_per_prompt={"image": max_img_per_msg}, limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img, max_model_len=max_img_per_msg * max_tokens_per_img,
tensor_parallel_size=2, tensor_parallel_size=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
) )
prompt = "Describe the following image." prompt = "Describe the following image."
@ -164,7 +164,7 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--disable-mm-preprocessor-cache", "--disable-mm-processor-cache",
action="store_true", action="store_true",
help="If True, disables caching of multi-modal processor.", help="If True, disables caching of multi-modal processor.",
) )

View File

@ -1563,7 +1563,7 @@ def parse_args():
) )
parser.add_argument( parser.add_argument(
"--disable-mm-preprocessor-cache", "--disable-mm-processor-cache",
action="store_true", action="store_true",
help="If True, disables caching of multi-modal processor.", help="If True, disables caching of multi-modal processor.",
) )
@ -1603,7 +1603,7 @@ def main(args):
engine_args = asdict(req_data.engine_args) | { engine_args = asdict(req_data.engine_args) | {
"seed": args.seed, "seed": args.seed,
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache, "mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
} }
llm = LLM(**engine_args) llm = LLM(**engine_args)

View File

@ -62,9 +62,7 @@ def run_test(
# if we run HF first, the cuda initialization will be done and it # if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method). # will hurt multiprocessing backend with fork method (the default method).
vllm_runner_kwargs_: dict[str, Any] = { vllm_runner_kwargs_: dict[str, Any] = {"mm_processor_cache_gb": 0}
"disable_mm_preprocessor_cache": True,
}
if model_info.tokenizer: if model_info.tokenizer:
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
if model_info.tokenizer_mode: if model_info.tokenizer_mode:

View File

@ -15,14 +15,14 @@ from ...utils import build_model_context
["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) ["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
@pytest.mark.parametrize("mm_processor_kwargs", [{}]) @pytest.mark.parametrize("mm_processor_kwargs", [{}])
@pytest.mark.parametrize("num_imgs", [1, 5]) @pytest.mark.parametrize("num_imgs", [1, 5])
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False]) @pytest.mark.parametrize("mm_processor_cache_gb", [0, 4])
@pytest.mark.parametrize("tokenized_prompt", [True, False]) @pytest.mark.parametrize("tokenized_prompt", [True, False])
def test_processor_override( def test_processor_override(
image_assets: ImageTestAssets, image_assets: ImageTestAssets,
model_id: str, model_id: str,
mm_processor_kwargs: dict, mm_processor_kwargs: dict,
num_imgs: int, num_imgs: int,
disable_mm_preprocessor_cache: bool, mm_processor_cache_gb: int,
tokenized_prompt: bool, tokenized_prompt: bool,
): ):
"""Ensure llama4 processor works properly.""" """Ensure llama4 processor works properly."""
@ -30,7 +30,7 @@ def test_processor_override(
model_id, model_id,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, mm_processor_cache_gb=mm_processor_cache_gb,
) )
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
config = processor.info.get_hf_config() config = processor.info.get_hf_config()

View File

@ -261,7 +261,7 @@ def build_model_context(
model_config_kwargs: Optional[dict[str, Any]] = None, model_config_kwargs: Optional[dict[str, Any]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None,
disable_mm_preprocessor_cache: bool = True, mm_processor_cache_gb: int = 0,
): ):
"""Creates an InputContext for a given model. """Creates an InputContext for a given model.
@ -291,7 +291,7 @@ def build_model_context(
seed=0, seed=0,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
**model_config_kwargs, **model_config_kwargs,
) )

View File

@ -443,8 +443,15 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
""" """
disable_mm_preprocessor_cache: bool = False mm_processor_cache_gb: int = 4
"""If `True`, disable caching of the multi-modal processor.""" """The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
This cache is duplicated for each API process and engine core process,
resulting in a total memory usage of
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended)."""
override_neuron_config: dict[str, Any] = field(default_factory=dict) override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config """Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to that are specific to Neuron devices, this argument will be used to
@ -881,17 +888,16 @@ class ModelConfig:
limit_per_prompt=self.limit_mm_per_prompt, limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs, media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self. mm_processor_cache_gb=self.mm_processor_cache_gb,
disable_mm_preprocessor_cache,
interleave_mm_strings=self.interleave_mm_strings) interleave_mm_strings=self.interleave_mm_strings)
return None return None
def set_disable_mm_preprocessor_cache(self, value: bool) -> None: def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config() mm_config = self.get_multimodal_config()
self.disable_mm_preprocessor_cache = value self.mm_processor_cache_gb = value
mm_config.disable_mm_preprocessor_cache = value mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self): def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config( return get_sentence_transformer_tokenizer_config(
@ -1698,7 +1704,16 @@ class ModelConfig:
if mm_config is None: if mm_config is None:
return False return False
return not mm_config.disable_mm_preprocessor_cache return mm_config.mm_processor_cache_gb > 0
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
@property @property
def enable_mm_input_cache(self) -> bool: def enable_mm_input_cache(self) -> bool:
@ -1707,7 +1722,7 @@ class ModelConfig:
if mm_config is None: if mm_config is None:
return False return False
return not mm_config.disable_mm_preprocessor_cache return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int: def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config mm_config = self.multimodal_config
@ -3391,9 +3406,15 @@ class MultiModalConfig:
`{"num_crops": 4}`. `{"num_crops": 4}`.
""" """
disable_mm_preprocessor_cache: bool = False mm_processor_cache_gb: int = 4
""" """
If `True`, disable caching of the multi-modal processor. The size (in GiB) of the multi-modal processor cache, which is used to
This cache is duplicated for each API process and engine core process,
resulting in a total memory usage of
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
Set to `0` to disable this cache completely (not recommended).
""" """
interleave_mm_strings: bool = False interleave_mm_strings: bool = False

View File

@ -358,8 +358,8 @@ class EngineArgs:
"media_io_kwargs") "media_io_kwargs")
mm_processor_kwargs: Optional[Dict[str, Any]] = \ mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \ disable_mm_preprocessor_cache: bool = False # DEPRECATED
MultiModalConfig.disable_mm_preprocessor_cache mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -720,8 +720,11 @@ class EngineArgs:
"--mm-processor-kwargs", "--mm-processor-kwargs",
**multimodal_kwargs["mm_processor_kwargs"]) **multimodal_kwargs["mm_processor_kwargs"])
multimodal_group.add_argument( multimodal_group.add_argument(
"--disable-mm-preprocessor-cache", "--mm-processor-cache-gb",
**multimodal_kwargs["disable_mm_preprocessor_cache"]) **multimodal_kwargs["mm_processor_cache_gb"])
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
type=bool,
deprecated=True)
multimodal_group.add_argument( multimodal_group.add_argument(
"--interleave-mm-strings", "--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"]) **multimodal_kwargs["interleave_mm_strings"])
@ -886,6 +889,23 @@ class EngineArgs:
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
self.load_format = "runai_streamer" self.load_format = "runai_streamer"
if self.disable_mm_preprocessor_cache:
logger.warning(
"`--disable-mm-preprocessor-cache` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb 0` instead.", )
self.mm_processor_cache_gb = 0
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
logger.warning(
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb %d` instead.",
envs.VLLM_MM_INPUT_CACHE_GIB,
)
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
hf_config_path=self.hf_config_path, hf_config_path=self.hf_config_path,
@ -922,7 +942,7 @@ class EngineArgs:
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache, mm_processor_cache_gb=self.mm_processor_cache_gb,
override_neuron_config=self.override_neuron_config, override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
@ -1234,13 +1254,13 @@ class EngineArgs:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1 dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb) or data_parallel_external_lb)
if (not dp_supports_mm_processor_cache if (not dp_supports_mm_processor_cache
and not model_config.disable_mm_preprocessor_cache): and model_config.mm_processor_cache_gb > 0):
logger.warning( logger.warning(
"Multi-modal processor cache is disabled because " "Multi-modal processor cache is disabled because "
"it is not compatible with data parallelism when " "it is not compatible with data parallelism when "
"there does not exist a one-to-one correspondance " "there does not exist a one-to-one correspondance "
"between API and engine core processes.") "between API and engine core processes.")
model_config.set_disable_mm_preprocessor_cache(True) model_config.set_mm_processor_cache_gb(0)
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
target_model_config=model_config, target_model_config=model_config,

View File

@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers = args.api_server_count num_api_servers = args.api_server_count
assert num_api_servers > 0 assert num_api_servers > 0
orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache orig_mm_processor_cache_gb = args.mm_processor_cache_gb
if num_api_servers > 1: if num_api_servers > 1:
setup_multiprocess_prometheus() setup_multiprocess_prometheus()
# Not compatible with API server scale-out # Not compatible with API server scale-out
args.disable_mm_preprocessor_cache = True args.mm_processor_cache_gb = 0
listen_address, sock = setup_server(args) listen_address, sock = setup_server(args)
@ -161,8 +161,7 @@ def run_multi_api_server(args: argparse.Namespace):
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used " raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1") "with api_server_count > 1")
if model_config.is_multimodal_model and not ( if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
orig_disable_mm_preprocessor_cache):
logger.warning("Multi-modal processor cache is disabled because " logger.warning("Multi-modal processor cache is disabled because "
"it is not compatible with `api_server_count > 1`.") "it is not compatible with `api_server_count > 1`.")

View File

@ -561,7 +561,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_VIDEO_LOADER_BACKEND": "VLLM_VIDEO_LOADER_BACKEND":
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"), lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
# Cache size (in GiB per process) for multimodal input cache # [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process # Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": "VLLM_MM_INPUT_CACHE_GIB":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")), lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn import torch.nn as nn
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@ -96,11 +95,22 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module, self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]() _ProcessorFactories]()
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) self._processor_cache: Optional[ProcessingCache] = None
def _get_processor_cache(self, model_config: "ModelConfig"):
capacity_gb = model_config.mm_processor_cache_gb
if capacity_gb is None:
return None # Overrides `disable_cache` argument
if self._processor_cache is None:
self._processor_cache = ProcessingCache(capacity_gb)
return self._processor_cache
def reset_processor_cache(self) -> bool: def reset_processor_cache(self) -> bool:
"""Reset the multi-modal processing cache.""" """Reset the multi-modal processing cache."""
self._processing_cache.reset() if self._processor_cache:
self._processor_cache.reset()
return True # Success return True # Success
@ -244,14 +254,14 @@ class MultiModalRegistry:
if tokenizer is None and not model_config.skip_tokenizer_init: if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None: if disable_cache is None:
mm_config = model_config.get_multimodal_config() disable_cache = not model_config.enable_mm_processor_cache
disable_cache = mm_config.disable_mm_preprocessor_cache
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls] factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer) ctx = InputProcessingContext(model_config, tokenizer)
cache = None if disable_cache else self._processing_cache cache = None if disable_cache else self._get_processor_cache(
model_config)
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)

View File

@ -430,7 +430,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
raise ValueError( raise ValueError(
"The number of multi-modal positions and hashes must match. This " "The number of multi-modal positions and hashes must match. This "
"is likely because you did not enable MM hashing. " "is likely because you did not enable MM hashing. "
"Please set `disable_mm_preprocessor_cache=False`.") "Please set `mm_processor_cache_gb > 0`.")
# Note that we assume mm_positions is sorted by offset. # Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of # We do not need to check all mm inputs if the start token index is out of