mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 16:40:53 +08:00
[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA (#30419)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
parent
19c583398a
commit
6ca74bc11a
@ -4,6 +4,7 @@
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
@ -203,6 +204,26 @@ def copy_kv_blocks(
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TpKVTopology:
|
||||
"""
|
||||
|
||||
@ -23,7 +23,11 @@ from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
EngineId,
|
||||
TpKVTopology,
|
||||
yield_req_data,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp,
|
||||
KVConnectorBase_V1,
|
||||
@ -481,7 +485,7 @@ class NixlConnectorScheduler:
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_save: dict[ReqId, Request] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self._reqs_in_batch: set[ReqId] = set()
|
||||
@ -627,16 +631,7 @@ class NixlConnectorScheduler:
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
|
||||
# save all blocks
|
||||
block_ids = blocks.get_block_ids()[0]
|
||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
||||
# kv blocks, so host blocks won't be flushed as long as its device
|
||||
# block is not overwritten; and it will be safe to skip saving them
|
||||
# to host xfer buffer.
|
||||
if block_ids:
|
||||
self._reqs_need_save[request.request_id] = (request, block_ids)
|
||||
self._reqs_need_save[request.request_id] = request
|
||||
elif params.get("do_remote_prefill"):
|
||||
if params.get("remote_block_ids"):
|
||||
if all(
|
||||
@ -688,13 +683,32 @@ class NixlConnectorScheduler:
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
# NOTE: For the prefill side, there might be a chance that an early added
|
||||
# request is a chunked prefill, so we need to check if new blocks are added
|
||||
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
|
||||
req_to_save = self._reqs_need_save.get(req_id)
|
||||
if req_to_save is None or new_block_id_groups is None:
|
||||
continue
|
||||
req = req_to_save
|
||||
|
||||
assert req.kv_transfer_params is not None
|
||||
meta.add_new_req_to_save(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
local_block_ids=new_block_id_groups[0],
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
assert scheduler_output.num_scheduled_tokens is not None
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
is_partial = (
|
||||
req.num_computed_tokens + num_scheduled_tokens
|
||||
) < req.num_prompt_tokens
|
||||
if not is_partial:
|
||||
# For non-partial prefills, once new req_meta is scheduled, it
|
||||
# can be removed from _reqs_need_save.
|
||||
# For partial prefill case, we will retain the request in
|
||||
# _reqs_need_save until all blocks are scheduled with req_meta.
|
||||
# Therefore, only pop if `not is_partial`.
|
||||
self._reqs_need_save.pop(req_id)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
meta.reqs_in_batch = self._reqs_in_batch
|
||||
@ -702,7 +716,6 @@ class NixlConnectorScheduler:
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_in_batch = set()
|
||||
self._reqs_not_processed = set()
|
||||
self._reqs_need_send = {}
|
||||
@ -748,6 +761,8 @@ class NixlConnectorScheduler:
|
||||
# Also include the case of a P/D Prefill request with immediate
|
||||
# block free (eg abort). Stop tracking this request.
|
||||
self._reqs_not_processed.add(request.request_id)
|
||||
# Clear _reqs_need_save if a request is aborted as partial prefill.
|
||||
self._reqs_need_save.pop(request.request_id, None)
|
||||
return False, None
|
||||
|
||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, ClassVar
|
||||
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(
|
||||
cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids,
|
||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user