[KVConnector][Feature] Support KV connector cache reset via /reset_prefix_cache (#27170)

Signed-off-by: tovam <tovam@pliops.com>
Signed-off-by: Tova Movshovitz <tovam@pliops.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Tova Movshovitz 2025-12-05 20:33:26 +02:00 committed by GitHub
parent 4e26d3b09e
commit adb315060c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 105 additions and 24 deletions

View File

@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
expose connector transfer stats via Prometheus.
"""
return None
def reset_cache(self) -> bool | None:
"""
Reset the connector's internal cache.
Returns:
bool: True if the cache was successfully reset, False otherwise.
"""
logger.debug(
"Connector cache reset requested, but %s does not implement reset_cache().",
type(self).__name__,
)
return None

View File

@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
per_engine_labelvalues,
prom_metrics,
)
def reset_cache(self) -> bool:
results = [c.reset_cache() is not False for c in self._connectors]
return all(results)

View File

@ -116,8 +116,10 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
"""Reset the prefix cache"""
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache and optionally any configured connector cache"""
...
@abstractmethod

View File

@ -1491,8 +1491,12 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.llm_engine.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.llm_engine.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
"""

View File

@ -663,14 +663,27 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(
raw_request: Request, reset_running_requests: bool = Query(default=False)
raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200)
@router.post("/reset_mm_cache")

View File

@ -152,7 +152,9 @@ class SchedulerInterface(ABC):
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.

View File

@ -1380,7 +1380,9 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the KV prefix cache.
If reset_running_requests is True, all the running requests will be
@ -1418,8 +1420,26 @@ class Scheduler(SchedulerInterface):
"the presence of running requests waiting for remote KV transfer, "
"which is not supported yet."
)
if reset_connector:
reset_successful = self.reset_connector_cache() and reset_successful
return reset_successful
def reset_connector_cache(self) -> bool:
if self.connector is None:
logger.warning("reset_connector called but no KV connector is configured.")
return False
if self.connector.reset_cache() is False:
return False
if self.log_stats:
assert self.connector_prefix_cache_stats is not None
self.connector_prefix_cache_stats.reset = True
return True
def make_stats(
self,
spec_decoding_stats: SpecDecodingStats | None = None,

View File

@ -749,8 +749,12 @@ class AsyncLLM(EngineClient):
self.input_processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.engine_core.reset_prefix_cache_async(
reset_running_requests, reset_connector
)
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()

View File

@ -503,8 +503,12 @@ class EngineCore:
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.scheduler.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.scheduler.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
self.model_executor.sleep(level)

View File

@ -138,7 +138,9 @@ class EngineCoreClient(ABC):
def reset_mm_cache(self) -> None:
raise NotImplementedError
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
@ -209,7 +211,7 @@ class EngineCoreClient(ABC):
raise NotImplementedError
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
@ -289,8 +291,12 @@ class InprocClient(EngineCoreClient):
def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
@ -753,8 +759,12 @@ class SyncMPClient(MPClient):
def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.call_utility("reset_prefix_cache", reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.call_utility(
"reset_prefix_cache", reset_running_requests, reset_connector
)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
@ -958,10 +968,10 @@ class AsyncMPClient(MPClient):
await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.call_utility_async(
"reset_prefix_cache", reset_running_requests
"reset_prefix_cache", reset_running_requests, reset_connector
)
async def sleep_async(self, level: int = 1) -> None:

View File

@ -328,8 +328,12 @@ class LLMEngine:
self.input_processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
self.engine_core.sleep(level)