mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 19:53:35 +08:00
[NIXL][BUG FIX] Fix both failing issue and accuracy issue with nixl + host_buffer on CUDA (#30419)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
parent
19c583398a
commit
6ca74bc11a
@ -4,6 +4,7 @@
|
|||||||
KV cache helper for store.
|
KV cache helper for store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
@ -203,6 +204,26 @@ def copy_kv_blocks(
|
|||||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def yield_req_data(
|
||||||
|
scheduler_output,
|
||||||
|
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||||
|
"""
|
||||||
|
Yields:
|
||||||
|
(req_id, new_block_id_groups, preempted)
|
||||||
|
"""
|
||||||
|
# new requests
|
||||||
|
for req_data in scheduler_output.scheduled_new_reqs:
|
||||||
|
yield req_data.req_id, req_data.block_ids, False
|
||||||
|
|
||||||
|
# cached requests
|
||||||
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
yield from zip(
|
||||||
|
cached_reqs.req_ids,
|
||||||
|
cached_reqs.new_block_ids,
|
||||||
|
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TpKVTopology:
|
class TpKVTopology:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -23,7 +23,11 @@ from vllm import envs
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import EngineId, TpKVTopology
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
|
EngineId,
|
||||||
|
TpKVTopology,
|
||||||
|
yield_req_data,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
CopyBlocksOp,
|
CopyBlocksOp,
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
@ -481,7 +485,7 @@ class NixlConnectorScheduler:
|
|||||||
# New requests are added by update_state_after_alloc in
|
# New requests are added by update_state_after_alloc in
|
||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
self._reqs_need_save: dict[ReqId, Request] = {}
|
||||||
# Reqs to send and their expiration time
|
# Reqs to send and their expiration time
|
||||||
self._reqs_need_send: dict[ReqId, float] = {}
|
self._reqs_need_send: dict[ReqId, float] = {}
|
||||||
self._reqs_in_batch: set[ReqId] = set()
|
self._reqs_in_batch: set[ReqId] = set()
|
||||||
@ -627,16 +631,7 @@ class NixlConnectorScheduler:
|
|||||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||||
# NOTE: when accelerator is not directly supported by Nixl,
|
# NOTE: when accelerator is not directly supported by Nixl,
|
||||||
# prefilled blocks need to be saved to host memory before transfer.
|
# prefilled blocks need to be saved to host memory before transfer.
|
||||||
|
self._reqs_need_save[request.request_id] = request
|
||||||
# save all blocks
|
|
||||||
block_ids = blocks.get_block_ids()[0]
|
|
||||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
|
||||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
|
||||||
# kv blocks, so host blocks won't be flushed as long as its device
|
|
||||||
# block is not overwritten; and it will be safe to skip saving them
|
|
||||||
# to host xfer buffer.
|
|
||||||
if block_ids:
|
|
||||||
self._reqs_need_save[request.request_id] = (request, block_ids)
|
|
||||||
elif params.get("do_remote_prefill"):
|
elif params.get("do_remote_prefill"):
|
||||||
if params.get("remote_block_ids"):
|
if params.get("remote_block_ids"):
|
||||||
if all(
|
if all(
|
||||||
@ -688,13 +683,32 @@ class NixlConnectorScheduler:
|
|||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
# NOTE: For the prefill side, there might be a chance that an early added
|
||||||
|
# request is a chunked prefill, so we need to check if new blocks are added
|
||||||
|
for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
|
||||||
|
req_to_save = self._reqs_need_save.get(req_id)
|
||||||
|
if req_to_save is None or new_block_id_groups is None:
|
||||||
|
continue
|
||||||
|
req = req_to_save
|
||||||
|
|
||||||
assert req.kv_transfer_params is not None
|
assert req.kv_transfer_params is not None
|
||||||
meta.add_new_req_to_save(
|
meta.add_new_req_to_save(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=new_block_id_groups[0],
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
)
|
)
|
||||||
|
assert scheduler_output.num_scheduled_tokens is not None
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
is_partial = (
|
||||||
|
req.num_computed_tokens + num_scheduled_tokens
|
||||||
|
) < req.num_prompt_tokens
|
||||||
|
if not is_partial:
|
||||||
|
# For non-partial prefills, once new req_meta is scheduled, it
|
||||||
|
# can be removed from _reqs_need_save.
|
||||||
|
# For partial prefill case, we will retain the request in
|
||||||
|
# _reqs_need_save until all blocks are scheduled with req_meta.
|
||||||
|
# Therefore, only pop if `not is_partial`.
|
||||||
|
self._reqs_need_save.pop(req_id)
|
||||||
|
|
||||||
meta.reqs_to_send = self._reqs_need_send
|
meta.reqs_to_send = self._reqs_need_send
|
||||||
meta.reqs_in_batch = self._reqs_in_batch
|
meta.reqs_in_batch = self._reqs_in_batch
|
||||||
@ -702,7 +716,6 @@ class NixlConnectorScheduler:
|
|||||||
|
|
||||||
# Clear the list once workers start the transfers
|
# Clear the list once workers start the transfers
|
||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
self._reqs_need_save.clear()
|
|
||||||
self._reqs_in_batch = set()
|
self._reqs_in_batch = set()
|
||||||
self._reqs_not_processed = set()
|
self._reqs_not_processed = set()
|
||||||
self._reqs_need_send = {}
|
self._reqs_need_send = {}
|
||||||
@ -748,6 +761,8 @@ class NixlConnectorScheduler:
|
|||||||
# Also include the case of a P/D Prefill request with immediate
|
# Also include the case of a P/D Prefill request with immediate
|
||||||
# block free (eg abort). Stop tracking this request.
|
# block free (eg abort). Stop tracking this request.
|
||||||
self._reqs_not_processed.add(request.request_id)
|
self._reqs_not_processed.add(request.request_id)
|
||||||
|
# Clear _reqs_need_save if a request is aborted as partial prefill.
|
||||||
|
self._reqs_need_save.pop(request.request_id, None)
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
# TODO: check whether block_ids actually ever be 0. If not we could
|
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||||
KVConnectorBase_V1,
|
KVConnectorBase_V1,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
|
|||||||
del self._store_jobs[req_id]
|
del self._store_jobs[req_id]
|
||||||
|
|
||||||
return finished_sending, finished_recving
|
return finished_sending, finished_recving
|
||||||
|
|
||||||
|
|
||||||
def yield_req_data(
|
|
||||||
scheduler_output,
|
|
||||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
|
||||||
"""
|
|
||||||
Yields:
|
|
||||||
(req_id, new_block_id_groups, preempted)
|
|
||||||
"""
|
|
||||||
# new requests
|
|
||||||
for req_data in scheduler_output.scheduled_new_reqs:
|
|
||||||
yield req_data.req_id, req_data.block_ids, False
|
|
||||||
|
|
||||||
# cached requests
|
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
|
||||||
yield from zip(
|
|
||||||
cached_reqs.req_ids,
|
|
||||||
cached_reqs.new_block_ids,
|
|
||||||
(req_id in cached_reqs.resumed_req_ids for req_id in cached_reqs.req_ids),
|
|
||||||
)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user