[NIXL] Ignore abort on already-finished request (#25067)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-10-10 11:21:56 +01:00 committed by GitHub
parent 606b00e80f
commit 784c231151
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 16 deletions

View File

@ -43,6 +43,7 @@ def test_basic_lifecycle():
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
@ -67,6 +68,7 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
assert len(scheduler.requests) == 1
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
@ -76,6 +78,7 @@ def test_basic_lifecycle():
# STEP (2): Send Finished to PB.
# (2a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
@ -92,6 +95,7 @@ def test_basic_lifecycle():
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
@ -133,6 +137,7 @@ def test_short_prompt_lifecycle():
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle():
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
def test_abort_during_kv_transfer():
"""Test aborting request does not release blocks for remote decode."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request])
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
# Request removed from PB but blocks should not be freed.
assert len(scheduler.requests) == 1
# Abort the request, and check the blocks are still not freed
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
assert len(scheduler.requests) == 1
# Simulate a finished sending notification
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request.request_id]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)

View File

@ -14,11 +14,12 @@ The class provides the following primitives:
temporary buffer alloc by the CacheManager.
update_connector_output() - update KVConnector state after
output is received from worker-side connectors.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
request_finished() - called once when a request is finished,
with the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or if the
connector now assumes responsibility for freeing the
the blocks asynchronously. Also optionally returns KV
transfer params.
take_events() - returns new KV events that were collected
by the connector since the last call.
@ -362,7 +363,11 @@ class KVConnectorBase_V1(ABC):
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks

View File

@ -1345,6 +1345,8 @@ class NixlConnectorWorker:
# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)
# We should never get an abort after setting an expiry timer
assert req_id not in self._reqs_to_send
# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():

View File

@ -1187,7 +1187,7 @@ class Scheduler(SchedulerInterface):
# First pass: collect requests to remove from queues
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
if request is None or request.is_finished():
# Invalid request ID.
continue
@ -1365,14 +1365,8 @@ class Scheduler(SchedulerInterface):
self.finished_recving_kv_req_ids.add(req_id)
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
if req_id not in self.requests:
logger.warning(
"Got finished sending KV transfer for request %s,"
"but the request is already freed.",
req_id,
)
else:
self._free_blocks(self.requests[req_id])
assert req_id in self.requests
self._free_blocks(self.requests[req_id])
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request], invalid_block_ids: set[int]