[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA (#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Chendi.Xue 2025-12-18 16:10:02 -06:00 committed by GitHub
parent 19c583398a
commit 6ca74bc11a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 36 deletions

View File

@ -4,6 +4,7 @@
KV cache helper for store.
"""
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
@ -203,6 +204,26 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
@dataclass
class TpKVTopology:
"""

View File

@ -23,7 +23,11 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
@ -481,7 +485,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
@ -627,16 +631,7 @@ class NixlConnectorScheduler:
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
self._reqs_need_save[request.request_id] = request
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
@ -688,13 +683,32 @@ class NixlConnectorScheduler:
kv_transfer_params=req.kv_transfer_params,
)
for req_id, (req, block_ids) in self._reqs_need_save.items():
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None
meta.add_new_req_to_save(
request_id=req_id,
local_block_ids=block_ids,
local_block_ids=new_block_id_groups[0],
kv_transfer_params=req.kv_transfer_params,
)
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
@ -702,7 +716,6 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
@ -748,6 +761,8 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any, ClassVar
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
del self._store_jobs[req_id]
return finished_sending, finished_recving
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)