From e0cc5f259a8bec0d66ed0bc3e25ca245377679a1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Jan 2025 13:47:33 -0800 Subject: [PATCH] [V1][BugFix] Free encoder cache for aborted requests (#12545) Signed-off-by: Woosuk Kwon --- vllm/v1/core/encoder_cache_manager.py | 9 ++++++++- vllm/v1/core/scheduler.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 0cd8c806a3e47..9d570b334c6cf 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -38,7 +38,8 @@ class EncoderCacheManager: def get_cached_input_ids(self, request: Request) -> Set[int]: return self.cached.get(request.request_id, set()) - def free(self, request: Request, input_id: int) -> None: + def free_encoder_input(self, request: Request, input_id: int) -> None: + """Free a single encoder input id for the request.""" req_id = request.request_id if req_id not in self.cached: return @@ -49,6 +50,12 @@ class EncoderCacheManager: self.num_free_slots += request.get_num_encoder_tokens(input_id) self.freed.append((req_id, input_id)) + def free(self, request: Request) -> None: + """Free all cached input ids for the request.""" + input_ids = self.get_cached_input_ids(request) + for input_id in input_ids: + self.free_encoder_input(request, input_id) + def get_freed_ids(self) -> List[Tuple[str, int]]: freed = self.freed self.freed = [] diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 7a88cc9433b32..da2e31b1fb75b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -202,7 +202,7 @@ class Scheduler: # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens if num_new_tokens == 0: - # The happens when prompt length is divisible by the block + # This happens when prompt length is divisible by the block # size and all blocks are cached. Now we force to recompute # the last block. Note that we have to re-compute an entire # block because allocate_slots() assumes num_computed_tokens @@ -269,6 +269,7 @@ class Scheduler: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. + num_common_prefix_blocks = 0 if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -433,7 +434,8 @@ class Scheduler: if start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free(request, input_id) + self.encoder_cache_manager.free_encoder_input( + request, input_id) if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] @@ -445,8 +447,10 @@ class Scheduler: # TODO: Update the KV cache manager for prefix caching. # Check for stop and update request state. - # This must be called before me make the EngineCoreOutput. + # This must be called before we make the EngineCoreOutput. stopped = self._check_stop(request) + if stopped: + self._free_request(request) # Add EngineCoreOutput for this Request. output = EngineCoreOutput( @@ -472,7 +476,6 @@ class Scheduler: if (request.num_tokens >= self.max_model_len or request.num_output_tokens >= request.max_tokens): request.status = RequestStatus.FINISHED_LENGTH_CAPPED - self._free_request(request) return True sampling_params = request.sampling_params @@ -480,13 +483,11 @@ class Scheduler: if (not sampling_params.ignore_eos and last_token_id == request.eos_token_id): request.status = RequestStatus.FINISHED_STOPPED - self._free_request(request) return True if last_token_id in (sampling_params.stop_token_ids or ()): request.status = RequestStatus.FINISHED_STOPPED request.stop_reason = last_token_id - self._free_request(request) return True return False @@ -525,6 +526,7 @@ class Scheduler: def _free_request(self, request: Request) -> None: assert request.is_finished() self.kv_cache_manager.free(request) + self.encoder_cache_manager.free(request) self.running_reqs_data.pop(request.request_id, None) del self.requests[request.request_id] self.finished_req_ids.add(request.request_id)