mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
[Core] Support reseting all running requests' KV while calling reset_prefix_cache (#28827)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
fa8804ad9c
commit
d0cd728907
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates preempt requests when using the `LLMEngine`
|
||||
for processing prompts with various sampling parameters.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
(
|
||||
"A robot may not injure a human being " * 50,
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||
),
|
||||
),
|
||||
(
|
||||
"A robot may not injure a human being " * 50,
|
||||
SamplingParams(
|
||||
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||
),
|
||||
),
|
||||
(
|
||||
"To be or not to be,",
|
||||
SamplingParams(
|
||||
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
|
||||
),
|
||||
),
|
||||
(
|
||||
"What is the meaning of life?",
|
||||
SamplingParams(
|
||||
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
print("-" * 50)
|
||||
step_id = 0
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
print("-" * 50)
|
||||
import os
|
||||
|
||||
print(f"Step {step_id} (pid={os.getpid()})")
|
||||
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
if step_id == 10:
|
||||
print(f"Resetting prefix cache at {step_id}")
|
||||
engine.reset_prefix_cache(reset_running_requests=True)
|
||||
|
||||
request_outputs: list[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print("-" * 50)
|
||||
print(request_output)
|
||||
print("-" * 50)
|
||||
step_id += 1
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
"""Initialize the LLMEngine from the command line arguments."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Demo on using the LLMEngine class directly"
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
test_prompts = create_test_prompts()
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
66
tests/v1/core/test_reset_prefix_cache_e2e.py
Normal file
66
tests/v1/core/test_reset_prefix_cache_e2e.py
Normal file
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
|
||||
PROMPTS = [
|
||||
"A robot may not injure a human being ",
|
||||
"To be or not to be,",
|
||||
"What is the meaning of life?",
|
||||
"What does the fox say? " * 20, # Test long prompt
|
||||
]
|
||||
|
||||
|
||||
def test_reset_prefix_cache_e2e():
|
||||
engine_args = EngineArgs(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
gpu_memory_utilization=0.2,
|
||||
async_scheduling=True,
|
||||
max_num_batched_tokens=32,
|
||||
max_model_len=2048,
|
||||
compilation_config={"mode": 0},
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=16,
|
||||
)
|
||||
|
||||
# No preempt case:
|
||||
for i, prompt in enumerate(PROMPTS):
|
||||
engine.add_request("ground_truth_" + str(i), prompt, sampling_params)
|
||||
|
||||
ground_truth_results = {}
|
||||
while engine.has_unfinished_requests():
|
||||
request_outputs = engine.step()
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
ground_truth_results[request_output.request_id] = request_output
|
||||
|
||||
# Preempt case:
|
||||
for i, prompt in enumerate(PROMPTS):
|
||||
engine.add_request("preempted_" + str(i), prompt, sampling_params)
|
||||
|
||||
step_id = 0
|
||||
preempted_results = {}
|
||||
while engine.has_unfinished_requests():
|
||||
if step_id == 10:
|
||||
engine.reset_prefix_cache(reset_running_requests=True)
|
||||
|
||||
request_outputs = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
preempted_results[request_output.request_id] = request_output
|
||||
step_id += 1
|
||||
|
||||
for i in range(len(PROMPTS)):
|
||||
assert (
|
||||
ground_truth_results["ground_truth_" + str(i)].outputs[0].text
|
||||
== preempted_results["preempted_" + str(i)].outputs[0].text
|
||||
), (
|
||||
f"ground_truth_results['ground_truth_{i}'].outputs[0].text="
|
||||
f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} "
|
||||
f"preempted_results['preempted_{i}'].outputs[0].text="
|
||||
f"{preempted_results['preempted_' + str(i)].outputs[0].text}"
|
||||
)
|
||||
@ -728,6 +728,37 @@ def test_preempt_during_execution():
|
||||
assert requests[1].output_token_ids[0] == 42
|
||||
|
||||
|
||||
def test_scheduler_reset_prefix_cache():
|
||||
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Initial scheduling, requests should be at the running state now
|
||||
_ = scheduler.schedule()
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == len(requests)
|
||||
for i, request in enumerate(requests):
|
||||
assert scheduler.running[i] == request
|
||||
|
||||
# Reset prefix cache should fail since there are still running requests
|
||||
# and they are taking KV cache
|
||||
assert not scheduler.reset_prefix_cache()
|
||||
|
||||
# Reset prefix cache with reset_running_requests=True. All running requests
|
||||
# Should be pushed back to the waiting queue and kv cache should be freed
|
||||
assert scheduler.reset_prefix_cache(reset_running_requests=True)
|
||||
|
||||
# Verify requests moved from running to waiting
|
||||
assert len(scheduler.waiting) == len(requests)
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
assert scheduler.waiting[i] == request
|
||||
|
||||
|
||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
|
||||
@ -116,7 +116,7 @@ class EngineClient(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
"""Reset the prefix cache"""
|
||||
...
|
||||
|
||||
|
||||
@ -1492,8 +1492,8 @@ class LLM:
|
||||
def stop_profile(self) -> None:
|
||||
self.llm_engine.stop_profile()
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self.llm_engine.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return self.llm_engine.reset_prefix_cache(reset_running_requests)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
"""
|
||||
|
||||
@ -877,13 +877,15 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
async def reset_prefix_cache(raw_request: Request):
|
||||
async def reset_prefix_cache(
|
||||
raw_request: Request, reset_running_requests: bool = Query(default=False)
|
||||
):
|
||||
"""
|
||||
Reset the prefix cache. Note that we currently do not check if the
|
||||
prefix cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting prefix cache...")
|
||||
await engine_client(raw_request).reset_prefix_cache()
|
||||
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/reset_mm_cache")
|
||||
|
||||
@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler):
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
if request.discard_latest_async_tokens:
|
||||
# If the request is force preempted in reset_prefix_cache, we
|
||||
# should discard the latest async token.
|
||||
request.discard_latest_async_tokens = False
|
||||
return [], False
|
||||
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids
|
||||
|
||||
@ -152,10 +152,16 @@ class SchedulerInterface(ABC):
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
This is particularly required when the model weights are live-updated.
|
||||
|
||||
Args:
|
||||
reset_running_requests: If True, all the running requests will be
|
||||
preempted and moved to the waiting queue. Otherwise, this method
|
||||
will only reset the KV prefix cache when there is no running request
|
||||
taking KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -347,17 +347,7 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
@ -756,6 +746,30 @@ class Scheduler(SchedulerInterface):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _preempt_request(
|
||||
self,
|
||||
request: Request,
|
||||
timestamp: float,
|
||||
) -> None:
|
||||
"""Preempt a request and put it back to the waiting queue.
|
||||
|
||||
NOTE: The request should be popped from the running queue outside of this
|
||||
method.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
"Only running requests can be preempted"
|
||||
)
|
||||
self.kv_cache_manager.free(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
request.num_computed_tokens = 0
|
||||
request.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
|
||||
|
||||
# Put the request back to the waiting queue.
|
||||
self.waiting.prepend_request(request)
|
||||
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@ -1362,8 +1376,45 @@ class Scheduler(SchedulerInterface):
|
||||
def has_finished_requests(self) -> bool:
|
||||
return len(self.finished_req_ids) > 0
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
"""Reset the KV prefix cache.
|
||||
|
||||
If reset_running_requests is True, all the running requests will be
|
||||
preempted and moved to the waiting queue.
|
||||
Otherwise, this method will only reset the KV prefix cache when there
|
||||
is no running requests taking KV cache.
|
||||
"""
|
||||
if reset_running_requests:
|
||||
# For logging.
|
||||
timestamp = time.monotonic()
|
||||
# Invalidate all the current running requests KV's by pushing them to
|
||||
# the waiting queue. In this case, we can reduce the ref count of all
|
||||
# the kv blocks to 0 and thus we can make sure the reset is successful.
|
||||
# Preempt in reverse order so the requests will be added back to the
|
||||
# running queue in FIFO order.
|
||||
while self.running:
|
||||
request = self.running.pop()
|
||||
self._preempt_request(request, timestamp)
|
||||
# NOTE(zhuohan): For async scheduling, we need to discard the latest
|
||||
# output token on the fly to avoid a redundant repetitive output token.
|
||||
request.num_output_placeholders = 0
|
||||
request.discard_latest_async_tokens = True
|
||||
|
||||
# Clear scheduled request ids cache. Since we are forcing preemption
|
||||
# + resumption in the same step, we must act as if these requests were
|
||||
# not scheduled in the prior step. They will be flushed from the
|
||||
# persistent batch in the model runner.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
|
||||
reset_successful = self.kv_cache_manager.reset_prefix_cache()
|
||||
if reset_running_requests and not reset_successful:
|
||||
raise RuntimeError(
|
||||
"Failed to reset KV cache even when all the running requests are "
|
||||
"preempted and moved to the waiting queue. This is likely due to "
|
||||
"the presence of running requests waiting for remote KV transfer, "
|
||||
"which is not supported yet."
|
||||
)
|
||||
return reset_successful
|
||||
|
||||
def make_stats(
|
||||
self,
|
||||
|
||||
@ -750,8 +750,8 @@ class AsyncLLM(EngineClient):
|
||||
self.input_processor.clear_mm_cache()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self) -> None:
|
||||
await self.engine_core.reset_prefix_cache_async()
|
||||
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.reset_prefix_cache()
|
||||
|
||||
@ -483,8 +483,8 @@ class EngineCore:
|
||||
|
||||
self.model_executor.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return self.scheduler.reset_prefix_cache(reset_running_requests)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
@ -138,7 +138,7 @@ class EngineCoreClient(ABC):
|
||||
def reset_mm_cache(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
@ -208,7 +208,9 @@ class EngineCoreClient(ABC):
|
||||
async def reset_mm_cache_async(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
async def reset_prefix_cache_async(
|
||||
self, reset_running_requests: bool = False
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
@ -287,8 +289,8 @@ class InprocClient(EngineCoreClient):
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self.engine_core.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return self.engine_core.reset_prefix_cache(reset_running_requests)
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self.engine_core.sleep(level)
|
||||
@ -751,8 +753,8 @@ class SyncMPClient(MPClient):
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.call_utility("reset_mm_cache")
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self.call_utility("reset_prefix_cache")
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return self.call_utility("reset_prefix_cache", reset_running_requests)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.call_utility("add_lora", lora_request)
|
||||
@ -955,8 +957,12 @@ class AsyncMPClient(MPClient):
|
||||
async def reset_mm_cache_async(self) -> None:
|
||||
await self.call_utility_async("reset_mm_cache")
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
await self.call_utility_async("reset_prefix_cache")
|
||||
async def reset_prefix_cache_async(
|
||||
self, reset_running_requests: bool = False
|
||||
) -> bool:
|
||||
return await self.call_utility_async(
|
||||
"reset_prefix_cache", reset_running_requests
|
||||
)
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
await self.call_utility_async("sleep", level)
|
||||
|
||||
@ -329,8 +329,8 @@ class LLMEngine:
|
||||
self.input_processor.clear_mm_cache()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.engine_core.reset_prefix_cache()
|
||||
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||
return self.engine_core.reset_prefix_cache(reset_running_requests)
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
@ -93,7 +93,12 @@ class Request:
|
||||
if self.prompt_token_ids is not None
|
||||
else [0] * self.num_prompt_tokens
|
||||
)
|
||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||
|
||||
# Used in async scheduling.
|
||||
self.num_output_placeholders = 0
|
||||
# Used in forced preemption (reset_prefix_cache) with async scheduling.
|
||||
self.discard_latest_async_tokens = False
|
||||
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
self.cache_salt: str | None = cache_salt
|
||||
|
||||
@ -482,6 +482,8 @@ class InputBatch:
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
if self.prev_req_id_to_index is not None:
|
||||
self.prev_req_id_to_index.pop(req_id, None)
|
||||
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
|
||||
@ -774,7 +774,14 @@ class GPUModelRunner(
|
||||
# they will be scheduled again sometime in the future.
|
||||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||||
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
|
||||
# NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint,
|
||||
# so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds
|
||||
# apart from the forced-preemption case in reset_prefix_cache. And in
|
||||
# that case we include the resumed_req_ids in the unscheduled set so
|
||||
# that they get cleared from the persistent batch before being re-scheduled
|
||||
# in the normal resumed request path.
|
||||
unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
|
||||
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||
# consecutive batches contain mostly the same requests. If batches
|
||||
# have low request overlap (e.g., alternating between two distinct
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user