[V1][BugFix] Free encoder cache for aborted requests (#12545)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-01-29 13:47:33 -08:00 committed by GitHub
parent 73aa6cfdf7
commit e0cc5f259a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 7 deletions

View File

@ -38,7 +38,8 @@ class EncoderCacheManager:
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())
def free(self, request: Request, input_id: int) -> None:
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
req_id = request.request_id
if req_id not in self.cached:
return
@ -49,6 +50,12 @@ class EncoderCacheManager:
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
input_ids = self.get_cached_input_ids(request)
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []

View File

@ -202,7 +202,7 @@ class Scheduler:
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
@ -269,6 +269,7 @@ class Scheduler:
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
@ -433,7 +434,8 @@ class Scheduler:
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
self.encoder_cache_manager.free_encoder_input(
request, input_id)
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
@ -445,8 +447,10 @@ class Scheduler:
# TODO: Update the KV cache manager for prefix caching.
# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
if stopped:
self._free_request(request)
# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
@ -472,7 +476,6 @@ class Scheduler:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True
sampling_params = request.sampling_params
@ -480,13 +483,11 @@ class Scheduler:
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False
@ -525,6 +526,7 @@ class Scheduler:
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)