mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[Bugfix][Nixl] Fix full prefix cache hit bug (#18632)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
78dcf56cb3
commit
c56ed8bb0e
@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -32,7 +33,7 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = tempfile.gettempdir(
|
||||
) + f"/connector_{self.name}_events.log"
|
||||
) + f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
@ -52,10 +53,19 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_record[name] += 1
|
||||
|
||||
# Include args that we're interested in
|
||||
to_log = [name]
|
||||
for arg in args:
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(f"num_blocks={len(arg.blocks)}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(name + "\n")
|
||||
f.write(' '.join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} "
|
||||
f"for {self.name}: {e}")
|
||||
@ -162,15 +172,23 @@ def test_multi_shared_storage_connector_consistency():
|
||||
f"{storage_1_path} and {storage_2_path}")
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called on each connector in turn.
|
||||
# neither of them have hits so update_state_after_alloc won't be called.
|
||||
assert events["storage1"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
# get_num_new_matched_tokens and update_state_after_alloc will be called
|
||||
# on each connector in turn.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||
]
|
||||
assert events["storage2"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
assert events["storage1-WORKER"][:5] == [
|
||||
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||
'wait_for_layer_load', 'save_kv_layer'
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||
]
|
||||
assert events["storage2-WORKER"][:5] == [
|
||||
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||
'wait_for_layer_load', 'save_kv_layer'
|
||||
]
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
@ -182,16 +200,16 @@ def test_multi_shared_storage_connector_consistency():
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will return new tokens from the first
|
||||
# connector so update_state_after_alloc will be called once blocks
|
||||
# are allocated for the first connector.
|
||||
# get_num_new_matched_tokens *won't* be called on the second connector
|
||||
# in this case.
|
||||
assert events["storage1"][:4] == [
|
||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
# connector so update_state_after_alloc will be with allocated blocks
|
||||
# on that one but with zero blocks for others (first nonzero match is
|
||||
# chosen).
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
|
||||
]
|
||||
assert events["storage2"][:2] == [
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||
]
|
||||
|
||||
# Delete storage1 connector state
|
||||
@ -205,17 +223,17 @@ def test_multi_shared_storage_connector_consistency():
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called for the first connector but it
|
||||
# won't have a hit so update_state_after_alloc won't be called.
|
||||
# get_num_new_matched_tokens will also be called on the second connector,
|
||||
# but it should have a hit so update_state_after_alloc will be called.
|
||||
assert events["storage1"][:3] == [
|
||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
||||
'bind_connector_metadata'
|
||||
# get_num_new_matched_tokens will be called for both connectors but will
|
||||
# return 0 from the first connector, but the second connector should have
|
||||
# a hit, so update_state_after_alloc will only be called with allocated
|
||||
# blocks for the second connector.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||
]
|
||||
assert events["storage2"][:4] == [
|
||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
||||
'build_connector_meta', 'bind_connector_metadata'
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
'get_num_new_matched_tokens 0',
|
||||
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
|
||||
]
|
||||
|
||||
# Clean up
|
||||
|
||||
@ -12,12 +12,12 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -51,8 +51,9 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||
|
||||
# A mapping from request id to the connector that is assigned to it.
|
||||
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
self._requests_to_connector: dict[str, int] = {}
|
||||
|
||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||
# finished per request. Not needed for async loads since we only allow
|
||||
@ -136,25 +137,31 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
for c in self._connectors:
|
||||
to_return = (0, False)
|
||||
for i, c in enumerate(self._connectors):
|
||||
toks, load_async = c.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
# The first connector that has new matched tokens will be assigned
|
||||
# to this request.
|
||||
if toks > 0:
|
||||
self._requests_to_connector[request.request_id] = c
|
||||
return toks, load_async
|
||||
return 0, False
|
||||
if to_return[0] == 0 and toks > 0:
|
||||
self._requests_to_connector[request.request_id] = i
|
||||
to_return = (toks, load_async)
|
||||
return to_return
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
# If the request is not assigned to any connector, we do nothing.
|
||||
if request.request_id not in self._requests_to_connector:
|
||||
return
|
||||
# We assume that the request is assigned to only one connector.
|
||||
c = self._requests_to_connector.pop(request.request_id)
|
||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
chosen_connector = self._requests_to_connector.get(
|
||||
request.request_id, -1)
|
||||
for i, c in enumerate(self._connectors):
|
||||
if i == chosen_connector:
|
||||
# Forward call to the chosen connector (if any).
|
||||
c.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
else:
|
||||
# Call with empty blocks for other connectors.
|
||||
c.update_state_after_alloc(request,
|
||||
KVCacheBlocks.create_empty(), 0)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
@ -170,7 +177,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
blocks: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
async_saves = 0
|
||||
kv_txfer_params = None
|
||||
@ -187,4 +194,8 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
kv_txfer_params = txfer_params
|
||||
if async_saves > 1:
|
||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||
|
||||
# Clean up other state for this request.
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
@ -221,15 +221,6 @@ class NixlConnectorScheduler:
|
||||
if count > 0:
|
||||
return count, True
|
||||
|
||||
# NOTE: if count is 0 here, we have less than block_size
|
||||
# tokens to pull after subtracting the local prefix cache hit.
|
||||
# The remote only sends fully computed blocks, so there is
|
||||
# nothing to transfer but we still need to notify the
|
||||
# prefill worker so that the remote blocks are freed.
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port")):
|
||||
self._reqs_need_recv[request.request_id] = (request, [])
|
||||
|
||||
# No remote prefill for this request.
|
||||
return 0, False
|
||||
|
||||
@ -247,9 +238,14 @@ class NixlConnectorScheduler:
|
||||
if params.get("remote_block_ids"):
|
||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||
"remote_port")):
|
||||
# If remote_blocks and num_external_tokens = 0, we have
|
||||
# a full prefix cache hit on the D worker. We need to call
|
||||
# send_notif in _read_blocks to free the memory on the P.
|
||||
local_block_ids = (blocks.get_unhashed_block_ids()
|
||||
if num_external_tokens > 0 else [])
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, blocks.get_unhashed_block_ids())
|
||||
request, local_block_ids)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
@ -268,15 +264,6 @@ class NixlConnectorScheduler:
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
# For the case where there are no remote blocks to pull
|
||||
# (block_ids is empty), we don't need to schedule
|
||||
# an async read on the worker side.
|
||||
if not block_ids:
|
||||
logger.debug(
|
||||
"Skipping adding request %s to NixlConnectorMetadata, "
|
||||
"as there are no remote blocks to pull", req_id)
|
||||
continue
|
||||
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
@ -660,26 +647,26 @@ class NixlConnectorWorker:
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
|
||||
"Local TP size must be divisible by remote TP size."
|
||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
|
||||
"Local TP size must be divisible by remote TP size.")
|
||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than"
|
||||
" prefill TP"
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
if self.use_mla:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len / (
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes)
|
||||
assert self.block_len == nixl_agent_meta.block_len
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len / (
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
|
||||
"Remote P worker KV layer cache must be of shape [2, N, \
|
||||
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, "Remote P worker with \
|
||||
different block size is not supported"
|
||||
assert self.block_size == remote_block_size, "Remote P worker with "
|
||||
"different block size is not supported"
|
||||
|
||||
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
||||
|
||||
@ -712,9 +699,9 @@ class NixlConnectorWorker:
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and " \
|
||||
"local rank %s",
|
||||
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
|
||||
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||
self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
|
||||
@ -424,11 +424,11 @@ class Scheduler(SchedulerInterface):
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
# KVConnector: update internal state after allocation.
|
||||
# KVTransfer: the connector uses this info to determine
|
||||
# if a load is needed. Note that
|
||||
# This information is used to determine if a load is
|
||||
# needed for this request.
|
||||
if num_external_computed_tokens:
|
||||
assert self.connector is not None
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
@ -841,7 +841,7 @@ class Scheduler(SchedulerInterface):
|
||||
}
|
||||
|
||||
finished_req_ids = self.finished_req_ids_dict
|
||||
if finished_req_ids is not None:
|
||||
if finished_req_ids:
|
||||
# Include ids of requests that finished since last outputs
|
||||
# were sent.
|
||||
for client_index, finished_set in finished_req_ids.items():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user