From 55c65ab495f5d270f65f89dcc737e9694b278002 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 25 Jun 2025 15:19:44 -0700 Subject: [PATCH] [P/D] Avoid stranding blocks in P when aborted in D's waiting queue (#19223) Signed-off-by: Nick Hill --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a962a9241d73e..92a9184d318c7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -298,8 +298,21 @@ class NixlConnectorScheduler: logger.debug( "NIXLConnector request_finished, request_status=%s, " "kv_transfer_params=%s", request.status, params) + if not params: + return False, None - if (params is None or not params.get("do_remote_decode") + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if (not params.get("do_remote_decode") or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None