[Bugfix] Fix _reqs_to_process leak on abort (#26012)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-04 13:39:31 +02:00 committed by GitHub
parent f05fea1f5e
commit 2a6dc67eb5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 2 deletions

View File

@ -33,6 +33,7 @@ from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from .utils import create_request, create_scheduler, create_vllm_config
@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init):
assert mock_dereg.call_count == 2
mock_dereg.assert_any_call("desc1")
mock_dereg.assert_any_call("desc2")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
iteration) and verify the worker no longer tracks it as in-batch.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
# Create a request that triggers do_remote_decode so that
# the scheduler adds it to reqs_in_batch
req = create_request(request_id=1, do_remote_decode=True, max_tokens=1)
scheduler.add_request(req)
# First scheduling pass - examinate build_connector_meta output
sched_out = scheduler.schedule()
kv_meta = sched_out.kv_connector_metadata
assert kv_meta is not None
assert isinstance(kv_meta, NixlConnectorMetadata)
assert req.request_id in kv_meta.reqs_in_batch
#### Model Runner start ####
# Bind scheduler-produced metadata and start worker processing.
connector.bind_connector_metadata(kv_meta)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)
# Ensure it was tracked by the worker
assert req.request_id in connector.connector_worker._reqs_to_process
#### Model Runner end ####
# Abort request - request_finished call in connector scheduler
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
# Second scheduling pass - build metadata with aborted request
sched_out2 = scheduler.schedule()
kv_meta2 = sched_out2.kv_connector_metadata
assert kv_meta2 is not None
assert isinstance(kv_meta2, NixlConnectorMetadata)
assert req.request_id not in kv_meta2.reqs_in_batch
# Bind empty/abort metadata and run worker step
#### Model Runner start ####
connector.bind_connector_metadata(kv_meta2)
connector.start_load_kv(dummy_ctx)
# After abort, the worker should not keep tracking it as "in-batch"
assert req.request_id not in connector.connector_worker._reqs_to_process
#### Model Runner end ####

View File

@ -113,6 +113,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
def add_new_req(
self,
@ -287,6 +288,9 @@ class NixlConnectorScheduler:
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
def get_num_new_matched_tokens(
self, request: "Request",
@ -401,11 +405,13 @@ class NixlConnectorScheduler:
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
return meta
@ -439,8 +445,12 @@ class NixlConnectorScheduler:
params["do_remote_prefill"] = False
return False, None
if (not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
if not params.get("do_remote_decode"):
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
@ -1234,6 +1244,10 @@ class NixlConnectorWorker:
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)
# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process: