mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 04:25:03 +08:00
[Bugfix] Avoid repeatedly creating dummy data during engine startup (#17935)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1df491c522
commit
61e0a506a3
@ -1232,6 +1232,9 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
self.engine.stop_profile()
|
self.engine.stop_profile()
|
||||||
|
|
||||||
|
async def reset_mm_cache(self) -> None:
|
||||||
|
self.engine.reset_mm_cache()
|
||||||
|
|
||||||
async def reset_prefix_cache(self,
|
async def reset_prefix_cache(self,
|
||||||
device: Optional[Device] = None) -> None:
|
device: Optional[Device] = None) -> None:
|
||||||
self.engine.reset_prefix_cache(device)
|
self.engine.reset_prefix_cache(device)
|
||||||
|
|||||||
@ -409,6 +409,9 @@ class LLMEngine:
|
|||||||
# the next step without re-scheduling.
|
# the next step without re-scheduling.
|
||||||
self._skip_scheduling_next_step = False
|
self._skip_scheduling_next_step = False
|
||||||
|
|
||||||
|
# Don't keep the dummy data in memory
|
||||||
|
self.reset_mm_cache()
|
||||||
|
|
||||||
def _initialize_kv_caches(self) -> None:
|
def _initialize_kv_caches(self) -> None:
|
||||||
"""Initialize the KV cache in the worker(s).
|
"""Initialize the KV cache in the worker(s).
|
||||||
|
|
||||||
@ -913,6 +916,10 @@ class LLMEngine:
|
|||||||
"""
|
"""
|
||||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||||
|
|
||||||
|
def reset_mm_cache(self) -> bool:
|
||||||
|
"""Reset the multi-modal cache."""
|
||||||
|
return self.input_preprocessor.mm_registry.reset_processor_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||||
"""Reset prefix cache for all devices."""
|
"""Reset prefix cache for all devices."""
|
||||||
|
|
||||||
|
|||||||
@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class RPCResetMultiModalCacheRequest(Enum):
|
||||||
|
RESET = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCResetPrefixCacheRequest:
|
class RPCResetPrefixCacheRequest:
|
||||||
device: Device
|
device: Device
|
||||||
@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse:
|
|||||||
|
|
||||||
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
|
||||||
RPCUProfileRequest, RPCLoadAdapterRequest,
|
RPCUProfileRequest, RPCLoadAdapterRequest,
|
||||||
|
RPCResetMultiModalCacheRequest,
|
||||||
RPCResetPrefixCacheRequest, RPCSleepRequest,
|
RPCResetPrefixCacheRequest, RPCSleepRequest,
|
||||||
RPCWakeUpRequest, RPCIsSleepingRequest]
|
RPCWakeUpRequest, RPCIsSleepingRequest]
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCIsSleepingResponse,
|
RPCIsSleepingResponse,
|
||||||
RPCLoadAdapterRequest,
|
RPCLoadAdapterRequest,
|
||||||
RPCProcessRequest,
|
RPCProcessRequest,
|
||||||
|
RPCResetMultiModalCacheRequest,
|
||||||
RPCResetPrefixCacheRequest,
|
RPCResetPrefixCacheRequest,
|
||||||
RPCSleepRequest, RPCStartupRequest,
|
RPCSleepRequest, RPCStartupRequest,
|
||||||
RPCStartupResponse,
|
RPCStartupResponse,
|
||||||
@ -687,6 +688,13 @@ class MQLLMEngineClient(EngineClient):
|
|||||||
await self._send_one_way_rpc_request(
|
await self._send_one_way_rpc_request(
|
||||||
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
|
||||||
|
|
||||||
|
async def reset_mm_cache(self) -> None:
|
||||||
|
"""Reset the multi-modal cache"""
|
||||||
|
|
||||||
|
await self._send_one_way_rpc_request(
|
||||||
|
request=RPCResetMultiModalCacheRequest.RESET,
|
||||||
|
socket=self.input_socket)
|
||||||
|
|
||||||
async def reset_prefix_cache(self,
|
async def reset_prefix_cache(self,
|
||||||
device: Optional[Device] = None) -> None:
|
device: Optional[Device] = None) -> None:
|
||||||
"""Reset the prefix cache"""
|
"""Reset the prefix cache"""
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCIsSleepingResponse,
|
RPCIsSleepingResponse,
|
||||||
RPCLoadAdapterRequest,
|
RPCLoadAdapterRequest,
|
||||||
RPCProcessRequest,
|
RPCProcessRequest,
|
||||||
|
RPCResetMultiModalCacheRequest,
|
||||||
RPCResetPrefixCacheRequest,
|
RPCResetPrefixCacheRequest,
|
||||||
RPCSleepRequest, RPCStartupRequest,
|
RPCSleepRequest, RPCStartupRequest,
|
||||||
RPCStartupResponse,
|
RPCStartupResponse,
|
||||||
@ -269,6 +270,8 @@ class MQLLMEngine:
|
|||||||
self.stop_profile()
|
self.stop_profile()
|
||||||
elif isinstance(request, RPCLoadAdapterRequest):
|
elif isinstance(request, RPCLoadAdapterRequest):
|
||||||
self._handle_load_adapter_request(request)
|
self._handle_load_adapter_request(request)
|
||||||
|
elif isinstance(request, RPCResetMultiModalCacheRequest):
|
||||||
|
self.reset_mm_cache()
|
||||||
elif isinstance(request, RPCResetPrefixCacheRequest):
|
elif isinstance(request, RPCResetPrefixCacheRequest):
|
||||||
self.reset_prefix_cache()
|
self.reset_prefix_cache()
|
||||||
elif isinstance(request, RPCSleepRequest):
|
elif isinstance(request, RPCSleepRequest):
|
||||||
@ -409,6 +412,9 @@ class MQLLMEngine:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.engine.stop_profile()
|
self.engine.stop_profile()
|
||||||
|
|
||||||
|
def reset_mm_cache(self) -> bool:
|
||||||
|
return self.engine.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> bool:
|
def reset_prefix_cache(self) -> bool:
|
||||||
return self.engine.reset_prefix_cache()
|
return self.engine.reset_prefix_cache()
|
||||||
|
|
||||||
|
|||||||
@ -278,6 +278,11 @@ class EngineClient(ABC):
|
|||||||
"""Start profiling the engine"""
|
"""Start profiling the engine"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def reset_mm_cache(self) -> None:
|
||||||
|
"""Reset the multi-modal cache"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def reset_prefix_cache(self,
|
async def reset_prefix_cache(self,
|
||||||
device: Optional[Device] = None) -> None:
|
device: Optional[Device] = None) -> None:
|
||||||
|
|||||||
@ -150,6 +150,10 @@ async def build_async_engine_client(
|
|||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
async with build_async_engine_client_from_engine_args(
|
||||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||||
|
|
||||||
|
# Don't keep the dummy data in memory
|
||||||
|
await engine.reset_mm_cache()
|
||||||
|
|
||||||
yield engine
|
yield engine
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1026,6 +1026,11 @@ class ProcessingCache:
|
|||||||
def put_item(self, item: ProcessingCacheItem) -> None:
|
def put_item(self, item: ProcessingCacheItem) -> None:
|
||||||
self._cache[item.key] = item.value
|
self._cache[item.key] = item.value
|
||||||
|
|
||||||
|
def reset(self) -> bool:
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class BaseProcessingInfo:
|
class BaseProcessingInfo:
|
||||||
"""Base class to provide the information necessary for data processing."""
|
"""Base class to provide the information necessary for data processing."""
|
||||||
|
|||||||
@ -88,6 +88,12 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
|
||||||
|
|
||||||
|
def reset_processor_cache(self) -> bool:
|
||||||
|
"""Reset the multi-modal processing cache."""
|
||||||
|
self._processing_cache.reset()
|
||||||
|
|
||||||
|
return True # Success
|
||||||
|
|
||||||
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
@deprecated("Legacy input processor/mapper pipeline has been removed. "
|
||||||
"Please update your model runner to use "
|
"Please update your model runner to use "
|
||||||
"`seq_group_metadata.multi_modal_data` directly without "
|
"`seq_group_metadata.multi_modal_data` directly without "
|
||||||
@ -106,7 +112,7 @@ class MultiModalRegistry:
|
|||||||
if not model_config.is_multimodal_model:
|
if not model_config.is_multimodal_model:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
processor = self.create_processor(model_config, disable_cache=True)
|
processor = self.create_processor(model_config, disable_cache=False)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
|
|
||||||
seq_len = model_config.max_model_len
|
seq_len = model_config.max_model_len
|
||||||
@ -190,7 +196,7 @@ class MultiModalRegistry:
|
|||||||
if not model_config.is_multimodal_model:
|
if not model_config.is_multimodal_model:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
processor = self.create_processor(model_config, disable_cache=True)
|
processor = self.create_processor(model_config, disable_cache=False)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
return profiler.get_mm_limits()
|
return profiler.get_mm_limits()
|
||||||
|
|
||||||
@ -286,7 +292,7 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
"""
|
"""
|
||||||
processor = self.create_processor(model_config, disable_cache=True)
|
processor = self.create_processor(model_config, disable_cache=False)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||||
|
|
||||||
@ -310,7 +316,7 @@ class MultiModalRegistry:
|
|||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
"""
|
"""
|
||||||
processor = self.create_processor(model_config, disable_cache=True)
|
processor = self.create_processor(model_config, disable_cache=False)
|
||||||
profiler = MultiModalProfiler(processor)
|
profiler = MultiModalProfiler(processor)
|
||||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||||
|
|
||||||
|
|||||||
@ -476,6 +476,11 @@ class AsyncLLM(EngineClient):
|
|||||||
async def stop_profile(self) -> None:
|
async def stop_profile(self) -> None:
|
||||||
await self.engine_core.profile_async(False)
|
await self.engine_core.profile_async(False)
|
||||||
|
|
||||||
|
async def reset_mm_cache(self) -> None:
|
||||||
|
self.processor.mm_registry.reset_processor_cache()
|
||||||
|
self.processor.mm_input_cache_client.reset()
|
||||||
|
await self.engine_core.reset_mm_cache_async()
|
||||||
|
|
||||||
async def reset_prefix_cache(self,
|
async def reset_prefix_cache(self,
|
||||||
device: Optional[Device] = None) -> None:
|
device: Optional[Device] = None) -> None:
|
||||||
if device == Device.CPU:
|
if device == Device.CPU:
|
||||||
|
|||||||
@ -286,6 +286,15 @@ class EngineCore:
|
|||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|
||||||
|
def reset_mm_cache(self):
|
||||||
|
# NOTE: Since this is mainly for debugging, we don't attempt to
|
||||||
|
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
|
||||||
|
if self.scheduler.get_num_unfinished_requests():
|
||||||
|
logger.warning("Resetting the multi-modal cache when requests are "
|
||||||
|
"in progress may lead to desynced internal caches.")
|
||||||
|
|
||||||
|
self.mm_input_cache_server.reset()
|
||||||
|
|
||||||
def reset_prefix_cache(self):
|
def reset_prefix_cache(self):
|
||||||
self.scheduler.reset_prefix_cache()
|
self.scheduler.reset_prefix_cache()
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,9 @@ class EngineCoreClient(ABC):
|
|||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset_mm_cache(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -143,6 +146,9 @@ class EngineCoreClient(ABC):
|
|||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reset_mm_cache_async(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -214,6 +220,9 @@ class InprocClient(EngineCoreClient):
|
|||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self.engine_core.profile(is_start)
|
self.engine_core.profile(is_start)
|
||||||
|
|
||||||
|
def reset_mm_cache(self) -> None:
|
||||||
|
self.engine_core.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self.engine_core.reset_prefix_cache()
|
self.engine_core.reset_prefix_cache()
|
||||||
|
|
||||||
@ -600,6 +609,9 @@ class SyncMPClient(MPClient):
|
|||||||
def profile(self, is_start: bool = True) -> None:
|
def profile(self, is_start: bool = True) -> None:
|
||||||
self.call_utility("profile", is_start)
|
self.call_utility("profile", is_start)
|
||||||
|
|
||||||
|
def reset_mm_cache(self) -> None:
|
||||||
|
self.call_utility("reset_mm_cache")
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self) -> None:
|
||||||
self.call_utility("reset_prefix_cache")
|
self.call_utility("reset_prefix_cache")
|
||||||
|
|
||||||
@ -787,6 +799,9 @@ class AsyncMPClient(MPClient):
|
|||||||
async def profile_async(self, is_start: bool = True) -> None:
|
async def profile_async(self, is_start: bool = True) -> None:
|
||||||
await self.call_utility_async("profile", is_start)
|
await self.call_utility_async("profile", is_start)
|
||||||
|
|
||||||
|
async def reset_mm_cache_async(self) -> None:
|
||||||
|
await self.call_utility_async("reset_mm_cache")
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(self) -> None:
|
||||||
await self.call_utility_async("reset_prefix_cache")
|
await self.call_utility_async("reset_prefix_cache")
|
||||||
|
|
||||||
|
|||||||
@ -101,6 +101,9 @@ class LLMEngine:
|
|||||||
# for v0 compatibility
|
# for v0 compatibility
|
||||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||||
|
|
||||||
|
# Don't keep the dummy data in memory
|
||||||
|
self.reset_mm_cache()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_vllm_config(
|
def from_vllm_config(
|
||||||
cls,
|
cls,
|
||||||
@ -240,6 +243,11 @@ class LLMEngine:
|
|||||||
def stop_profile(self):
|
def stop_profile(self):
|
||||||
self.engine_core.profile(False)
|
self.engine_core.profile(False)
|
||||||
|
|
||||||
|
def reset_mm_cache(self):
|
||||||
|
self.processor.mm_registry.reset_processor_cache()
|
||||||
|
self.processor.mm_input_cache_client.reset()
|
||||||
|
self.engine_core.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||||
self.engine_core.reset_prefix_cache()
|
self.engine_core.reset_prefix_cache()
|
||||||
|
|
||||||
|
|||||||
@ -83,3 +83,8 @@ class MirroredProcessingCache:
|
|||||||
full_mm_inputs.append(mm_input)
|
full_mm_inputs.append(mm_input)
|
||||||
|
|
||||||
return full_mm_inputs
|
return full_mm_inputs
|
||||||
|
|
||||||
|
def reset(self) -> bool:
|
||||||
|
self.mm_cache.clear()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|||||||
@ -54,6 +54,10 @@ class Processor:
|
|||||||
self.use_hash = self.mm_input_cache_client.use_cache or \
|
self.use_hash = self.mm_input_cache_client.use_cache or \
|
||||||
self.cache_config.enable_prefix_caching
|
self.cache_config.enable_prefix_caching
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_registry(self):
|
||||||
|
return self.input_preprocessor.mm_registry
|
||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user