diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 79255b031eec..3fc4f6445df2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -845,7 +845,8 @@ class LLMEngine: def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache() + return self.input_preprocessor.mm_registry.reset_processor_cache( + self.model_config) def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index dca04e9a1e22..565d54e1a264 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass +from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn @@ -86,6 +87,13 @@ class _ProcessorFactories(Generic[_I]): return self.processor(info, dummy_inputs_builder, cache=cache) +# Make sure a different cache is used for each model config +# NOTE: ModelConfig is not hashable so it cannot be passed directly +@lru_cache(maxsize=1) +def _get_processor_cache(model_id: str, capacity_gb: int): + return ProcessingCache(capacity_gb) if capacity_gb > 0 else None + + class MultiModalRegistry: """ A registry that dispatches data processing according to the model. @@ -95,22 +103,15 @@ class MultiModalRegistry: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - self._processor_cache: Optional[ProcessingCache] = None - def _get_processor_cache(self, model_config: "ModelConfig"): + model_id = model_config.model capacity_gb = model_config.mm_processor_cache_gb - if capacity_gb is None: - return None # Overrides `disable_cache` argument + return _get_processor_cache(model_id, capacity_gb) - if self._processor_cache is None: - self._processor_cache = ProcessingCache(capacity_gb) - - return self._processor_cache - - def reset_processor_cache(self) -> bool: + def reset_processor_cache(self, model_config: "ModelConfig") -> bool: """Reset the multi-modal processing cache.""" - if self._processor_cache: - self._processor_cache.reset() + if processor_cache := self._get_processor_cache(model_config): + processor_cache.reset() return True # Success diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 45f450291ab6..7b4ed90fd132 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -566,7 +566,7 @@ class AsyncLLM(EngineClient): await self.engine_core.profile_async(False) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache() + self.processor.mm_registry.reset_processor_cache(self.model_config) self.processor.mm_input_cache_client.reset() await self.engine_core.reset_mm_cache_async() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index efbdffbc0900..5a00a930951c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,7 +271,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache() + self.processor.mm_registry.reset_processor_cache(self.model_config) self.processor.mm_input_cache_client.reset() self.engine_core.reset_mm_cache()