diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index e0404186eb2d..b2ec2ddfb64d 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -43,6 +43,7 @@ def test_basic_lifecycle(): # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -67,6 +68,7 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 0 # ... but blocks should not be freed. + assert len(scheduler.requests) == 1 blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 ].req_to_blocks[request_id] @@ -76,6 +78,7 @@ def test_basic_lifecycle(): # STEP (2): Send Finished to PB. # (2a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 1 assert request_id in scheduler_output.finished_req_ids @@ -92,6 +95,7 @@ def test_basic_lifecycle(): # STEP (3): Finished sending. # (3a): schedule() - pass finished request to PB. scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 0 assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0 @@ -133,6 +137,7 @@ def test_short_prompt_lifecycle(): # STEP (1): Prefill. # (1a): schedule() scheduler_output = scheduler.schedule() + assert len(scheduler.requests) == 1 assert len(scheduler.running) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 @@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle(): reqs=[request_normal], use_eos=True ) scheduler.update_from_output(scheduler_output, model_runner_output) - scheduler.schedule() + scheduler_output = scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) ##################### @@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert_scheduler_empty(scheduler) + + +def test_abort_during_kv_transfer(): + """Test aborting request does not release blocks for remote decode.""" + + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + + # Prime the KVCache. + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_EXTERNAL_FULL_BLOCKS = 2 + NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) + + request = create_request( + request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS, + do_remote_decode=True, + ) + + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request]) + scheduler.update_from_output(scheduler_output, model_runner_output) + scheduler_output = scheduler.schedule() + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + + # Request removed from PB but blocks should not be freed. + assert len(scheduler.requests) == 1 + + # Abort the request, and check the blocks are still not freed + scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED) + assert len(scheduler.requests) == 1 + + # Simulate a finished sending notification + scheduler_output = scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request.request_id] + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert_scheduler_empty(scheduler) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 70225e95aed2..e871b3017d8b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -14,11 +14,12 @@ The class provides the following primitives: temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after output is received from worker-side connectors. - request_finished() - called when a request is finished, with - the computed kv cache blocks for the request. - Returns whether KV cache should be freed now or will be - freed asynchronously and optionally returns KV transfer - params. + request_finished() - called once when a request is finished, + with the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or if the + connector now assumes responsibility for freeing the + the blocks asynchronously. Also optionally returns KV + transfer params. take_events() - returns new KV events that were collected by the connector since the last call. @@ -362,7 +363,11 @@ class KVConnectorBase_V1(ABC): block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: """ - Called when a request has finished, before its blocks are freed. + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the the blocks + asynchronously by returning True. Returns: True if the request is being saved/sent asynchronously and blocks diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0d4744b9f4ab..365d1a1ff280 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1345,6 +1345,8 @@ class NixlConnectorWorker: # Remove all requests that are not to be processed (eg aborted). for req_id in metadata.reqs_not_processed: self._reqs_to_process.discard(req_id) + # We should never get an abort after setting an expiry timer + assert req_id not in self._reqs_to_send # Add to requests that are waiting to be read and track expiration. for req_id, expiration_time in metadata.reqs_to_send.items(): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bb9c9aaadebf..5bc7a488bf83 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1187,7 +1187,7 @@ class Scheduler(SchedulerInterface): # First pass: collect requests to remove from queues for req_id in request_ids: request = self.requests.get(req_id) - if request is None: + if request is None or request.is_finished(): # Invalid request ID. continue @@ -1365,14 +1365,8 @@ class Scheduler(SchedulerInterface): self.finished_recving_kv_req_ids.add(req_id) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) - if req_id not in self.requests: - logger.warning( - "Got finished sending KV transfer for request %s," - "but the request is already freed.", - req_id, - ) - else: - self._free_blocks(self.requests[req_id]) + assert req_id in self.requests + self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( self, requests: Iterable[Request], invalid_block_ids: set[int]