mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:49:19 +08:00
[Core] Support reseting all running requests' KV while calling reset_prefix_cache (#28827)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
fa8804ad9c
commit
d0cd728907
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
98
examples/offline_inference/llm_engine_reset_kv.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This file demonstrates preempt requests when using the `LLMEngine`
|
||||||
|
for processing prompts with various sampling parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||||
|
"""Create a list of test prompts with their sampling parameters."""
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
"A robot may not injure a human being " * 50,
|
||||||
|
SamplingParams(
|
||||||
|
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"A robot may not injure a human being " * 50,
|
||||||
|
SamplingParams(
|
||||||
|
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=16
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"To be or not to be,",
|
||||||
|
SamplingParams(
|
||||||
|
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"What is the meaning of life?",
|
||||||
|
SamplingParams(
|
||||||
|
n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1, max_tokens=128
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_requests(engine: LLMEngine, test_prompts: list[tuple[str, SamplingParams]]):
|
||||||
|
"""Continuously process a list of prompts and handle the outputs."""
|
||||||
|
request_id = 0
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
step_id = 0
|
||||||
|
while test_prompts or engine.has_unfinished_requests():
|
||||||
|
print("-" * 50)
|
||||||
|
import os
|
||||||
|
|
||||||
|
print(f"Step {step_id} (pid={os.getpid()})")
|
||||||
|
|
||||||
|
if test_prompts:
|
||||||
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
|
request_id += 1
|
||||||
|
|
||||||
|
if step_id == 10:
|
||||||
|
print(f"Resetting prefix cache at {step_id}")
|
||||||
|
engine.reset_prefix_cache(reset_running_requests=True)
|
||||||
|
|
||||||
|
request_outputs: list[RequestOutput] = engine.step()
|
||||||
|
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
print("-" * 50)
|
||||||
|
print(request_output)
|
||||||
|
print("-" * 50)
|
||||||
|
step_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||||
|
"""Initialize the LLMEngine from the command line arguments."""
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
return LLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Demo on using the LLMEngine class directly"
|
||||||
|
)
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
"""Main function that sets up and runs the prompt processing."""
|
||||||
|
engine = initialize_engine(args)
|
||||||
|
test_prompts = create_test_prompts()
|
||||||
|
process_requests(engine, test_prompts)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
66
tests/v1/core/test_reset_prefix_cache_e2e.py
Normal file
66
tests/v1/core/test_reset_prefix_cache_e2e.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
"A robot may not injure a human being ",
|
||||||
|
"To be or not to be,",
|
||||||
|
"What is the meaning of life?",
|
||||||
|
"What does the fox say? " * 20, # Test long prompt
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_prefix_cache_e2e():
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model="Qwen/Qwen3-0.6B",
|
||||||
|
gpu_memory_utilization=0.2,
|
||||||
|
async_scheduling=True,
|
||||||
|
max_num_batched_tokens=32,
|
||||||
|
max_model_len=2048,
|
||||||
|
compilation_config={"mode": 0},
|
||||||
|
)
|
||||||
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No preempt case:
|
||||||
|
for i, prompt in enumerate(PROMPTS):
|
||||||
|
engine.add_request("ground_truth_" + str(i), prompt, sampling_params)
|
||||||
|
|
||||||
|
ground_truth_results = {}
|
||||||
|
while engine.has_unfinished_requests():
|
||||||
|
request_outputs = engine.step()
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
ground_truth_results[request_output.request_id] = request_output
|
||||||
|
|
||||||
|
# Preempt case:
|
||||||
|
for i, prompt in enumerate(PROMPTS):
|
||||||
|
engine.add_request("preempted_" + str(i), prompt, sampling_params)
|
||||||
|
|
||||||
|
step_id = 0
|
||||||
|
preempted_results = {}
|
||||||
|
while engine.has_unfinished_requests():
|
||||||
|
if step_id == 10:
|
||||||
|
engine.reset_prefix_cache(reset_running_requests=True)
|
||||||
|
|
||||||
|
request_outputs = engine.step()
|
||||||
|
|
||||||
|
for request_output in request_outputs:
|
||||||
|
if request_output.finished:
|
||||||
|
preempted_results[request_output.request_id] = request_output
|
||||||
|
step_id += 1
|
||||||
|
|
||||||
|
for i in range(len(PROMPTS)):
|
||||||
|
assert (
|
||||||
|
ground_truth_results["ground_truth_" + str(i)].outputs[0].text
|
||||||
|
== preempted_results["preempted_" + str(i)].outputs[0].text
|
||||||
|
), (
|
||||||
|
f"ground_truth_results['ground_truth_{i}'].outputs[0].text="
|
||||||
|
f"{ground_truth_results['ground_truth_' + str(i)].outputs[0].text} "
|
||||||
|
f"preempted_results['preempted_{i}'].outputs[0].text="
|
||||||
|
f"{preempted_results['preempted_' + str(i)].outputs[0].text}"
|
||||||
|
)
|
||||||
@ -728,6 +728,37 @@ def test_preempt_during_execution():
|
|||||||
assert requests[1].output_token_ids[0] == 42
|
assert requests[1].output_token_ids[0] == 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_reset_prefix_cache():
|
||||||
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Initial scheduling, requests should be at the running state now
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
|
||||||
|
# Verify requests moved from waiting to running
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == len(requests)
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
assert scheduler.running[i] == request
|
||||||
|
|
||||||
|
# Reset prefix cache should fail since there are still running requests
|
||||||
|
# and they are taking KV cache
|
||||||
|
assert not scheduler.reset_prefix_cache()
|
||||||
|
|
||||||
|
# Reset prefix cache with reset_running_requests=True. All running requests
|
||||||
|
# Should be pushed back to the waiting queue and kv cache should be freed
|
||||||
|
assert scheduler.reset_prefix_cache(reset_running_requests=True)
|
||||||
|
|
||||||
|
# Verify requests moved from running to waiting
|
||||||
|
assert len(scheduler.waiting) == len(requests)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
assert scheduler.waiting[i] == request
|
||||||
|
|
||||||
|
|
||||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"spec_tokens,output_tokens,expected",
|
"spec_tokens,output_tokens,expected",
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class EngineClient(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def reset_prefix_cache(self) -> None:
|
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
"""Reset the prefix cache"""
|
"""Reset the prefix cache"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -1492,8 +1492,8 @@ class LLM:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.llm_engine.stop_profile()
|
self.llm_engine.stop_profile()
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
self.llm_engine.reset_prefix_cache()
|
return self.llm_engine.reset_prefix_cache(reset_running_requests)
|
||||||
|
|
||||||
def sleep(self, level: int = 1):
|
def sleep(self, level: int = 1):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -877,13 +877,15 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
return JSONResponse(content=server_info)
|
return JSONResponse(content=server_info)
|
||||||
|
|
||||||
@router.post("/reset_prefix_cache")
|
@router.post("/reset_prefix_cache")
|
||||||
async def reset_prefix_cache(raw_request: Request):
|
async def reset_prefix_cache(
|
||||||
|
raw_request: Request, reset_running_requests: bool = Query(default=False)
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Reset the prefix cache. Note that we currently do not check if the
|
Reset the prefix cache. Note that we currently do not check if the
|
||||||
prefix cache is successfully reset in the API server.
|
prefix cache is successfully reset in the API server.
|
||||||
"""
|
"""
|
||||||
logger.info("Resetting prefix cache...")
|
logger.info("Resetting prefix cache...")
|
||||||
await engine_client(raw_request).reset_prefix_cache()
|
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
@router.post("/reset_mm_cache")
|
@router.post("/reset_mm_cache")
|
||||||
|
|||||||
@ -45,6 +45,12 @@ class AsyncScheduler(Scheduler):
|
|||||||
request: Request,
|
request: Request,
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
) -> tuple[list[int], bool]:
|
) -> tuple[list[int], bool]:
|
||||||
|
if request.discard_latest_async_tokens:
|
||||||
|
# If the request is force preempted in reset_prefix_cache, we
|
||||||
|
# should discard the latest async token.
|
||||||
|
request.discard_latest_async_tokens = False
|
||||||
|
return [], False
|
||||||
|
|
||||||
status_before_update = request.status
|
status_before_update = request.status
|
||||||
new_token_ids, stopped = super()._update_request_with_output(
|
new_token_ids, stopped = super()._update_request_with_output(
|
||||||
request, new_token_ids
|
request, new_token_ids
|
||||||
|
|||||||
@ -152,10 +152,16 @@ class SchedulerInterface(ABC):
|
|||||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset_prefix_cache(self) -> bool:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
"""Reset the prefix cache for KV cache.
|
"""Reset the prefix cache for KV cache.
|
||||||
|
|
||||||
This is particularly required when the model weights are live-updated.
|
This is particularly required when the model weights are live-updated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reset_running_requests: If True, all the running requests will be
|
||||||
|
preempted and moved to the waiting queue. Otherwise, this method
|
||||||
|
will only reset the KV prefix cache when there is no running request
|
||||||
|
taking KV cache.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -347,17 +347,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
else:
|
else:
|
||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
self.kv_cache_manager.free(preempted_req)
|
self._preempt_request(preempted_req, scheduled_timestamp)
|
||||||
self.encoder_cache_manager.free(preempted_req)
|
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
|
||||||
preempted_req.num_computed_tokens = 0
|
|
||||||
preempted_req.num_preemptions += 1
|
|
||||||
if self.log_stats:
|
|
||||||
preempted_req.record_event(
|
|
||||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
self.waiting.prepend_request(preempted_req)
|
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
if preempted_req == request:
|
if preempted_req == request:
|
||||||
# No more request to preempt. Cannot schedule this request.
|
# No more request to preempt. Cannot schedule this request.
|
||||||
@ -756,6 +746,30 @@ class Scheduler(SchedulerInterface):
|
|||||||
self._update_after_schedule(scheduler_output)
|
self._update_after_schedule(scheduler_output)
|
||||||
return scheduler_output
|
return scheduler_output
|
||||||
|
|
||||||
|
def _preempt_request(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
timestamp: float,
|
||||||
|
) -> None:
|
||||||
|
"""Preempt a request and put it back to the waiting queue.
|
||||||
|
|
||||||
|
NOTE: The request should be popped from the running queue outside of this
|
||||||
|
method.
|
||||||
|
"""
|
||||||
|
assert request.status == RequestStatus.RUNNING, (
|
||||||
|
"Only running requests can be preempted"
|
||||||
|
)
|
||||||
|
self.kv_cache_manager.free(request)
|
||||||
|
self.encoder_cache_manager.free(request)
|
||||||
|
request.status = RequestStatus.PREEMPTED
|
||||||
|
request.num_computed_tokens = 0
|
||||||
|
request.num_preemptions += 1
|
||||||
|
if self.log_stats:
|
||||||
|
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
|
||||||
|
|
||||||
|
# Put the request back to the waiting queue.
|
||||||
|
self.waiting.prepend_request(request)
|
||||||
|
|
||||||
def _update_after_schedule(
|
def _update_after_schedule(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
@ -1362,8 +1376,45 @@ class Scheduler(SchedulerInterface):
|
|||||||
def has_finished_requests(self) -> bool:
|
def has_finished_requests(self) -> bool:
|
||||||
return len(self.finished_req_ids) > 0
|
return len(self.finished_req_ids) > 0
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> bool:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
return self.kv_cache_manager.reset_prefix_cache()
|
"""Reset the KV prefix cache.
|
||||||
|
|
||||||
|
If reset_running_requests is True, all the running requests will be
|
||||||
|
preempted and moved to the waiting queue.
|
||||||
|
Otherwise, this method will only reset the KV prefix cache when there
|
||||||
|
is no running requests taking KV cache.
|
||||||
|
"""
|
||||||
|
if reset_running_requests:
|
||||||
|
# For logging.
|
||||||
|
timestamp = time.monotonic()
|
||||||
|
# Invalidate all the current running requests KV's by pushing them to
|
||||||
|
# the waiting queue. In this case, we can reduce the ref count of all
|
||||||
|
# the kv blocks to 0 and thus we can make sure the reset is successful.
|
||||||
|
# Preempt in reverse order so the requests will be added back to the
|
||||||
|
# running queue in FIFO order.
|
||||||
|
while self.running:
|
||||||
|
request = self.running.pop()
|
||||||
|
self._preempt_request(request, timestamp)
|
||||||
|
# NOTE(zhuohan): For async scheduling, we need to discard the latest
|
||||||
|
# output token on the fly to avoid a redundant repetitive output token.
|
||||||
|
request.num_output_placeholders = 0
|
||||||
|
request.discard_latest_async_tokens = True
|
||||||
|
|
||||||
|
# Clear scheduled request ids cache. Since we are forcing preemption
|
||||||
|
# + resumption in the same step, we must act as if these requests were
|
||||||
|
# not scheduled in the prior step. They will be flushed from the
|
||||||
|
# persistent batch in the model runner.
|
||||||
|
self.prev_step_scheduled_req_ids.clear()
|
||||||
|
|
||||||
|
reset_successful = self.kv_cache_manager.reset_prefix_cache()
|
||||||
|
if reset_running_requests and not reset_successful:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to reset KV cache even when all the running requests are "
|
||||||
|
"preempted and moved to the waiting queue. This is likely due to "
|
||||||
|
"the presence of running requests waiting for remote KV transfer, "
|
||||||
|
"which is not supported yet."
|
||||||
|
)
|
||||||
|
return reset_successful
|
||||||
|
|
||||||
def make_stats(
|
def make_stats(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -750,8 +750,8 @@ class AsyncLLM(EngineClient):
|
|||||||
self.input_processor.clear_mm_cache()
|
self.input_processor.clear_mm_cache()
|
||||||
await self.engine_core.reset_mm_cache_async()
|
await self.engine_core.reset_mm_cache_async()
|
||||||
|
|
||||||
async def reset_prefix_cache(self) -> None:
|
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
await self.engine_core.reset_prefix_cache_async()
|
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
|
||||||
|
|
||||||
async def sleep(self, level: int = 1) -> None:
|
async def sleep(self, level: int = 1) -> None:
|
||||||
await self.reset_prefix_cache()
|
await self.reset_prefix_cache()
|
||||||
|
|||||||
@ -483,8 +483,8 @@ class EngineCore:
|
|||||||
|
|
||||||
self.model_executor.reset_mm_cache()
|
self.model_executor.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self):
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
self.scheduler.reset_prefix_cache()
|
return self.scheduler.reset_prefix_cache(reset_running_requests)
|
||||||
|
|
||||||
def sleep(self, level: int = 1):
|
def sleep(self, level: int = 1):
|
||||||
self.model_executor.sleep(level)
|
self.model_executor.sleep(level)
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class EngineCoreClient(ABC):
|
|||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def sleep(self, level: int = 1) -> None:
|
def sleep(self, level: int = 1) -> None:
|
||||||
@ -208,7 +208,9 @@ class EngineCoreClient(ABC):
|
|||||||
async def reset_mm_cache_async(self) -> None:
|
async def reset_mm_cache_async(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(
|
||||||
|
self, reset_running_requests: bool = False
|
||||||
|
) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def sleep_async(self, level: int = 1) -> None:
|
async def sleep_async(self, level: int = 1) -> None:
|
||||||
@ -287,8 +289,8 @@ class InprocClient(EngineCoreClient):
|
|||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
self.engine_core.reset_mm_cache()
|
self.engine_core.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
self.engine_core.reset_prefix_cache()
|
return self.engine_core.reset_prefix_cache(reset_running_requests)
|
||||||
|
|
||||||
def sleep(self, level: int = 1) -> None:
|
def sleep(self, level: int = 1) -> None:
|
||||||
self.engine_core.sleep(level)
|
self.engine_core.sleep(level)
|
||||||
@ -751,8 +753,8 @@ class SyncMPClient(MPClient):
|
|||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
self.call_utility("reset_mm_cache")
|
self.call_utility("reset_mm_cache")
|
||||||
|
|
||||||
def reset_prefix_cache(self) -> None:
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
self.call_utility("reset_prefix_cache")
|
return self.call_utility("reset_prefix_cache", reset_running_requests)
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
return self.call_utility("add_lora", lora_request)
|
return self.call_utility("add_lora", lora_request)
|
||||||
@ -955,8 +957,12 @@ class AsyncMPClient(MPClient):
|
|||||||
async def reset_mm_cache_async(self) -> None:
|
async def reset_mm_cache_async(self) -> None:
|
||||||
await self.call_utility_async("reset_mm_cache")
|
await self.call_utility_async("reset_mm_cache")
|
||||||
|
|
||||||
async def reset_prefix_cache_async(self) -> None:
|
async def reset_prefix_cache_async(
|
||||||
await self.call_utility_async("reset_prefix_cache")
|
self, reset_running_requests: bool = False
|
||||||
|
) -> bool:
|
||||||
|
return await self.call_utility_async(
|
||||||
|
"reset_prefix_cache", reset_running_requests
|
||||||
|
)
|
||||||
|
|
||||||
async def sleep_async(self, level: int = 1) -> None:
|
async def sleep_async(self, level: int = 1) -> None:
|
||||||
await self.call_utility_async("sleep", level)
|
await self.call_utility_async("sleep", level)
|
||||||
|
|||||||
@ -329,8 +329,8 @@ class LLMEngine:
|
|||||||
self.input_processor.clear_mm_cache()
|
self.input_processor.clear_mm_cache()
|
||||||
self.engine_core.reset_mm_cache()
|
self.engine_core.reset_mm_cache()
|
||||||
|
|
||||||
def reset_prefix_cache(self):
|
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
|
||||||
self.engine_core.reset_prefix_cache()
|
return self.engine_core.reset_prefix_cache(reset_running_requests)
|
||||||
|
|
||||||
def sleep(self, level: int = 1):
|
def sleep(self, level: int = 1):
|
||||||
self.engine_core.sleep(level)
|
self.engine_core.sleep(level)
|
||||||
|
|||||||
@ -93,7 +93,12 @@ class Request:
|
|||||||
if self.prompt_token_ids is not None
|
if self.prompt_token_ids is not None
|
||||||
else [0] * self.num_prompt_tokens
|
else [0] * self.num_prompt_tokens
|
||||||
)
|
)
|
||||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
|
||||||
|
# Used in async scheduling.
|
||||||
|
self.num_output_placeholders = 0
|
||||||
|
# Used in forced preemption (reset_prefix_cache) with async scheduling.
|
||||||
|
self.discard_latest_async_tokens = False
|
||||||
|
|
||||||
self.spec_token_ids: list[int] = []
|
self.spec_token_ids: list[int] = []
|
||||||
self.num_computed_tokens = 0
|
self.num_computed_tokens = 0
|
||||||
self.cache_salt: str | None = cache_salt
|
self.cache_salt: str | None = cache_salt
|
||||||
|
|||||||
@ -482,6 +482,8 @@ class InputBatch:
|
|||||||
self.generators.pop(req_index, None)
|
self.generators.pop(req_index, None)
|
||||||
self.num_logprobs.pop(req_id, None)
|
self.num_logprobs.pop(req_id, None)
|
||||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||||
|
if self.prev_req_id_to_index is not None:
|
||||||
|
self.prev_req_id_to_index.pop(req_id, None)
|
||||||
|
|
||||||
self.has_allowed_token_ids.discard(req_id)
|
self.has_allowed_token_ids.discard(req_id)
|
||||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||||
|
|||||||
@ -774,7 +774,14 @@ class GPUModelRunner(
|
|||||||
# they will be scheduled again sometime in the future.
|
# they will be scheduled again sometime in the future.
|
||||||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
resumed_req_ids = scheduler_output.scheduled_cached_reqs.resumed_req_ids
|
||||||
|
# NOTE(zhuohan): cached_req_ids and resumed_req_ids are usually disjoint,
|
||||||
|
# so `(scheduled_req_ids - resumed_req_ids) == scheduled_req_ids` holds
|
||||||
|
# apart from the forced-preemption case in reset_prefix_cache. And in
|
||||||
|
# that case we include the resumed_req_ids in the unscheduled set so
|
||||||
|
# that they get cleared from the persistent batch before being re-scheduled
|
||||||
|
# in the normal resumed request path.
|
||||||
|
unscheduled_req_ids = cached_req_ids - (scheduled_req_ids - resumed_req_ids)
|
||||||
# NOTE(woosuk): The persistent batch optimization assumes that
|
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||||
# consecutive batches contain mostly the same requests. If batches
|
# consecutive batches contain mostly the same requests. If batches
|
||||||
# have low request overlap (e.g., alternating between two distinct
|
# have low request overlap (e.g., alternating between two distinct
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user