[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA (#30419)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Chendi.Xue 2025-12-18 16:10:02 -06:00 committed by GitHub
parent 19c583398a
commit 6ca74bc11a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 36 deletions

View File

@ -4,6 +4,7 @@
KV cache helper for store. KV cache helper for store.
""" """
from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
@ -203,6 +204,26 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)
@dataclass @dataclass
class TpKVTopology: class TpKVTopology:
""" """

View File

@ -23,7 +23,11 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp, CopyBlocksOp,
KVConnectorBase_V1, KVConnectorBase_V1,
@ -481,7 +485,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, Request] = {}
# Reqs to send and their expiration time # Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set() self._reqs_in_batch: set[ReqId] = set()
@ -627,16 +631,7 @@ class NixlConnectorScheduler:
if self.use_host_buffer and params.get("do_remote_decode"): if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl, # NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer. # prefilled blocks need to be saved to host memory before transfer.
self._reqs_need_save[request.request_id] = request
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"): elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"): if params.get("remote_block_ids"):
if all( if all(
@ -688,13 +683,32 @@ class NixlConnectorScheduler:
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
) )
for req_id, (req, block_ids) in self._reqs_need_save.items(): # NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
req_to_save = self._reqs_need_save.get(req_id)
if req_to_save is None or new_block_id_groups is None:
continue
req = req_to_save
assert req.kv_transfer_params is not None assert req.kv_transfer_params is not None
meta.add_new_req_to_save( meta.add_new_req_to_save(
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=new_block_id_groups[0],
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
) )
assert scheduler_output.num_scheduled_tokens is not None
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
is_partial = (
req.num_computed_tokens + num_scheduled_tokens
) < req.num_prompt_tokens
if not is_partial:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self._reqs_need_save.pop(req_id)
meta.reqs_to_send = self._reqs_need_send meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch meta.reqs_in_batch = self._reqs_in_batch
@ -702,7 +716,6 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers # Clear the list once workers start the transfers
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set() self._reqs_in_batch = set()
self._reqs_not_processed = set() self._reqs_not_processed = set()
self._reqs_need_send = {} self._reqs_need_send = {}
@ -748,6 +761,8 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate # Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request. # block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id) self._reqs_not_processed.add(request.request_id)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self._reqs_need_save.pop(request.request_id, None)
return False, None return False, None
# TODO: check whether block_ids actually ever be 0. If not we could # TODO: check whether block_ids actually ever be 0. If not we could

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Iterator from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import islice from itertools import islice
from typing import Any, ClassVar from typing import Any, ClassVar
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
del self._store_jobs[req_id] del self._store_jobs[req_id]
return finished_sending, finished_recving return finished_sending, finished_recving
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for req_data in scheduler_output.scheduled_new_reqs:
yield req_data.req_id, req_data.block_ids, False
# cached requests
cached_reqs = scheduler_output.scheduled_cached_reqs
yield from zip(
cached_reqs.req_ids,
cached_reqs.new_block_ids,
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
)