From 2a6dc67eb520ddb9c4138d8b35ed6fe6226997fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Sat, 4 Oct 2025 13:39:31 +0200 Subject: [PATCH] [Bugfix] Fix `_reqs_to_process` leak on abort (#26012) Signed-off-by: NickLucche --- .../kv_connector/unit/test_nixl_connector.py | 66 +++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 18 ++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 21953b5533ece..08c0fdefdfc9b 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -33,6 +33,7 @@ from vllm.platforms.interface import Platform from vllm.sampling_params import SamplingParams from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.request import RequestStatus from .utils import create_request, create_scheduler, create_vllm_config @@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init): assert mock_dereg.call_count == 2 mock_dereg.assert_any_call("desc1") mock_dereg.assert_any_call("desc2") + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_aborted_request_removed_from_worker_in_batch(dist_init): + """ + Create and schedule a request so that P adds it to in-batch tracking via + the real scheduler, then simulate an abort (request not in next scheduler + iteration) and verify the worker no longer tracks it as in-batch. + """ + vllm_config = create_vllm_config() + + scheduler = create_scheduler(vllm_config) + # KVConnector Worker in P + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + connector.connector_worker = FakeNixlConnectorWorker(vllm_config, + connector.engine_id, + hand_shake_latency=0) + + # Create a request that triggers do_remote_decode so that + # the scheduler adds it to reqs_in_batch + req = create_request(request_id=1, do_remote_decode=True, max_tokens=1) + scheduler.add_request(req) + + # First scheduling pass - examinate build_connector_meta output + sched_out = scheduler.schedule() + kv_meta = sched_out.kv_connector_metadata + assert kv_meta is not None + assert isinstance(kv_meta, NixlConnectorMetadata) + assert req.request_id in kv_meta.reqs_in_batch + + #### Model Runner start #### + # Bind scheduler-produced metadata and start worker processing. + connector.bind_connector_metadata(kv_meta) + + dummy_ctx = ForwardContext( + no_compile_layers={}, + attn_metadata={}, + virtual_engine=0, + ) + connector.start_load_kv(dummy_ctx) + + # Ensure it was tracked by the worker + assert req.request_id in connector.connector_worker._reqs_to_process + + #### Model Runner end #### + + # Abort request - request_finished call in connector scheduler + scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED) + # Second scheduling pass - build metadata with aborted request + sched_out2 = scheduler.schedule() + kv_meta2 = sched_out2.kv_connector_metadata + assert kv_meta2 is not None + assert isinstance(kv_meta2, NixlConnectorMetadata) + assert req.request_id not in kv_meta2.reqs_in_batch + + # Bind empty/abort metadata and run worker step + #### Model Runner start #### + connector.bind_connector_metadata(kv_meta2) + connector.start_load_kv(dummy_ctx) + + # After abort, the worker should not keep tracking it as "in-batch" + assert req.request_id not in connector.connector_worker._reqs_to_process + #### Model Runner end #### diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fdfcc39666ad2..c9a472ce86cbd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -113,6 +113,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} self.reqs_in_batch: set[ReqId] = set() + self.reqs_not_processed: set[ReqId] = set() def add_new_req( self, @@ -287,6 +288,9 @@ class NixlConnectorScheduler: # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self._reqs_in_batch: set[ReqId] = set() + # Reqs to remove from processed set because they're not to send after + # remote prefill or aborted. + self._reqs_not_processed: set[ReqId] = set() def get_num_new_matched_tokens( self, request: "Request", @@ -401,11 +405,13 @@ class NixlConnectorScheduler: meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch + meta.reqs_not_processed = self._reqs_not_processed # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() self._reqs_in_batch = set() + self._reqs_not_processed = set() self._reqs_need_send = {} return meta @@ -439,8 +445,12 @@ class NixlConnectorScheduler: params["do_remote_prefill"] = False return False, None - if (not params.get("do_remote_decode") - or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + if not params.get("do_remote_decode"): + return False, None + if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + # Also include the case of a P/D Prefill request with immediate + # block free (eg abort). Stop tracking this request. + self._reqs_not_processed.add(request.request_id) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -1234,6 +1244,10 @@ class NixlConnectorWorker: for req_id in metadata.reqs_in_batch: self._reqs_to_process.add(req_id) + # Remove all requests that are not to be processed (eg aborted). + for req_id in metadata.reqs_not_processed: + self._reqs_to_process.discard(req_id) + # Add to requests that are waiting to be read and track expiration. for req_id, expiration_time in metadata.reqs_to_send.items(): if req_id in self._reqs_to_process: