mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 01:54:28 +08:00
[PD][Nixl] Remote consumer READ timeout for clearing request blocks (#20139)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
72d14d0eed
commit
71d1d75b7a
@ -9,10 +9,13 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -41,9 +44,9 @@ def test_basic_interface():
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.requests) == 1
|
||||
assert request_id in kv_connector_metadata.requests
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||
assert request_id in kv_connector_metadata.reqs_to_recv
|
||||
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
@ -78,7 +81,7 @@ def test_prompt_less_than_block_size():
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
||||
assert len(kv_connector_metadata.requests) == 0
|
||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
||||
|
||||
# This request should be scheduled regularly.
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
@ -371,3 +374,70 @@ class TestNixlHandshake:
|
||||
if cnt_finished_reqs == total_reqs:
|
||||
return
|
||||
raise TimeoutError("Took too long to complete async handshake.")
|
||||
|
||||
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_abort_timeout_on_prefiller(monkeypatch):
|
||||
"""
|
||||
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
|
||||
-----> P
|
||||
| {process request}
|
||||
<-\--- | {result is NOT delivered, eg proxy is down}
|
||||
|
|
||||
|
|
||||
| {eventually free blocks}
|
||||
"""
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
timeout = 6
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
remote_prefill_opts = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
# Simulate sidecar request
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=1,
|
||||
extra_args={"kv_transfer_params": remote_prefill_opts})
|
||||
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
|
||||
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks
|
||||
|
||||
padding = "Just making this request a little longer so that we're sure "
|
||||
"we're not hitting the small-request lower bound beneath which we don't "
|
||||
"actually trigger the whole kv transfer, but rather just recompute the "
|
||||
"blocks on D."
|
||||
_ = llm.generate([f"What is the capital of Japan? {padding}"],
|
||||
sampling_params)
|
||||
|
||||
# Request finished but not freed
|
||||
assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks
|
||||
# Some other request, 0 still not freed
|
||||
_ = llm.generate([f"What is the capital of Italy? {padding}"],
|
||||
sampling_params)
|
||||
assert '0' in req_to_blocks
|
||||
assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks
|
||||
|
||||
# Wait for timeout and trigger another scheduler loop
|
||||
time.sleep(timeout)
|
||||
_ = llm.generate([f"What is the capital of France? {padding}"],
|
||||
sampling_params)
|
||||
# Request-0 times out and is cleared!
|
||||
assert '0' not in req_to_blocks
|
||||
|
||||
@ -79,7 +79,8 @@ class ReqMeta:
|
||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@ -87,7 +88,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
self.reqs_to_recv[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
@ -194,10 +195,12 @@ class NixlConnectorScheduler:
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
logger.info("Initializing NIXL Scheduler %s", engine_id)
|
||||
|
||||
# Requests that need to start recv.
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@ -284,6 +287,9 @@ class NixlConnectorScheduler:
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
@ -325,6 +331,11 @@ class NixlConnectorScheduler:
|
||||
# If prompt < block_size, no xfer so free blocks immediately.
|
||||
delay_free_blocks = len(computed_block_ids) > 0
|
||||
|
||||
if delay_free_blocks:
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
@ -394,6 +405,8 @@ class NixlConnectorWorker:
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
@ -826,6 +839,16 @@ class NixlConnectorWorker:
|
||||
"and %s requests done recving", self.tp_rank,
|
||||
len(done_sending), len(done_recving))
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
while self._reqs_to_send:
|
||||
req_id, expires = next(iter(self._reqs_to_send.items()))
|
||||
# Sorted dict, oldest requests are put first so we can exit early.
|
||||
if now < expires:
|
||||
break
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
if self.world_size == 1:
|
||||
return done_sending, done_recving
|
||||
|
||||
@ -857,7 +880,7 @@ class NixlConnectorWorker:
|
||||
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self._done_sending_count.keys()):
|
||||
if self._done_sending_count[req_id] == self.world_size:
|
||||
if self._done_sending_count[req_id] >= self.world_size:
|
||||
del self._done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
|
||||
@ -887,6 +910,7 @@ class NixlConnectorWorker:
|
||||
tp_ratio):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
del self._reqs_to_send[req_id]
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
@ -921,7 +945,7 @@ class NixlConnectorWorker:
|
||||
Start loading by triggering non-blocking nixl_xfer.
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
for req_id, meta in metadata.requests.items():
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
@ -943,6 +967,9 @@ class NixlConnectorWorker:
|
||||
while not self._ready_requests.empty():
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
logger.debug(
|
||||
"Remote agent %s available, calling _read_blocks for req %s",
|
||||
|
||||
10
vllm/envs.py
10
vllm/envs.py
@ -138,6 +138,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
|
||||
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -953,7 +954,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# generations on machines < 100 for compressed-tensors
|
||||
# models
|
||||
"VLLM_USE_NVFP4_CT_EMULATIONS":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
|
||||
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
|
||||
|
||||
# Time (in seconds) after which the KV cache on the producer side is
|
||||
# automatically cleared if no READ notification is received from the
|
||||
# consumer. This is only applicable when using NixlConnector in a
|
||||
# disaggregated decode-prefill setup.
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user