mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[P/D] Avoid stranding blocks in P when aborted in D's waiting queue (#19223)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
2cc2069970
commit
55c65ab495
@ -298,8 +298,21 @@ class NixlConnectorScheduler:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
"NIXLConnector request_finished, request_status=%s, "
|
"NIXLConnector request_finished, request_status=%s, "
|
||||||
"kv_transfer_params=%s", request.status, params)
|
"kv_transfer_params=%s", request.status, params)
|
||||||
|
if not params:
|
||||||
|
return False, None
|
||||||
|
|
||||||
if (params is None or not params.get("do_remote_decode")
|
if params.get("do_remote_prefill"):
|
||||||
|
# If do_remote_prefill is still True when the request is finished,
|
||||||
|
# update_state_after_alloc must not have been called (the request
|
||||||
|
# must have been aborted before it was scheduled).
|
||||||
|
# To avoid stranding the prefill blocks in the prefill instance,
|
||||||
|
# we must add empty block_ids to _reqs_need_recv so that our
|
||||||
|
# worker side will notify and free blocks in the prefill instance.
|
||||||
|
self._reqs_need_recv[request.request_id] = (request, [])
|
||||||
|
params["do_remote_prefill"] = False
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if (not params.get("do_remote_decode")
|
||||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user