mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:25:40 +08:00
[BugFix][P/D] Fix for cases where _recving_transfers can be cleaned up when *all* transfer done (#19874)
Signed-off-by: Linkun Chen <github@lkchen.net>
This commit is contained in:
parent
2ebff5b77c
commit
1bcd15edc7
@ -1,8 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
except ImportError:
|
||||
NixlWrapper = None
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlConnectorMetadata)
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -72,3 +87,160 @@ def test_prompt_less_than_block_size():
|
||||
|
||||
# This request should be scheduled regularly.
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
We don't inherit from NixlWrapper because NixlWrapper could be None.
|
||||
"""
|
||||
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
REMOTE_AGENT_NAME = "remote_agent"
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def get_agent_metadata(self) -> bytes:
|
||||
return self.AGENT_METADATA
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
return self.REMOTE_AGENT_NAME
|
||||
|
||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||
# Used to collect done_sending, which we don't test yet.
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
|
||||
|
||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
|
||||
REMOTE_ENGINE_ID = "remote_engine"
|
||||
|
||||
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int):
|
||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
# gpu_model_runner. Here we just hardcode some dummy values.
|
||||
self.slot_size_bytes = 4096
|
||||
self.block_len = self.slot_size_bytes * self.block_size
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
self.add_remote_agent(
|
||||
NixlAgentMetadata(
|
||||
engine_id=self.REMOTE_ENGINE_ID,
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
tp_size=1,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name,
|
||||
))
|
||||
|
||||
|
||||
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_multi_xfer_one_engine(
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
"""Test case where multiple xfers are initiated to the same engine.
|
||||
|
||||
This test triggers the connector to load remote KV for the same
|
||||
`request_id`. The transfer is not done immediately due to
|
||||
`set_cycles_before_xfer_done`, so there is a state where there are multiple
|
||||
transfer states for the same `request_id`, and `get_finished` should handle
|
||||
it correctly (wait for all transfers to be done).
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
request_id = "req_id"
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
|
||||
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
|
||||
for i in range(4):
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id=request_id,
|
||||
local_block_ids=[i + 1, i + 2, i + 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [i + 4, i + 5, i + 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
_before_load = time.perf_counter()
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_after_load = time.perf_counter()
|
||||
assert _after_load - _before_load < 0.1, "start_load_kv took " \
|
||||
f"{_after_load - _before_load} seconds"
|
||||
|
||||
while True:
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0:
|
||||
assert request_id in done_recving
|
||||
break
|
||||
|
||||
@ -841,17 +841,20 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
for handle, xfer_stime in handles:
|
||||
in_progress = False
|
||||
for handle, _xfer_stime in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s",
|
||||
xfer_state)
|
||||
if not in_progress:
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user