[Core] Support reseting all running requests' KV while calling reset_prefix_cache (#28827)

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Zhuohan Li 2025-12-01 18:25:05 -08:00 committed by GitHub
parent fa8804ad9c
commit d0cd728907
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 315 additions and 35 deletions

View File

@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file demonstrates preempt requests when using the `LLMEngine`
for processing prompts with various sampling parameters.
"""
import argparse
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return [
(
"A robot may not injure a human being " * 50,
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
),
),
(
"A robot may not injure a human being " * 50,
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
),
),
(
"To be or not to be,",
SamplingParams(
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
),
),
(
"What is the meaning of life?",
SamplingParams(
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
),
),
]
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
print("-" * 50)
step_id = 0
while test_prompts or engine.has_unfinished_requests():
print("-" * 50)
import os
print(f"Step {step_id} (pid={os.getpid()})")
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
if step_id == 10:
print(f"Resetting prefix cache at {step_id}")
engine.reset_prefix_cache(reset_running_requests=True)
request_outputs: list[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
print("-" * 50)
print(request_output)
print("-" * 50)
step_id += 1
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
engine_args = EngineArgs.from_cli_args(args)
return LLMEngine.from_engine_args(engine_args)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using the LLMEngine class directly"
)
parser = EngineArgs.add_cli_args(parser)
return parser.parse_args()
def main(args: argparse.Namespace):
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine(args)
test_prompts = create_test_prompts()
process_requests(engine, test_prompts)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import EngineArgs, LLMEngine, SamplingParams
PROMPTS = [
"A robot may not injure a human being ",
"To be or not to be,",
"What is the meaning of life?",
"What does the fox say? " * 20, # Test long prompt
]
def test_reset_prefix_cache_e2e():
engine_args = EngineArgs(
model="Qwen/Qwen3-0.6B",
gpu_memory_utilization=0.2,
async_scheduling=True,
max_num_batched_tokens=32,
max_model_len=2048,
compilation_config={"mode": 0},
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=16,
)
# No preempt case:
for i, prompt in enumerate(PROMPTS):
engine.add_request("ground_truth_" + str(i), prompt, sampling_params)
ground_truth_results = {}
while engine.has_unfinished_requests():
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
ground_truth_results[request_output.request_id] = request_output
# Preempt case:
for i, prompt in enumerate(PROMPTS):
engine.add_request("preempted_" + str(i), prompt, sampling_params)
step_id = 0
preempted_results = {}
while engine.has_unfinished_requests():
if step_id == 10:
engine.reset_prefix_cache(reset_running_requests=True)
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
preempted_results[request_output.request_id] = request_output
step_id += 1
for i in range(len(PROMPTS)):
assert (
ground_truth_results["ground_truth_" + str(i)].outputs[0].text
== preempted_results["preempted_" + str(i)].outputs[0].text
), (
f"ground_truth_results['ground_truth_{i}'].outputs[0].text="
f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} "
f"preempted_results['preempted_{i}'].outputs[0].text="
f"{preempted_results['preempted_' + str(i)].outputs[0].text}"
)

View File

@ -728,6 +728,37 @@ def test_preempt_during_execution():
assert requests[1].output_token_ids[0] == 42
def test_scheduler_reset_prefix_cache():
scheduler = create_scheduler(enable_prefix_caching=True)
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Initial scheduling, requests should be at the running state now
_ = scheduler.schedule()
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
# Reset prefix cache should fail since there are still running requests
# and they are taking KV cache
assert not scheduler.reset_prefix_cache()
# Reset prefix cache with reset_running_requests=True. All running requests
# Should be pushed back to the waiting queue and kv cache should be freed
assert scheduler.reset_prefix_cache(reset_running_requests=True)
# Verify requests moved from running to waiting
assert len(scheduler.waiting) == len(requests)
assert len(scheduler.running) == 0
for i, request in enumerate(requests):
assert scheduler.waiting[i] == request
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",

View File

@ -116,7 +116,7 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
"""Reset the prefix cache"""
...

View File

@ -1492,8 +1492,8 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self) -> None:
self.llm_engine.reset_prefix_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.llm_engine.reset_prefix_cache(reset_running_requests)
def sleep(self, level: int = 1):
"""

View File

@ -877,13 +877,15 @@ if envs.VLLM_SERVER_DEV_MODE:
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
async def reset_prefix_cache(
raw_request: Request, reset_running_requests: bool = Query(default=False)
):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
return Response(status_code=200)
@router.post("/reset_mm_cache")

View File

@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler):
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
if request.discard_latest_async_tokens:
# If the request is force preempted in reset_prefix_cache, we
# should discard the latest async token.
request.discard_latest_async_tokens = False
return [], False
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids

View File

@ -152,10 +152,16 @@ class SchedulerInterface(ABC):
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
Args:
reset_running_requests: If True, all the running requests will be
preempted and moved to the waiting queue. Otherwise, this method
will only reset the KV prefix cache when there is no running request
taking KV cache.
"""
raise NotImplementedError

View File

@ -347,17 +347,7 @@ class Scheduler(SchedulerInterface):
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)
self.waiting.prepend_request(preempted_req)
self._preempt_request(preempted_req, scheduled_timestamp)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
@ -756,6 +746,30 @@ class Scheduler(SchedulerInterface):
self._update_after_schedule(scheduler_output)
return scheduler_output
def _preempt_request(
self,
request: Request,
timestamp: float,
) -> None:
"""Preempt a request and put it back to the waiting queue.
NOTE: The request should be popped from the running queue outside of this
method.
"""
assert request.status == RequestStatus.RUNNING, (
"Only running requests can be preempted"
)
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
request.num_preemptions += 1
if self.log_stats:
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
# Put the request back to the waiting queue.
self.waiting.prepend_request(request)
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
@ -1362,8 +1376,45 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
"""Reset the KV prefix cache.
If reset_running_requests is True, all the running requests will be
preempted and moved to the waiting queue.
Otherwise, this method will only reset the KV prefix cache when there
is no running requests taking KV cache.
"""
if reset_running_requests:
# For logging.
timestamp = time.monotonic()
# Invalidate all the current running requests KV's by pushing them to
# the waiting queue. In this case, we can reduce the ref count of all
# the kv blocks to 0 and thus we can make sure the reset is successful.
# Preempt in reverse order so the requests will be added back to the
# running queue in FIFO order.
while self.running:
request = self.running.pop()
self._preempt_request(request, timestamp)
# NOTE(zhuohan): For async scheduling, we need to discard the latest
# output token on the fly to avoid a redundant repetitive output token.
request.num_output_placeholders = 0
request.discard_latest_async_tokens = True
# Clear scheduled request ids cache. Since we are forcing preemption
# + resumption in the same step, we must act as if these requests were
# not scheduled in the prior step. They will be flushed from the
# persistent batch in the model runner.
self.prev_step_scheduled_req_ids.clear()
reset_successful = self.kv_cache_manager.reset_prefix_cache()
if reset_running_requests and not reset_successful:
raise RuntimeError(
"Failed to reset KV cache even when all the running requests are "
"preempted and moved to the waiting queue. This is likely due to "
"the presence of running requests waiting for remote KV transfer, "
"which is not supported yet."
)
return reset_successful
def make_stats(
self,

View File

@ -750,8 +750,8 @@ class AsyncLLM(EngineClient):
self.input_processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()

View File

@ -483,8 +483,8 @@ class EngineCore:
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.scheduler.reset_prefix_cache(reset_running_requests)
def sleep(self, level: int = 1):
self.model_executor.sleep(level)

View File

@ -138,7 +138,7 @@ class EngineCoreClient(ABC):
def reset_mm_cache(self) -> None:
raise NotImplementedError
def reset_prefix_cache(self) -> None:
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
@ -208,7 +208,9 @@ class EngineCoreClient(ABC):
async def reset_mm_cache_async(self) -> None:
raise NotImplementedError
async def reset_prefix_cache_async(self) -> None:
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
) -> bool:
raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None:
@ -287,8 +289,8 @@ class InprocClient(EngineCoreClient):
def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
@ -751,8 +753,8 @@ class SyncMPClient(MPClient):
def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")
def reset_prefix_cache(self) -> None:
self.call_utility("reset_prefix_cache")
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.call_utility("reset_prefix_cache", reset_running_requests)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
@ -955,8 +957,12 @@ class AsyncMPClient(MPClient):
async def reset_mm_cache_async(self) -> None:
await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async(self) -> None:
await self.call_utility_async("reset_prefix_cache")
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
) -> bool:
return await self.call_utility_async(
"reset_prefix_cache", reset_running_requests
)
async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level)

View File

@ -329,8 +329,8 @@ class LLMEngine:
self.input_processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def sleep(self, level: int = 1):
self.engine_core.sleep(level)

View File

@ -93,7 +93,12 @@ class Request:
if self.prompt_token_ids is not None
else [0] * self.num_prompt_tokens
)
self.num_output_placeholders = 0 # Used in async scheduling.
# Used in async scheduling.
self.num_output_placeholders = 0
# Used in forced preemption (reset_prefix_cache) with async scheduling.
self.discard_latest_async_tokens = False
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.cache_salt: str | None = cache_salt

View File

@ -482,6 +482,8 @@ class InputBatch:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:

View File

@ -774,7 +774,14 @@ class GPUModelRunner(
# they will be scheduled again sometime in the future.
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
cached_req_ids = self.input_batch.req_id_to_index.keys()
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
# NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint,
# so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds
# apart from the forced-preemption case in reset_prefix_cache. And in
# that case we include the resumed_req_ids in the unscheduled set so
# that they get cleared from the persistent batch before being re-scheduled
# in the normal resumed request path.
unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct