mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:35:40 +08:00
[NIXL] Ignore abort on already-finished request (#25067)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
606b00e80f
commit
784c231151
@ -43,6 +43,7 @@ def test_basic_lifecycle():
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
@ -67,6 +68,7 @@ def test_basic_lifecycle():
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
@ -76,6 +78,7 @@ def test_basic_lifecycle():
|
||||
# STEP (2): Send Finished to PB.
|
||||
# (2a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
@ -92,6 +95,7 @@ def test_basic_lifecycle():
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
@ -133,6 +137,7 @@ def test_short_prompt_lifecycle():
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
@ -178,7 +183,7 @@ def test_prefix_cache_lifecycle():
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
@ -213,3 +218,45 @@ def test_prefix_cache_lifecycle():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_abort_during_kv_transfer():
|
||||
"""Test aborting request does not release blocks for remote decode."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
# Request removed from PB but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Abort the request, and check the blocks are still not freed
|
||||
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Simulate a finished sending notification
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request.request_id]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@ -14,11 +14,12 @@ The class provides the following primitives:
|
||||
temporary buffer alloc by the CacheManager.
|
||||
update_connector_output() - update KVConnector state after
|
||||
output is received from worker-side connectors.
|
||||
request_finished() - called when a request is finished, with
|
||||
the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or will be
|
||||
freed asynchronously and optionally returns KV transfer
|
||||
params.
|
||||
request_finished() - called once when a request is finished,
|
||||
with the computed kv cache blocks for the request.
|
||||
Returns whether KV cache should be freed now or if the
|
||||
connector now assumes responsibility for freeing the
|
||||
the blocks asynchronously. Also optionally returns KV
|
||||
transfer params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
@ -362,7 +363,11 @@ class KVConnectorBase_V1(ABC):
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
|
||||
@ -1345,6 +1345,8 @@ class NixlConnectorWorker:
|
||||
# Remove all requests that are not to be processed (eg aborted).
|
||||
for req_id in metadata.reqs_not_processed:
|
||||
self._reqs_to_process.discard(req_id)
|
||||
# We should never get an abort after setting an expiry timer
|
||||
assert req_id not in self._reqs_to_send
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||
|
||||
@ -1187,7 +1187,7 @@ class Scheduler(SchedulerInterface):
|
||||
# First pass: collect requests to remove from queues
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
if request is None or request.is_finished():
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
@ -1365,14 +1365,8 @@ class Scheduler(SchedulerInterface):
|
||||
self.finished_recving_kv_req_ids.add(req_id)
|
||||
for req_id in kv_connector_output.finished_sending or ():
|
||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||
if req_id not in self.requests:
|
||||
logger.warning(
|
||||
"Got finished sending KV transfer for request %s,"
|
||||
"but the request is already freed.",
|
||||
req_id,
|
||||
)
|
||||
else:
|
||||
self._free_blocks(self.requests[req_id])
|
||||
assert req_id in self.requests
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
def _update_requests_with_invalid_blocks(
|
||||
self, requests: Iterable[Request], invalid_block_ids: set[int]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user