mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:47:18 +08:00
[V1][BugFix] Free encoder cache for aborted requests (#12545)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
73aa6cfdf7
commit
e0cc5f259a
@ -38,7 +38,8 @@ class EncoderCacheManager:
|
||||
def get_cached_input_ids(self, request: Request) -> Set[int]:
|
||||
return self.cached.get(request.request_id, set())
|
||||
|
||||
def free(self, request: Request, input_id: int) -> None:
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free a single encoder input id for the request."""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
return
|
||||
@ -49,6 +50,12 @@ class EncoderCacheManager:
|
||||
self.num_free_slots += request.get_num_encoder_tokens(input_id)
|
||||
self.freed.append((req_id, input_id))
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all cached input ids for the request."""
|
||||
input_ids = self.get_cached_input_ids(request)
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_ids(self) -> List[Tuple[str, int]]:
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
|
||||
@ -202,7 +202,7 @@ class Scheduler:
|
||||
# which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if num_new_tokens == 0:
|
||||
# The happens when prompt length is divisible by the block
|
||||
# This happens when prompt length is divisible by the block
|
||||
# size and all blocks are cached. Now we force to recompute
|
||||
# the last block. Note that we have to re-compute an entire
|
||||
# block because allocate_slots() assumes num_computed_tokens
|
||||
@ -269,6 +269,7 @@ class Scheduler:
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = 0
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
@ -433,7 +434,8 @@ class Scheduler:
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free(request, input_id)
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
if request.num_computed_tokens == request.num_tokens:
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
@ -445,8 +447,10 @@ class Scheduler:
|
||||
# TODO: Update the KV cache manager for prefix caching.
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before me make the EngineCoreOutput.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = self._check_stop(request)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
output = EngineCoreOutput(
|
||||
@ -472,7 +476,6 @@ class Scheduler:
|
||||
if (request.num_tokens >= self.max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
self._free_request(request)
|
||||
return True
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
@ -480,13 +483,11 @@ class Scheduler:
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
self._free_request(request)
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
self._free_request(request)
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -525,6 +526,7 @@ class Scheduler:
|
||||
def _free_request(self, request: Request) -> None:
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.encoder_cache_manager.free(request)
|
||||
self.running_reqs_data.pop(request.request_id, None)
|
||||
del self.requests[request.request_id]
|
||||
self.finished_req_ids.add(request.request_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user