mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 11:55:42 +08:00
[Frontend] Use engine argument to control MM cache size (#22441)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8c9da6be22
commit
139d155781
@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
|
|||||||
|
|
||||||
If you run out of CPU RAM, try the following options:
|
If you run out of CPU RAM, try the following options:
|
||||||
|
|
||||||
- (Multi-modal models only) you can set the size of multi-modal processor cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB per API process + 4 GiB per engine core process)
|
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
|
||||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||||
|
|
||||||
## Multi-modal input limits
|
## Multi-modal input limits
|
||||||
|
|||||||
@ -161,12 +161,18 @@ By default, the multi-modal processor cache is enabled to avoid repeatedly proce
|
|||||||
the same multi-modal inputs via Hugging Face `AutoProcessor`,
|
the same multi-modal inputs via Hugging Face `AutoProcessor`,
|
||||||
which commonly occurs in multi-turn conversations.
|
which commonly occurs in multi-turn conversations.
|
||||||
|
|
||||||
You can adjust the size of the cache via `VLLM_MM_INPUT_CACHE_GIB` environment variable
|
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
|
||||||
(default 4 GiB per API process + 4 GiB per engine core process).
|
(default 4 GiB per API process + 4 GiB per engine core process).
|
||||||
|
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
|
||||||
|
|
||||||
If you do not benefit much from the cache, you can disable it completely via `disable_mm_preprocessor_cache`:
|
Examples:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# Use a larger cache
|
||||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
disable_mm_preprocessor_cache=True)
|
mm_processor_cache_gb=8)
|
||||||
|
|
||||||
|
# Disable the cache
|
||||||
|
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
mm_processor_cache_gb=0)
|
||||||
```
|
```
|
||||||
|
|||||||
@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace):
|
|||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "Describe this image in one sentence."
|
prompt = "Describe this image in one sentence."
|
||||||
@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
|||||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = "Describe the following image."
|
prompt = "Describe the following image."
|
||||||
@ -164,7 +164,7 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-mm-preprocessor-cache",
|
"--disable-mm-processor-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="If True, disables caching of multi-modal processor.",
|
help="If True, disables caching of multi-modal processor.",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1563,7 +1563,7 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-mm-preprocessor-cache",
|
"--disable-mm-processor-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="If True, disables caching of multi-modal processor.",
|
help="If True, disables caching of multi-modal processor.",
|
||||||
)
|
)
|
||||||
@ -1603,7 +1603,7 @@ def main(args):
|
|||||||
|
|
||||||
engine_args = asdict(req_data.engine_args) | {
|
engine_args = asdict(req_data.engine_args) | {
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
|
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
|
||||||
}
|
}
|
||||||
llm = LLM(**engine_args)
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
|
|||||||
@ -62,9 +62,7 @@ def run_test(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
vllm_runner_kwargs_: dict[str, Any] = {
|
vllm_runner_kwargs_: dict[str, Any] = {"mm_processor_cache_gb": 0}
|
||||||
"disable_mm_preprocessor_cache": True,
|
|
||||||
}
|
|
||||||
if model_info.tokenizer:
|
if model_info.tokenizer:
|
||||||
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
|
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
|
||||||
if model_info.tokenizer_mode:
|
if model_info.tokenizer_mode:
|
||||||
|
|||||||
@ -15,14 +15,14 @@ from ...utils import build_model_context
|
|||||||
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
|
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
|
||||||
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
|
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("num_imgs", [1, 5])
|
@pytest.mark.parametrize("num_imgs", [1, 5])
|
||||||
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
|
@pytest.mark.parametrize("mm_processor_cache_gb", [0, 4])
|
||||||
@pytest.mark.parametrize("tokenized_prompt", [True, False])
|
@pytest.mark.parametrize("tokenized_prompt", [True, False])
|
||||||
def test_processor_override(
|
def test_processor_override(
|
||||||
image_assets: ImageTestAssets,
|
image_assets: ImageTestAssets,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
mm_processor_kwargs: dict,
|
mm_processor_kwargs: dict,
|
||||||
num_imgs: int,
|
num_imgs: int,
|
||||||
disable_mm_preprocessor_cache: bool,
|
mm_processor_cache_gb: int,
|
||||||
tokenized_prompt: bool,
|
tokenized_prompt: bool,
|
||||||
):
|
):
|
||||||
"""Ensure llama4 processor works properly."""
|
"""Ensure llama4 processor works properly."""
|
||||||
@ -30,7 +30,7 @@ def test_processor_override(
|
|||||||
model_id,
|
model_id,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
limit_mm_per_prompt={"image": num_imgs},
|
limit_mm_per_prompt={"image": num_imgs},
|
||||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||||
)
|
)
|
||||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||||
config = processor.info.get_hf_config()
|
config = processor.info.get_hf_config()
|
||||||
|
|||||||
@ -261,7 +261,7 @@ def build_model_context(
|
|||||||
model_config_kwargs: Optional[dict[str, Any]] = None,
|
model_config_kwargs: Optional[dict[str, Any]] = None,
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||||
disable_mm_preprocessor_cache: bool = True,
|
mm_processor_cache_gb: int = 0,
|
||||||
):
|
):
|
||||||
"""Creates an InputContext for a given model.
|
"""Creates an InputContext for a given model.
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ def build_model_context(
|
|||||||
seed=0,
|
seed=0,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
**model_config_kwargs,
|
**model_config_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -443,8 +443,15 @@ class ModelConfig:
|
|||||||
from `AutoProcessor.from_pretrained`. The available overrides depend on the
|
from `AutoProcessor.from_pretrained`. The available overrides depend on the
|
||||||
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||||
"""
|
"""
|
||||||
disable_mm_preprocessor_cache: bool = False
|
mm_processor_cache_gb: int = 4
|
||||||
"""If `True`, disable caching of the multi-modal processor."""
|
"""The size (in GiB) of the multi-modal processor cache, which is used to
|
||||||
|
avoid re-processing past multi-modal inputs.
|
||||||
|
|
||||||
|
This cache is duplicated for each API process and engine core process,
|
||||||
|
resulting in a total memory usage of
|
||||||
|
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||||
|
|
||||||
|
Set to `0` to disable this cache completely (not recommended)."""
|
||||||
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Initialize non-default neuron config or override default neuron config
|
"""Initialize non-default neuron config or override default neuron config
|
||||||
that are specific to Neuron devices, this argument will be used to
|
that are specific to Neuron devices, this argument will be used to
|
||||||
@ -881,17 +888,16 @@ class ModelConfig:
|
|||||||
limit_per_prompt=self.limit_mm_per_prompt,
|
limit_per_prompt=self.limit_mm_per_prompt,
|
||||||
media_io_kwargs=self.media_io_kwargs,
|
media_io_kwargs=self.media_io_kwargs,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
disable_mm_preprocessor_cache=self.
|
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||||
disable_mm_preprocessor_cache,
|
|
||||||
interleave_mm_strings=self.interleave_mm_strings)
|
interleave_mm_strings=self.interleave_mm_strings)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_disable_mm_preprocessor_cache(self, value: bool) -> None:
|
def set_mm_processor_cache_gb(self, value: int) -> None:
|
||||||
mm_config = self.get_multimodal_config()
|
mm_config = self.get_multimodal_config()
|
||||||
|
|
||||||
self.disable_mm_preprocessor_cache = value
|
self.mm_processor_cache_gb = value
|
||||||
mm_config.disable_mm_preprocessor_cache = value
|
mm_config.mm_processor_cache_gb = value
|
||||||
|
|
||||||
def _get_encoder_config(self):
|
def _get_encoder_config(self):
|
||||||
return get_sentence_transformer_tokenizer_config(
|
return get_sentence_transformer_tokenizer_config(
|
||||||
@ -1698,7 +1704,16 @@ class ModelConfig:
|
|||||||
if mm_config is None:
|
if mm_config is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return not mm_config.disable_mm_preprocessor_cache
|
return mm_config.mm_processor_cache_gb > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_mm_processor_cache(self) -> bool:
|
||||||
|
"""Whether the multi-modal processor cache should be enabled."""
|
||||||
|
mm_config = self.multimodal_config
|
||||||
|
if mm_config is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return mm_config.mm_processor_cache_gb > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enable_mm_input_cache(self) -> bool:
|
def enable_mm_input_cache(self) -> bool:
|
||||||
@ -1707,7 +1722,7 @@ class ModelConfig:
|
|||||||
if mm_config is None:
|
if mm_config is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return not mm_config.disable_mm_preprocessor_cache
|
return mm_config.mm_processor_cache_gb > 0
|
||||||
|
|
||||||
def get_mm_input_cache_gb(self) -> int:
|
def get_mm_input_cache_gb(self) -> int:
|
||||||
mm_config = self.multimodal_config
|
mm_config = self.multimodal_config
|
||||||
@ -3391,9 +3406,15 @@ class MultiModalConfig:
|
|||||||
`{"num_crops": 4}`.
|
`{"num_crops": 4}`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
disable_mm_preprocessor_cache: bool = False
|
mm_processor_cache_gb: int = 4
|
||||||
"""
|
"""
|
||||||
If `True`, disable caching of the multi-modal processor.
|
The size (in GiB) of the multi-modal processor cache, which is used to
|
||||||
|
|
||||||
|
This cache is duplicated for each API process and engine core process,
|
||||||
|
resulting in a total memory usage of
|
||||||
|
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||||
|
|
||||||
|
Set to `0` to disable this cache completely (not recommended).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
interleave_mm_strings: bool = False
|
interleave_mm_strings: bool = False
|
||||||
|
|||||||
@ -358,8 +358,8 @@ class EngineArgs:
|
|||||||
"media_io_kwargs")
|
"media_io_kwargs")
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||||
MultiModalConfig.mm_processor_kwargs
|
MultiModalConfig.mm_processor_kwargs
|
||||||
disable_mm_preprocessor_cache: bool = \
|
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||||
MultiModalConfig.disable_mm_preprocessor_cache
|
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
||||||
# LoRA fields
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
@ -720,8 +720,11 @@ class EngineArgs:
|
|||||||
"--mm-processor-kwargs",
|
"--mm-processor-kwargs",
|
||||||
**multimodal_kwargs["mm_processor_kwargs"])
|
**multimodal_kwargs["mm_processor_kwargs"])
|
||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--disable-mm-preprocessor-cache",
|
"--mm-processor-cache-gb",
|
||||||
**multimodal_kwargs["disable_mm_preprocessor_cache"])
|
**multimodal_kwargs["mm_processor_cache_gb"])
|
||||||
|
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
|
||||||
|
type=bool,
|
||||||
|
deprecated=True)
|
||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--interleave-mm-strings",
|
"--interleave-mm-strings",
|
||||||
**multimodal_kwargs["interleave_mm_strings"])
|
**multimodal_kwargs["interleave_mm_strings"])
|
||||||
@ -886,6 +889,23 @@ class EngineArgs:
|
|||||||
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
||||||
self.load_format = "runai_streamer"
|
self.load_format = "runai_streamer"
|
||||||
|
|
||||||
|
if self.disable_mm_preprocessor_cache:
|
||||||
|
logger.warning(
|
||||||
|
"`--disable-mm-preprocessor-cache` is deprecated "
|
||||||
|
"and will be removed in v0.13. "
|
||||||
|
"Please use `--mm-processor-cache-gb 0` instead.", )
|
||||||
|
|
||||||
|
self.mm_processor_cache_gb = 0
|
||||||
|
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
|
||||||
|
logger.warning(
|
||||||
|
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
|
||||||
|
"and will be removed in v0.13. "
|
||||||
|
"Please use `--mm-processor-cache-gb %d` instead.",
|
||||||
|
envs.VLLM_MM_INPUT_CACHE_GIB,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
|
||||||
|
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
hf_config_path=self.hf_config_path,
|
hf_config_path=self.hf_config_path,
|
||||||
@ -922,7 +942,7 @@ class EngineArgs:
|
|||||||
use_async_output_proc=not self.disable_async_output_proc,
|
use_async_output_proc=not self.disable_async_output_proc,
|
||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||||
override_neuron_config=self.override_neuron_config,
|
override_neuron_config=self.override_neuron_config,
|
||||||
override_pooler_config=self.override_pooler_config,
|
override_pooler_config=self.override_pooler_config,
|
||||||
logits_processor_pattern=self.logits_processor_pattern,
|
logits_processor_pattern=self.logits_processor_pattern,
|
||||||
@ -1234,13 +1254,13 @@ class EngineArgs:
|
|||||||
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
|
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
|
||||||
or data_parallel_external_lb)
|
or data_parallel_external_lb)
|
||||||
if (not dp_supports_mm_processor_cache
|
if (not dp_supports_mm_processor_cache
|
||||||
and not model_config.disable_mm_preprocessor_cache):
|
and model_config.mm_processor_cache_gb > 0):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Multi-modal processor cache is disabled because "
|
"Multi-modal processor cache is disabled because "
|
||||||
"it is not compatible with data parallelism when "
|
"it is not compatible with data parallelism when "
|
||||||
"there does not exist a one-to-one correspondance "
|
"there does not exist a one-to-one correspondance "
|
||||||
"between API and engine core processes.")
|
"between API and engine core processes.")
|
||||||
model_config.set_disable_mm_preprocessor_cache(True)
|
model_config.set_mm_processor_cache_gb(0)
|
||||||
|
|
||||||
speculative_config = self.create_speculative_config(
|
speculative_config = self.create_speculative_config(
|
||||||
target_model_config=model_config,
|
target_model_config=model_config,
|
||||||
|
|||||||
@ -138,13 +138,13 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
num_api_servers = args.api_server_count
|
num_api_servers = args.api_server_count
|
||||||
assert num_api_servers > 0
|
assert num_api_servers > 0
|
||||||
|
|
||||||
orig_disable_mm_preprocessor_cache = args.disable_mm_preprocessor_cache
|
orig_mm_processor_cache_gb = args.mm_processor_cache_gb
|
||||||
|
|
||||||
if num_api_servers > 1:
|
if num_api_servers > 1:
|
||||||
setup_multiprocess_prometheus()
|
setup_multiprocess_prometheus()
|
||||||
|
|
||||||
# Not compatible with API server scale-out
|
# Not compatible with API server scale-out
|
||||||
args.disable_mm_preprocessor_cache = True
|
args.mm_processor_cache_gb = 0
|
||||||
|
|
||||||
listen_address, sock = setup_server(args)
|
listen_address, sock = setup_server(args)
|
||||||
|
|
||||||
@ -161,8 +161,7 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||||
"with api_server_count > 1")
|
"with api_server_count > 1")
|
||||||
|
|
||||||
if model_config.is_multimodal_model and not (
|
if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
|
||||||
orig_disable_mm_preprocessor_cache):
|
|
||||||
logger.warning("Multi-modal processor cache is disabled because "
|
logger.warning("Multi-modal processor cache is disabled because "
|
||||||
"it is not compatible with `api_server_count > 1`.")
|
"it is not compatible with `api_server_count > 1`.")
|
||||||
|
|
||||||
|
|||||||
@ -561,7 +561,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_VIDEO_LOADER_BACKEND":
|
"VLLM_VIDEO_LOADER_BACKEND":
|
||||||
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
|
lambda: os.getenv("VLLM_VIDEO_LOADER_BACKEND", "opencv"),
|
||||||
|
|
||||||
# Cache size (in GiB per process) for multimodal input cache
|
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
||||||
# Default is 4 GiB per API process + 4 GiB per engine core process
|
# Default is 4 GiB per API process + 4 GiB per engine core process
|
||||||
"VLLM_MM_INPUT_CACHE_GIB":
|
"VLLM_MM_INPUT_CACHE_GIB":
|
||||||
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
|||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||||
@ -96,11 +95,22 @@ class MultiModalRegistry:
|
|||||||
self._processor_factories = ClassRegistry[nn.Module,
|
self._processor_factories = ClassRegistry[nn.Module,
|
||||||
_ProcessorFactories]()
|
_ProcessorFactories]()
|
||||||
|
|
||||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
self._processor_cache: Optional[ProcessingCache] = None
|
||||||
|
|
||||||
|
def _get_processor_cache(self, model_config: "ModelConfig"):
|
||||||
|
capacity_gb = model_config.mm_processor_cache_gb
|
||||||
|
if capacity_gb is None:
|
||||||
|
return None # Overrides `disable_cache` argument
|
||||||
|
|
||||||
|
if self._processor_cache is None:
|
||||||
|
self._processor_cache = ProcessingCache(capacity_gb)
|
||||||
|
|
||||||
|
return self._processor_cache
|
||||||
|
|
||||||
def reset_processor_cache(self) -> bool:
|
def reset_processor_cache(self) -> bool:
|
||||||
"""Reset the multi-modal processing cache."""
|
"""Reset the multi-modal processing cache."""
|
||||||
self._processing_cache.reset()
|
if self._processor_cache:
|
||||||
|
self._processor_cache.reset()
|
||||||
|
|
||||||
return True # Success
|
return True # Success
|
||||||
|
|
||||||
@ -244,14 +254,14 @@ class MultiModalRegistry:
|
|||||||
if tokenizer is None and not model_config.skip_tokenizer_init:
|
if tokenizer is None and not model_config.skip_tokenizer_init:
|
||||||
tokenizer = cached_tokenizer_from_config(model_config)
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
if disable_cache is None:
|
if disable_cache is None:
|
||||||
mm_config = model_config.get_multimodal_config()
|
disable_cache = not model_config.enable_mm_processor_cache
|
||||||
disable_cache = mm_config.disable_mm_preprocessor_cache
|
|
||||||
|
|
||||||
model_cls = self._get_model_cls(model_config)
|
model_cls = self._get_model_cls(model_config)
|
||||||
factories = self._processor_factories[model_cls]
|
factories = self._processor_factories[model_cls]
|
||||||
|
|
||||||
ctx = InputProcessingContext(model_config, tokenizer)
|
ctx = InputProcessingContext(model_config, tokenizer)
|
||||||
cache = None if disable_cache else self._processing_cache
|
cache = None if disable_cache else self._get_processor_cache(
|
||||||
|
model_config)
|
||||||
|
|
||||||
return factories.build_processor(ctx, cache=cache)
|
return factories.build_processor(ctx, cache=cache)
|
||||||
|
|
||||||
|
|||||||
@ -430,7 +430,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The number of multi-modal positions and hashes must match. This "
|
"The number of multi-modal positions and hashes must match. This "
|
||||||
"is likely because you did not enable MM hashing. "
|
"is likely because you did not enable MM hashing. "
|
||||||
"Please set `disable_mm_preprocessor_cache=False`.")
|
"Please set `mm_processor_cache_gb > 0`.")
|
||||||
|
|
||||||
# Note that we assume mm_positions is sorted by offset.
|
# Note that we assume mm_positions is sorted by offset.
|
||||||
# We do not need to check all mm inputs if the start token index is out of
|
# We do not need to check all mm inputs if the start token index is out of
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user