mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Bugfix] Fix _reqs_to_process leak on abort (#26012)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
f05fea1f5e
commit
2a6dc67eb5
@ -33,6 +33,7 @@ from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init):
|
||||
assert mock_dereg.call_count == 2
|
||||
mock_dereg.assert_any_call("desc1")
|
||||
mock_dereg.assert_any_call("desc2")
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_aborted_request_removed_from_worker_in_batch(dist_init):
|
||||
"""
|
||||
Create and schedule a request so that P adds it to in-batch tracking via
|
||||
the real scheduler, then simulate an abort (request not in next scheduler
|
||||
iteration) and verify the worker no longer tracks it as in-batch.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
# KVConnector Worker in P
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
|
||||
# Create a request that triggers do_remote_decode so that
|
||||
# the scheduler adds it to reqs_in_batch
|
||||
req = create_request(request_id=1, do_remote_decode=True, max_tokens=1)
|
||||
scheduler.add_request(req)
|
||||
|
||||
# First scheduling pass - examinate build_connector_meta output
|
||||
sched_out = scheduler.schedule()
|
||||
kv_meta = sched_out.kv_connector_metadata
|
||||
assert kv_meta is not None
|
||||
assert isinstance(kv_meta, NixlConnectorMetadata)
|
||||
assert req.request_id in kv_meta.reqs_in_batch
|
||||
|
||||
#### Model Runner start ####
|
||||
# Bind scheduler-produced metadata and start worker processing.
|
||||
connector.bind_connector_metadata(kv_meta)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# Ensure it was tracked by the worker
|
||||
assert req.request_id in connector.connector_worker._reqs_to_process
|
||||
|
||||
#### Model Runner end ####
|
||||
|
||||
# Abort request - request_finished call in connector scheduler
|
||||
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
# Second scheduling pass - build metadata with aborted request
|
||||
sched_out2 = scheduler.schedule()
|
||||
kv_meta2 = sched_out2.kv_connector_metadata
|
||||
assert kv_meta2 is not None
|
||||
assert isinstance(kv_meta2, NixlConnectorMetadata)
|
||||
assert req.request_id not in kv_meta2.reqs_in_batch
|
||||
|
||||
# Bind empty/abort metadata and run worker step
|
||||
#### Model Runner start ####
|
||||
connector.bind_connector_metadata(kv_meta2)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# After abort, the worker should not keep tracking it as "in-batch"
|
||||
assert req.request_id not in connector.connector_worker._reqs_to_process
|
||||
#### Model Runner end ####
|
||||
|
||||
@ -113,6 +113,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
self.reqs_in_batch: set[ReqId] = set()
|
||||
self.reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@ -287,6 +288,9 @@ class NixlConnectorScheduler:
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self._reqs_in_batch: set[ReqId] = set()
|
||||
# Reqs to remove from processed set because they're not to send after
|
||||
# remote prefill or aborted.
|
||||
self._reqs_not_processed: set[ReqId] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@ -401,11 +405,13 @@ class NixlConnectorScheduler:
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
meta.reqs_in_batch = self._reqs_in_batch
|
||||
meta.reqs_not_processed = self._reqs_not_processed
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_in_batch = set()
|
||||
self._reqs_not_processed = set()
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
@ -439,8 +445,12 @@ class NixlConnectorScheduler:
|
||||
params["do_remote_prefill"] = False
|
||||
return False, None
|
||||
|
||||
if (not params.get("do_remote_decode")
|
||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||
if not params.get("do_remote_decode"):
|
||||
return False, None
|
||||
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
|
||||
# Also include the case of a P/D Prefill request with immediate
|
||||
# block free (eg abort). Stop tracking this request.
|
||||
self._reqs_not_processed.add(request.request_id)
|
||||
return False, None
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
@ -1234,6 +1244,10 @@ class NixlConnectorWorker:
|
||||
for req_id in metadata.reqs_in_batch:
|
||||
self._reqs_to_process.add(req_id)
|
||||
|
||||
# Remove all requests that are not to be processed (eg aborted).
|
||||
for req_id in metadata.reqs_not_processed:
|
||||
self._reqs_to_process.discard(req_id)
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||
if req_id in self._reqs_to_process:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user