[BugFix][P/D] Fix for cases where _recving_transfers can be cleaned up when *all* transfer done (#19874)

Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
lkchen 2025-06-22 22:41:53 -07:00 committed by GitHub
parent 2ebff5b77c
commit 1bcd15edc7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 179 additions and 4 deletions

View File

@ -1,8 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import uuid
from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
try:
from nixl._api import nixl_agent as NixlWrapper
except ImportError:
NixlWrapper = None
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata)
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from .utils import create_request, create_scheduler, create_vllm_config
@ -72,3 +87,160 @@ def test_prompt_less_than_block_size():
# This request should be scheduled regularly.
assert len(scheduler_output.scheduled_new_reqs) == 1
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from NixlWrapper because NixlWrapper could be None.
"""
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int):
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
tp_size=1,
block_len=self.block_len,
attn_backend_name=self.backend_name,
))
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are multiple
transfer states for the same `request_id`, and `get_finished` should handle
it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
for i in range(4):
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id=request_id,
local_block_ids=[i + 1, i + 2, i + 3],
kv_transfer_params={
"remote_block_ids": [i + 4, i + 5, i + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
})
connector.bind_connector_metadata(metadata)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
while True:
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break

View File

@ -841,17 +841,20 @@ class NixlConnectorWorker:
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
for handle, xfer_stime in handles:
in_progress = False
for handle, _xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle)
done_req_ids.add(req_id)
del transfers[req_id]
elif xfer_state == "PROC":
in_progress = True
continue
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
if not in_progress:
done_req_ids.add(req_id)
del transfers[req_id]
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):