[Bugfix][Nixl] Fix full prefix cache hit bug (#18632)

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Robert Shaw 2025-06-04 22:07:32 -04:00 committed by GitHub
parent 78dcf56cb3
commit c56ed8bb0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 97 additions and 81 deletions

View File

@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@ -32,7 +33,7 @@ class TestSharedStorageConnector(SharedStorageConnector):
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}_events.log"
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
@ -52,10 +53,19 @@ class TestSharedStorageConnector(SharedStorageConnector):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(f"num_blocks={len(arg.blocks)}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(name + "\n")
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
@ -162,15 +172,23 @@ def test_multi_shared_storage_connector_consistency():
f"{storage_1_path} and {storage_2_path}")
events = get_connector_events()
# get_num_new_matched_tokens will be called on each connector in turn.
# neither of them have hits so update_state_after_alloc won't be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage2"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
assert events["storage1-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
]
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage2-WORKER"][:5] == [
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
'wait_for_layer_load', 'save_kv_layer'
]
# Reset prefix cache or else we'll just get the tokens back from there.
@ -182,16 +200,16 @@ def test_multi_shared_storage_connector_consistency():
events = get_connector_events()
# get_num_new_matched_tokens will return new tokens from the first
# connector so update_state_after_alloc will be called once blocks
# are allocated for the first connector.
# get_num_new_matched_tokens *won't* be called on the second connector
# in this case.
assert events["storage1"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
# connector so update_state_after_alloc will be with allocated blocks
# on that one but with zero blocks for others (first nonzero match is
# chosen).
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
]
assert events["storage2"][:2] == [
'build_connector_meta', 'bind_connector_metadata'
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
# Delete storage1 connector state
@ -205,17 +223,17 @@ def test_multi_shared_storage_connector_consistency():
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
events = get_connector_events()
# get_num_new_matched_tokens will be called for the first connector but it
# won't have a hit so update_state_after_alloc won't be called.
# get_num_new_matched_tokens will also be called on the second connector,
# but it should have a hit so update_state_after_alloc will be called.
assert events["storage1"][:3] == [
'get_num_new_matched_tokens', 'build_connector_meta',
'bind_connector_metadata'
# get_num_new_matched_tokens will be called for both connectors but will
# return 0 from the first connector, but the second connector should have
# a hit, so update_state_after_alloc will only be called with allocated
# blocks for the second connector.
assert events["storage1-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
]
assert events["storage2"][:4] == [
'get_num_new_matched_tokens', 'update_state_after_alloc',
'build_connector_meta', 'bind_connector_metadata'
assert events["storage2-SCHEDULER"][:3] == [
'get_num_new_matched_tokens 0',
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
]
# Clean up

View File

@ -12,12 +12,12 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -51,8 +51,9 @@ class MultiConnector(KVConnectorBase_V1):
self._connectors.append(
KVConnectorFactory.create_connector_v1(temp_config, role))
# A mapping from request id to the connector that is assigned to it.
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
self._requests_to_connector: dict[str, int] = {}
# Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow
@ -136,25 +137,31 @@ class MultiConnector(KVConnectorBase_V1):
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
for c in self._connectors:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# The first connector that has new matched tokens will be assigned
# to this request.
if toks > 0:
self._requests_to_connector[request.request_id] = c
return toks, load_async
return 0, False
if to_return[0] == 0 and toks > 0:
self._requests_to_connector[request.request_id] = i
to_return = (toks, load_async)
return to_return
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
# If the request is not assigned to any connector, we do nothing.
if request.request_id not in self._requests_to_connector:
return
# We assume that the request is assigned to only one connector.
c = self._requests_to_connector.pop(request.request_id)
c.update_state_after_alloc(request, blocks, num_external_tokens)
chosen_connector = self._requests_to_connector.get(
request.request_id, -1)
for i, c in enumerate(self._connectors):
if i == chosen_connector:
# Forward call to the chosen connector (if any).
c.update_state_after_alloc(request, blocks,
num_external_tokens)
else:
# Call with empty blocks for other connectors.
c.update_state_after_alloc(request,
KVCacheBlocks.create_empty(), 0)
def build_connector_meta(
self,
@ -170,7 +177,7 @@ class MultiConnector(KVConnectorBase_V1):
def request_finished(
self,
request: "Request",
blocks: "KVCacheBlocks",
blocks: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
async_saves = 0
kv_txfer_params = None
@ -187,4 +194,8 @@ class MultiConnector(KVConnectorBase_V1):
kv_txfer_params = txfer_params
if async_saves > 1:
self._extra_async_saves[request.request_id] = async_saves - 1
# Clean up other state for this request.
self._requests_to_connector.pop(request.request_id, None)
return async_saves > 0, kv_txfer_params

View File

@ -221,15 +221,6 @@ class NixlConnectorScheduler:
if count > 0:
return count, True
# NOTE: if count is 0 here, we have less than block_size
# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])
# No remote prefill for this request.
return 0, False
@ -247,9 +238,14 @@ class NixlConnectorScheduler:
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, blocks.get_unhashed_block_ids())
request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
@ -268,15 +264,6 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
if not block_ids:
logger.debug(
"Skipping adding request %s to NixlConnectorMetadata, "
"as there are no remote blocks to pull", req_id)
continue
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
@ -660,26 +647,26 @@ class NixlConnectorWorker:
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
"Local TP size must be divisible by remote TP size."
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
"Local TP size must be divisible by remote TP size.")
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
assert tp_ratio > 0, "Decode TP cannot be smaller than"
" prefill TP"
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
if self.use_mla:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len / (
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
else:
remote_block_size = nixl_agent_meta.block_len / (
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
"Remote P worker KV layer cache must be of shape [2, N, \
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, "Remote P worker with \
different block size is not supported"
assert self.block_size == remote_block_size, "Remote P worker with "
"different block size is not supported"
assert self.num_blocks >= nixl_agent_meta.num_blocks
@ -712,9 +699,9 @@ class NixlConnectorWorker:
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and " \
"local rank %s",
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")

View File

@ -424,11 +424,11 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled.
break
# KVConnector: update internal state after allocation.
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if num_external_computed_tokens:
assert self.connector is not None
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
@ -841,7 +841,7 @@ class Scheduler(SchedulerInterface):
}
finished_req_ids = self.finished_req_ids_dict
if finished_req_ids is not None:
if finished_req_ids:
# Include ids of requests that finished since last outputs
# were sent.
for client_index, finished_set in finished_req_ids.items():