diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index d37ec25675b72..8e9182a9bca4c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC): expose connector transfer stats via Prometheus. """ return None + + def reset_cache(self) -> bool | None: + """ + Reset the connector's internal cache. + + Returns: + bool: True if the cache was successfully reset, False otherwise. + """ + logger.debug( + "Connector cache reset requested, but %s does not implement reset_cache().", + type(self).__name__, + ) + + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 51d5df6c6ba15..c80dc1a567fdb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1): per_engine_labelvalues, prom_metrics, ) + + def reset_cache(self) -> bool: + results = [c.reset_cache() is not False for c in self._connectors] + return all(results) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1b6330c9f9b65..d94951a0cffc8 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -116,8 +116,10 @@ class EngineClient(ABC): ... @abstractmethod - async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - """Reset the prefix cache""" + async def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + """Reset the prefix cache and optionally any configured connector cache""" ... @abstractmethod diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 481a47a97f7d4..add9176340a9e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1491,8 +1491,12 @@ class LLM: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return self.llm_engine.reset_prefix_cache(reset_running_requests) + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.llm_engine.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): """ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2fa6afa2bacb5..7be601d824f34 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -663,14 +663,27 @@ if envs.VLLM_SERVER_DEV_MODE: @router.post("/reset_prefix_cache") async def reset_prefix_cache( - raw_request: Request, reset_running_requests: bool = Query(default=False) + raw_request: Request, + reset_running_requests: bool = Query(default=False), + reset_external: bool = Query(default=False), ): """ - Reset the prefix cache. Note that we currently do not check if the - prefix cache is successfully reset in the API server. + Reset the local prefix cache. + + Optionally, if the query parameter `reset_external=true` + also resets the external (connector-managed) prefix cache. + + Note that we currently do not check if the prefix cache + is successfully reset in the API server. + + Example: + POST /reset_prefix_cache?reset_external=true """ logger.info("Resetting prefix cache...") - await engine_client(raw_request).reset_prefix_cache(reset_running_requests) + + await engine_client(raw_request).reset_prefix_cache( + reset_running_requests, reset_external + ) return Response(status_code=200) @router.post("/reset_mm_cache") diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index c2f503ef2354e..596ab05ad320a 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -152,7 +152,9 @@ class SchedulerInterface(ABC): return self.has_unfinished_requests() or self.has_finished_requests() @abstractmethod - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: """Reset the prefix cache for KV cache. This is particularly required when the model weights are live-updated. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 75a7385df38b1..0a8efa2fd512f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1380,7 +1380,9 @@ class Scheduler(SchedulerInterface): def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: """Reset the KV prefix cache. If reset_running_requests is True, all the running requests will be @@ -1418,8 +1420,26 @@ class Scheduler(SchedulerInterface): "the presence of running requests waiting for remote KV transfer, " "which is not supported yet." ) + + if reset_connector: + reset_successful = self.reset_connector_cache() and reset_successful + return reset_successful + def reset_connector_cache(self) -> bool: + if self.connector is None: + logger.warning("reset_connector called but no KV connector is configured.") + return False + + if self.connector.reset_cache() is False: + return False + + if self.log_stats: + assert self.connector_prefix_cache_stats is not None + self.connector_prefix_cache_stats.reset = True + + return True + def make_stats( self, spec_decoding_stats: SpecDecodingStats | None = None, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ec5d6e95ce3aa..fd7e04dc02082 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -749,8 +749,12 @@ class AsyncLLM(EngineClient): self.input_processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return await self.engine_core.reset_prefix_cache_async(reset_running_requests) + async def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return await self.engine_core.reset_prefix_cache_async( + reset_running_requests, reset_connector + ) async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8e34dfcea7f61..3d3a1e138ddef 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -503,8 +503,12 @@ class EngineCore: self.model_executor.reset_mm_cache() - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return self.scheduler.reset_prefix_cache(reset_running_requests) + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.scheduler.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): self.model_executor.sleep(level) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index afa0593921d06..c936646aa7993 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -138,7 +138,9 @@ class EngineCoreClient(ABC): def reset_mm_cache(self) -> None: raise NotImplementedError - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: raise NotImplementedError def sleep(self, level: int = 1) -> None: @@ -209,7 +211,7 @@ class EngineCoreClient(ABC): raise NotImplementedError async def reset_prefix_cache_async( - self, reset_running_requests: bool = False + self, reset_running_requests: bool = False, reset_connector: bool = False ) -> bool: raise NotImplementedError @@ -289,8 +291,12 @@ class InprocClient(EngineCoreClient): def reset_mm_cache(self) -> None: self.engine_core.reset_mm_cache() - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return self.engine_core.reset_prefix_cache(reset_running_requests) + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.engine_core.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) @@ -753,8 +759,12 @@ class SyncMPClient(MPClient): def reset_mm_cache(self) -> None: self.call_utility("reset_mm_cache") - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return self.call_utility("reset_prefix_cache", reset_running_requests) + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.call_utility( + "reset_prefix_cache", reset_running_requests, reset_connector + ) def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) @@ -958,10 +968,10 @@ class AsyncMPClient(MPClient): await self.call_utility_async("reset_mm_cache") async def reset_prefix_cache_async( - self, reset_running_requests: bool = False + self, reset_running_requests: bool = False, reset_connector: bool = False ) -> bool: return await self.call_utility_async( - "reset_prefix_cache", reset_running_requests + "reset_prefix_cache", reset_running_requests, reset_connector ) async def sleep_async(self, level: int = 1) -> None: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 8772f2e488dc0..4c31291005477 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -328,8 +328,12 @@ class LLMEngine: self.input_processor.clear_mm_cache() self.engine_core.reset_mm_cache() - def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: - return self.engine_core.reset_prefix_cache(reset_running_requests) + def reset_prefix_cache( + self, reset_running_requests: bool = False, reset_connector: bool = False + ) -> bool: + return self.engine_core.reset_prefix_cache( + reset_running_requests, reset_connector + ) def sleep(self, level: int = 1): self.engine_core.sleep(level)