From d0cd728907d82a109a165612b8790ddaf5496f59 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 1 Dec 2025 18:25:05 -0800 Subject: [PATCH] [Core] Support reseting all running requests' KV while calling `reset_prefix_cache` (#28827) Signed-off-by: Zhuohan Li Signed-off-by: Nick Hill Co-authored-by: Nick Hill --- .../offline_inference/llm_engine_reset_kv.py | 98 +++++++++++++++++++ tests/v1/core/test_reset_prefix_cache_e2e.py | 66 +++++++++++++ tests/v1/core/test_scheduler.py | 31 ++++++ vllm/engine/protocol.py | 2 +- vllm/entrypoints/llm.py | 4 +- vllm/entrypoints/openai/api_server.py | 6 +- vllm/v1/core/sched/async_scheduler.py | 6 ++ vllm/v1/core/sched/interface.py | 8 +- vllm/v1/core/sched/scheduler.py | 77 ++++++++++++--- vllm/v1/engine/async_llm.py | 4 +- vllm/v1/engine/core.py | 4 +- vllm/v1/engine/core_client.py | 22 +++-- vllm/v1/engine/llm_engine.py | 4 +- vllm/v1/request.py | 7 +- vllm/v1/worker/gpu_input_batch.py | 2 + vllm/v1/worker/gpu_model_runner.py | 9 +- 16 files changed, 315 insertions(+), 35 deletions(-) create mode 100644 examples/offline_inference/llm_engine_reset_kv.py create mode 100644 tests/v1/core/test_reset_prefix_cache_e2e.py diff --git a/examples/offline_inference/llm_engine_reset_kv.py b/examples/offline_inference/llm_engine_reset_kv.py new file mode 100644 index 000000000000..3fbe7fa7545e --- /dev/null +++ b/examples/offline_inference/llm_engine_reset_kv.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This file demonstrates preempt requests when using the `LLMEngine` +for processing prompts with various sampling parameters. +""" + +import argparse + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.utils.argparse_utils import FlexibleArgumentParser + + +def create_test_prompts() -> list[tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters.""" + return [ + ( + "A robot may not injure a human being " * 50, + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16 + ), + ), + ( + "A robot may not injure a human being " * 50, + SamplingParams( + temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16 + ), + ), + ( + "To be or not to be,", + SamplingParams( + temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128 + ), + ), + ( + "What is the meaning of life?", + SamplingParams( + n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128 + ), + ), + ] + + +def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + print("-" * 50) + step_id = 0 + while test_prompts or engine.has_unfinished_requests(): + print("-" * 50) + import os + + print(f"Step {step_id} (pid={os.getpid()})") + + if test_prompts: + prompt, sampling_params = test_prompts.pop(0) + engine.add_request(str(request_id), prompt, sampling_params) + request_id += 1 + + if step_id == 10: + print(f"Resetting prefix cache at {step_id}") + engine.reset_prefix_cache(reset_running_requests=True) + + request_outputs: list[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print("-" * 50) + print(request_output) + print("-" * 50) + step_id += 1 + + +def initialize_engine(args: argparse.Namespace) -> LLMEngine: + """Initialize the LLMEngine from the command line arguments.""" + engine_args = EngineArgs.from_cli_args(args) + return LLMEngine.from_engine_args(engine_args) + + +def parse_args(): + parser = FlexibleArgumentParser( + description="Demo on using the LLMEngine class directly" + ) + parser = EngineArgs.add_cli_args(parser) + return parser.parse_args() + + +def main(args: argparse.Namespace): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine(args) + test_prompts = create_test_prompts() + process_requests(engine, test_prompts) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/v1/core/test_reset_prefix_cache_e2e.py b/tests/v1/core/test_reset_prefix_cache_e2e.py new file mode 100644 index 000000000000..e543c30a156e --- /dev/null +++ b/tests/v1/core/test_reset_prefix_cache_e2e.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import EngineArgs, LLMEngine, SamplingParams + +PROMPTS = [ + "A robot may not injure a human being ", + "To be or not to be,", + "What is the meaning of life?", + "What does the fox say? " * 20, # Test long prompt +] + + +def test_reset_prefix_cache_e2e(): + engine_args = EngineArgs( + model="Qwen/Qwen3-0.6B", + gpu_memory_utilization=0.2, + async_scheduling=True, + max_num_batched_tokens=32, + max_model_len=2048, + compilation_config={"mode": 0}, + ) + engine = LLMEngine.from_engine_args(engine_args) + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=16, + ) + + # No preempt case: + for i, prompt in enumerate(PROMPTS): + engine.add_request("ground_truth_" + str(i), prompt, sampling_params) + + ground_truth_results = {} + while engine.has_unfinished_requests(): + request_outputs = engine.step() + for request_output in request_outputs: + if request_output.finished: + ground_truth_results[request_output.request_id] = request_output + + # Preempt case: + for i, prompt in enumerate(PROMPTS): + engine.add_request("preempted_" + str(i), prompt, sampling_params) + + step_id = 0 + preempted_results = {} + while engine.has_unfinished_requests(): + if step_id == 10: + engine.reset_prefix_cache(reset_running_requests=True) + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + preempted_results[request_output.request_id] = request_output + step_id += 1 + + for i in range(len(PROMPTS)): + assert ( + ground_truth_results["ground_truth_" + str(i)].outputs[0].text + == preempted_results["preempted_" + str(i)].outputs[0].text + ), ( + f"ground_truth_results['ground_truth_{i}'].outputs[0].text=" + f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} " + f"preempted_results['preempted_{i}'].outputs[0].text=" + f"{preempted_results['preempted_' + str(i)].outputs[0].text}" + ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index fe4153e60997..0051c11d18d8 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -728,6 +728,37 @@ def test_preempt_during_execution(): assert requests[1].output_token_ids[0] == 42 +def test_scheduler_reset_prefix_cache(): + scheduler = create_scheduler(enable_prefix_caching=True) + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Initial scheduling, requests should be at the running state now + _ = scheduler.schedule() + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + # Reset prefix cache should fail since there are still running requests + # and they are taking KV cache + assert not scheduler.reset_prefix_cache() + + # Reset prefix cache with reset_running_requests=True. All running requests + # Should be pushed back to the waiting queue and kv cache should be freed + assert scheduler.reset_prefix_cache(reset_running_requests=True) + + # Verify requests moved from running to waiting + assert len(scheduler.waiting) == len(requests) + assert len(scheduler.running) == 0 + + for i, request in enumerate(requests): + assert scheduler.waiting[i] == request + + # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index f2b19c845018..1b6330c9f9b6 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -116,7 +116,7 @@ class EngineClient(ABC): ... @abstractmethod - async def reset_prefix_cache(self) -> None: + async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: """Reset the prefix cache""" ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f005605c08d7..c121fa71f019 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1492,8 +1492,8 @@ class LLM: def stop_profile(self) -> None: self.llm_engine.stop_profile() - def reset_prefix_cache(self) -> None: - self.llm_engine.reset_prefix_cache() + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return self.llm_engine.reset_prefix_cache(reset_running_requests) def sleep(self, level: int = 1): """ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 92161f67f1cf..cdc316b65ba7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -877,13 +877,15 @@ if envs.VLLM_SERVER_DEV_MODE: return JSONResponse(content=server_info) @router.post("/reset_prefix_cache") - async def reset_prefix_cache(raw_request: Request): + async def reset_prefix_cache( + raw_request: Request, reset_running_requests: bool = Query(default=False) + ): """ Reset the prefix cache. Note that we currently do not check if the prefix cache is successfully reset in the API server. """ logger.info("Resetting prefix cache...") - await engine_client(raw_request).reset_prefix_cache() + await engine_client(raw_request).reset_prefix_cache(reset_running_requests) return Response(status_code=200) @router.post("/reset_mm_cache") diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 7916fafdae1f..df61eebb395e 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler): request: Request, new_token_ids: list[int], ) -> tuple[list[int], bool]: + if request.discard_latest_async_tokens: + # If the request is force preempted in reset_prefix_cache, we + # should discard the latest async token. + request.discard_latest_async_tokens = False + return [], False + status_before_update = request.status new_token_ids, stopped = super()._update_request_with_output( request, new_token_ids diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 88d99d940282..c2f503ef2354 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -152,10 +152,16 @@ class SchedulerInterface(ABC): return self.has_unfinished_requests() or self.has_finished_requests() @abstractmethod - def reset_prefix_cache(self) -> bool: + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: """Reset the prefix cache for KV cache. This is particularly required when the model weights are live-updated. + + Args: + reset_running_requests: If True, all the running requests will be + preempted and moved to the waiting queue. Otherwise, this method + will only reset the KV prefix cache when there is no running request + taking KV cache. """ raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c1ead200ba8d..52b98ef65459 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -347,17 +347,7 @@ class Scheduler(SchedulerInterface): else: preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - self.encoder_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - preempted_req.num_preemptions += 1 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, scheduled_timestamp - ) - - self.waiting.prepend_request(preempted_req) + self._preempt_request(preempted_req, scheduled_timestamp) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. Cannot schedule this request. @@ -756,6 +746,30 @@ class Scheduler(SchedulerInterface): self._update_after_schedule(scheduler_output) return scheduler_output + def _preempt_request( + self, + request: Request, + timestamp: float, + ) -> None: + """Preempt a request and put it back to the waiting queue. + + NOTE: The request should be popped from the running queue outside of this + method. + """ + assert request.status == RequestStatus.RUNNING, ( + "Only running requests can be preempted" + ) + self.kv_cache_manager.free(request) + self.encoder_cache_manager.free(request) + request.status = RequestStatus.PREEMPTED + request.num_computed_tokens = 0 + request.num_preemptions += 1 + if self.log_stats: + request.record_event(EngineCoreEventType.PREEMPTED, timestamp) + + # Put the request back to the waiting queue. + self.waiting.prepend_request(request) + def _update_after_schedule( self, scheduler_output: SchedulerOutput, @@ -1362,8 +1376,45 @@ class Scheduler(SchedulerInterface): def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 - def reset_prefix_cache(self) -> bool: - return self.kv_cache_manager.reset_prefix_cache() + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + """Reset the KV prefix cache. + + If reset_running_requests is True, all the running requests will be + preempted and moved to the waiting queue. + Otherwise, this method will only reset the KV prefix cache when there + is no running requests taking KV cache. + """ + if reset_running_requests: + # For logging. + timestamp = time.monotonic() + # Invalidate all the current running requests KV's by pushing them to + # the waiting queue. In this case, we can reduce the ref count of all + # the kv blocks to 0 and thus we can make sure the reset is successful. + # Preempt in reverse order so the requests will be added back to the + # running queue in FIFO order. + while self.running: + request = self.running.pop() + self._preempt_request(request, timestamp) + # NOTE(zhuohan): For async scheduling, we need to discard the latest + # output token on the fly to avoid a redundant repetitive output token. + request.num_output_placeholders = 0 + request.discard_latest_async_tokens = True + + # Clear scheduled request ids cache. Since we are forcing preemption + # + resumption in the same step, we must act as if these requests were + # not scheduled in the prior step. They will be flushed from the + # persistent batch in the model runner. + self.prev_step_scheduled_req_ids.clear() + + reset_successful = self.kv_cache_manager.reset_prefix_cache() + if reset_running_requests and not reset_successful: + raise RuntimeError( + "Failed to reset KV cache even when all the running requests are " + "preempted and moved to the waiting queue. This is likely due to " + "the presence of running requests waiting for remote KV transfer, " + "which is not supported yet." + ) + return reset_successful def make_stats( self, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d0708a8a046d..17a271ca42e2 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -750,8 +750,8 @@ class AsyncLLM(EngineClient): self.input_processor.clear_mm_cache() await self.engine_core.reset_mm_cache_async() - async def reset_prefix_cache(self) -> None: - await self.engine_core.reset_prefix_cache_async() + async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return await self.engine_core.reset_prefix_cache_async(reset_running_requests) async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e3a5f51a8fc5..61b8422dd663 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -483,8 +483,8 @@ class EngineCore: self.model_executor.reset_mm_cache() - def reset_prefix_cache(self): - self.scheduler.reset_prefix_cache() + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return self.scheduler.reset_prefix_cache(reset_running_requests) def sleep(self, level: int = 1): self.model_executor.sleep(level) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 9b440505bd9d..afa0593921d0 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -138,7 +138,7 @@ class EngineCoreClient(ABC): def reset_mm_cache(self) -> None: raise NotImplementedError - def reset_prefix_cache(self) -> None: + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: raise NotImplementedError def sleep(self, level: int = 1) -> None: @@ -208,7 +208,9 @@ class EngineCoreClient(ABC): async def reset_mm_cache_async(self) -> None: raise NotImplementedError - async def reset_prefix_cache_async(self) -> None: + async def reset_prefix_cache_async( + self, reset_running_requests: bool = False + ) -> bool: raise NotImplementedError async def sleep_async(self, level: int = 1) -> None: @@ -287,8 +289,8 @@ class InprocClient(EngineCoreClient): def reset_mm_cache(self) -> None: self.engine_core.reset_mm_cache() - def reset_prefix_cache(self) -> None: - self.engine_core.reset_prefix_cache() + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return self.engine_core.reset_prefix_cache(reset_running_requests) def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) @@ -751,8 +753,8 @@ class SyncMPClient(MPClient): def reset_mm_cache(self) -> None: self.call_utility("reset_mm_cache") - def reset_prefix_cache(self) -> None: - self.call_utility("reset_prefix_cache") + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return self.call_utility("reset_prefix_cache", reset_running_requests) def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) @@ -955,8 +957,12 @@ class AsyncMPClient(MPClient): async def reset_mm_cache_async(self) -> None: await self.call_utility_async("reset_mm_cache") - async def reset_prefix_cache_async(self) -> None: - await self.call_utility_async("reset_prefix_cache") + async def reset_prefix_cache_async( + self, reset_running_requests: bool = False + ) -> bool: + return await self.call_utility_async( + "reset_prefix_cache", reset_running_requests + ) async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a3bde7ba8d64..e7dfc554e76f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -329,8 +329,8 @@ class LLMEngine: self.input_processor.clear_mm_cache() self.engine_core.reset_mm_cache() - def reset_prefix_cache(self): - self.engine_core.reset_prefix_cache() + def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: + return self.engine_core.reset_prefix_cache(reset_running_requests) def sleep(self, level: int = 1): self.engine_core.sleep(level) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 366cdadf5a58..f2dfd2eed03c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -93,7 +93,12 @@ class Request: if self.prompt_token_ids is not None else [0] * self.num_prompt_tokens ) - self.num_output_placeholders = 0 # Used in async scheduling. + + # Used in async scheduling. + self.num_output_placeholders = 0 + # Used in forced preemption (reset_prefix_cache) with async scheduling. + self.discard_latest_async_tokens = False + self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 self.cache_salt: str | None = cache_salt diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e7991baeaa1b..516c76a5e4b1 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -482,6 +482,8 @@ class InputBatch: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + if self.prev_req_id_to_index is not None: + self.prev_req_id_to_index.pop(req_id, None) self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2218e4f023f9..9eacd2138978 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -774,7 +774,14 @@ class GPUModelRunner( # they will be scheduled again sometime in the future. scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids + resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids + # NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint, + # so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds + # apart from the forced-preemption case in reset_prefix_cache. And in + # that case we include the resumed_req_ids in the unscheduled set so + # that they get cleared from the persistent batch before being re-scheduled + # in the normal resumed request path. + unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids) # NOTE(woosuk): The persistent batch optimization assumes that # consecutive batches contain mostly the same requests. If batches # have low request overlap (e.g., alternating between two distinct