From 6ca74bc11afc1a985ab80f4b94b2d3bcda764630 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Thu, 18 Dec 2025 16:10:02 -0600 Subject: [PATCH] [NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA (#30419) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Chendi Xue Signed-off-by: Chendi.Xue Co-authored-by: Nicolò Lucchesi --- .../kv_transfer/kv_connector/utils.py | 21 +++++++++ .../kv_connector/v1/nixl_connector.py | 45 ++++++++++++------- .../kv_connector/v1/offloading_connector.py | 23 +--------- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index a026cccb85372..4f1ea1a0240c4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,6 +4,7 @@ KV cache helper for store. """ +from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Literal @@ -203,6 +204,26 @@ def copy_kv_blocks( copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) +def yield_req_data( + scheduler_output, +) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: + """ + Yields: + (req_id, new_block_id_groups, preempted) + """ + # new requests + for req_data in scheduler_output.scheduled_new_reqs: + yield req_data.req_id, req_data.block_ids, False + + # cached requests + cached_reqs = scheduler_output.scheduled_cached_reqs + yield from zip( + cached_reqs.req_ids, + cached_reqs.new_block_ids, + (req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids), + ) + + @dataclass class TpKVTopology: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index be56eb4e93c10..757ca41e9844b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -23,7 +23,11 @@ from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.utils import ( + EngineId, + TpKVTopology, + yield_req_data, +) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, @@ -481,7 +485,7 @@ class NixlConnectorScheduler: # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, Request] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self._reqs_in_batch: set[ReqId] = set() @@ -627,16 +631,7 @@ class NixlConnectorScheduler: if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. - - # save all blocks - block_ids = blocks.get_block_ids()[0] - # TODO: skip the blocks that are already in the host xfer buffer. - # Currently, the host xfer buffer block is 1-to-1 mapped to device - # kv blocks, so host blocks won't be flushed as long as its device - # block is not overwritten; and it will be safe to skip saving them - # to host xfer buffer. - if block_ids: - self._reqs_need_save[request.request_id] = (request, block_ids) + self._reqs_need_save[request.request_id] = request elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all( @@ -688,13 +683,32 @@ class NixlConnectorScheduler: kv_transfer_params=req.kv_transfer_params, ) - for req_id, (req, block_ids) in self._reqs_need_save.items(): + # NOTE: For the prefill side, there might be a chance that an early added + # request is a chunked prefill, so we need to check if new blocks are added + for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output): + req_to_save = self._reqs_need_save.get(req_id) + if req_to_save is None or new_block_id_groups is None: + continue + req = req_to_save + assert req.kv_transfer_params is not None meta.add_new_req_to_save( request_id=req_id, - local_block_ids=block_ids, + local_block_ids=new_block_id_groups[0], kv_transfer_params=req.kv_transfer_params, ) + assert scheduler_output.num_scheduled_tokens is not None + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + is_partial = ( + req.num_computed_tokens + num_scheduled_tokens + ) < req.num_prompt_tokens + if not is_partial: + # For non-partial prefills, once new req_meta is scheduled, it + # can be removed from _reqs_need_save. + # For partial prefill case, we will retain the request in + # _reqs_need_save until all blocks are scheduled with req_meta. + # Therefore, only pop if `not is_partial`. + self._reqs_need_save.pop(req_id) meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch @@ -702,7 +716,6 @@ class NixlConnectorScheduler: # Clear the list once workers start the transfers self._reqs_need_recv.clear() - self._reqs_need_save.clear() self._reqs_in_batch = set() self._reqs_not_processed = set() self._reqs_need_send = {} @@ -748,6 +761,8 @@ class NixlConnectorScheduler: # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. self._reqs_not_processed.add(request.request_id) + # Clear _reqs_need_save if a request is aborted as partial prefill. + self._reqs_need_save.pop(request.request_id, None) return False, None # TODO: check whether block_ids actually ever be 0. If not we could diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 0ad9d4ae1b39f..a6d86bc9e1a19 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import Iterable from dataclasses import dataclass from itertools import islice from typing import Any, ClassVar @@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent +from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, @@ -516,23 +517,3 @@ class OffloadingConnectorWorker: del self._store_jobs[req_id] return finished_sending, finished_recving - - -def yield_req_data( - scheduler_output, -) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: - """ - Yields: - (req_id, new_block_id_groups, preempted) - """ - # new requests - for req_data in scheduler_output.scheduled_new_reqs: - yield req_data.req_id, req_data.block_ids, False - - # cached requests - cached_reqs = scheduler_output.scheduled_cached_reqs - yield from zip( - cached_reqs.req_ids, - cached_reqs.new_block_ids, - (req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids), - )