From 61e0a506a3a30445fddff21355936e9f83725c97 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 13 May 2025 13:40:19 +0800 Subject: [PATCH] [Bugfix] Avoid repeatedly creating dummy data during engine startup (#17935) Signed-off-by: DarkLight1337 --- vllm/engine/async_llm_engine.py | 3 +++ vllm/engine/llm_engine.py | 7 +++++++ vllm/engine/multiprocessing/__init__.py | 5 +++++ vllm/engine/multiprocessing/client.py | 8 ++++++++ vllm/engine/multiprocessing/engine.py | 6 ++++++ vllm/engine/protocol.py | 5 +++++ vllm/entrypoints/openai/api_server.py | 4 ++++ vllm/multimodal/processing.py | 5 +++++ vllm/multimodal/registry.py | 14 ++++++++++---- vllm/v1/engine/async_llm.py | 5 +++++ vllm/v1/engine/core.py | 9 +++++++++ vllm/v1/engine/core_client.py | 15 +++++++++++++++ vllm/v1/engine/llm_engine.py | 8 ++++++++ vllm/v1/engine/mm_input_cache.py | 5 +++++ vllm/v1/engine/processor.py | 4 ++++ 15 files changed, 99 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 37bb12d44287..56b9e49d24d9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1232,6 +1232,9 @@ class AsyncLLMEngine(EngineClient): async def stop_profile(self) -> None: self.engine.stop_profile() + async def reset_mm_cache(self) -> None: + self.engine.reset_mm_cache() + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: self.engine.reset_prefix_cache(device) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bed696d3dc00..2a27afe9757e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -409,6 +409,9 @@ class LLMEngine: # the next step without re-scheduling. self._skip_scheduling_next_step = False + # Don't keep the dummy data in memory + self.reset_mm_cache() + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -913,6 +916,10 @@ class LLMEngine: """ return self.scheduler[virtual_engine].has_unfinished_seqs() + def reset_mm_cache(self) -> bool: + """Reset the multi-modal cache.""" + return self.input_preprocessor.mm_registry.reset_processor_cache() + def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index cafd8150bc01..af72c8e6b776 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum): STOP_PROFILE = 2 +class RPCResetMultiModalCacheRequest(Enum): + RESET = 1 + + @dataclass class RPCResetPrefixCacheRequest: device: Device @@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse: RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPCUProfileRequest, RPCLoadAdapterRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCWakeUpRequest, RPCIsSleepingRequest] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 505d3d06b3ca..eea89a9a055f 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -31,6 +31,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, RPCIsSleepingResponse, RPCLoadAdapterRequest, RPCProcessRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, @@ -687,6 +688,13 @@ class MQLLMEngineClient(EngineClient): await self._send_one_way_rpc_request( request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + + await self._send_one_way_rpc_request( + request=RPCResetMultiModalCacheRequest.RESET, + socket=self.input_socket) + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: """Reset the prefix cache""" diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index a5dcf9e2d945..ac234d25373d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -22,6 +22,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, RPCIsSleepingResponse, RPCLoadAdapterRequest, RPCProcessRequest, + RPCResetMultiModalCacheRequest, RPCResetPrefixCacheRequest, RPCSleepRequest, RPCStartupRequest, RPCStartupResponse, @@ -269,6 +270,8 @@ class MQLLMEngine: self.stop_profile() elif isinstance(request, RPCLoadAdapterRequest): self._handle_load_adapter_request(request) + elif isinstance(request, RPCResetMultiModalCacheRequest): + self.reset_mm_cache() elif isinstance(request, RPCResetPrefixCacheRequest): self.reset_prefix_cache() elif isinstance(request, RPCSleepRequest): @@ -409,6 +412,9 @@ class MQLLMEngine: def stop_profile(self) -> None: self.engine.stop_profile() + def reset_mm_cache(self) -> bool: + return self.engine.reset_mm_cache() + def reset_prefix_cache(self) -> bool: return self.engine.reset_prefix_cache() diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e9350612ee57..a837a2d288a9 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -278,6 +278,11 @@ class EngineClient(ABC): """Start profiling the engine""" ... + @abstractmethod + async def reset_mm_cache(self) -> None: + """Reset the multi-modal cache""" + ... + @abstractmethod async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 25b6f98bb769..e809579c2b17 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -150,6 +150,10 @@ async def build_async_engine_client( async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: + + # Don't keep the dummy data in memory + await engine.reset_mm_cache() + yield engine diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 27b059b3ee62..92f9e70b5234 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1026,6 +1026,11 @@ class ProcessingCache: def put_item(self, item: ProcessingCacheItem) -> None: self._cache[item.key] = item.value + def reset(self) -> bool: + self._cache.clear() + + return True + class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3e62f4c43e10..67d0d7fc1183 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -88,6 +88,12 @@ class MultiModalRegistry: self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB) + def reset_processor_cache(self) -> bool: + """Reset the multi-modal processing cache.""" + self._processing_cache.reset() + + return True # Success + @deprecated("Legacy input processor/mapper pipeline has been removed. " "Please update your model runner to use " "`seq_group_metadata.multi_modal_data` directly without " @@ -106,7 +112,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len @@ -190,7 +196,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -286,7 +292,7 @@ class MultiModalRegistry: The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -310,7 +316,7 @@ class MultiModalRegistry: The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=True) + processor = self.create_processor(model_config, disable_cache=False) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 00ceb7d3d0c4..0d646d8dd575 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -476,6 +476,11 @@ class AsyncLLM(EngineClient): async def stop_profile(self) -> None: await self.engine_core.profile_async(False) + async def reset_mm_cache(self) -> None: + self.processor.mm_registry.reset_processor_cache() + self.processor.mm_input_cache_client.reset() + await self.engine_core.reset_mm_cache_async() + async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: if device == Device.CPU: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index fde60bbfa51f..5a493db8a5fe 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -286,6 +286,15 @@ class EngineCore: def profile(self, is_start: bool = True): self.model_executor.profile(is_start) + def reset_mm_cache(self): + # NOTE: Since this is mainly for debugging, we don't attempt to + # re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) + if self.scheduler.get_num_unfinished_requests(): + logger.warning("Resetting the multi-modal cache when requests are " + "in progress may lead to desynced internal caches.") + + self.mm_input_cache_server.reset() + def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 91a0a75a3081..c33317edcbb0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -88,6 +88,9 @@ class EngineCoreClient(ABC): def profile(self, is_start: bool = True) -> None: raise NotImplementedError + def reset_mm_cache(self) -> None: + raise NotImplementedError + def reset_prefix_cache(self) -> None: raise NotImplementedError @@ -143,6 +146,9 @@ class EngineCoreClient(ABC): async def profile_async(self, is_start: bool = True) -> None: raise NotImplementedError + async def reset_mm_cache_async(self) -> None: + raise NotImplementedError + async def reset_prefix_cache_async(self) -> None: raise NotImplementedError @@ -214,6 +220,9 @@ class InprocClient(EngineCoreClient): def profile(self, is_start: bool = True) -> None: self.engine_core.profile(is_start) + def reset_mm_cache(self) -> None: + self.engine_core.reset_mm_cache() + def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() @@ -600,6 +609,9 @@ class SyncMPClient(MPClient): def profile(self, is_start: bool = True) -> None: self.call_utility("profile", is_start) + def reset_mm_cache(self) -> None: + self.call_utility("reset_mm_cache") + def reset_prefix_cache(self) -> None: self.call_utility("reset_prefix_cache") @@ -787,6 +799,9 @@ class AsyncMPClient(MPClient): async def profile_async(self, is_start: bool = True) -> None: await self.call_utility_async("profile", is_start) + async def reset_mm_cache_async(self) -> None: + await self.call_utility_async("reset_mm_cache") + async def reset_prefix_cache_async(self) -> None: await self.call_utility_async("reset_prefix_cache") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index b471b153657f..112896d6c767 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -101,6 +101,9 @@ class LLMEngine: # for v0 compatibility self.model_executor = self.engine_core.engine_core.model_executor # type: ignore + # Don't keep the dummy data in memory + self.reset_mm_cache() + @classmethod def from_vllm_config( cls, @@ -240,6 +243,11 @@ class LLMEngine: def stop_profile(self): self.engine_core.profile(False) + def reset_mm_cache(self): + self.processor.mm_registry.reset_processor_cache() + self.processor.mm_input_cache_client.reset() + self.engine_core.reset_mm_cache() + def reset_prefix_cache(self, device: Optional[Device] = None): self.engine_core.reset_prefix_cache() diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index 64ece840fc4c..fcb90bebdb62 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -83,3 +83,8 @@ class MirroredProcessingCache: full_mm_inputs.append(mm_input) return full_mm_inputs + + def reset(self) -> bool: + self.mm_cache.clear() + + return True diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 66be88738535..64a756148780 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -54,6 +54,10 @@ class Processor: self.use_hash = self.mm_input_cache_client.use_cache or \ self.cache_config.enable_prefix_caching + @property + def mm_registry(self): + return self.input_preprocessor.mm_registry + def _validate_logprobs( self, params: SamplingParams,