mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 19:54:57 +08:00
[Bugfix][Nixl] Fix full prefix cache hit bug (#18632)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
78dcf56cb3
commit
c56ed8bb0e
@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||||
SharedStorageConnector)
|
SharedStorageConnector)
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
|||||||
self.call_record: dict[str, int] = defaultdict(int)
|
self.call_record: dict[str, int] = defaultdict(int)
|
||||||
# Use a unique temp file per connector
|
# Use a unique temp file per connector
|
||||||
self._event_file = tempfile.gettempdir(
|
self._event_file = tempfile.gettempdir(
|
||||||
) + f"/connector_{self.name}_events.log"
|
) + f"/connector_{self.name}-{self.role.name}_events.log"
|
||||||
# Start with an empty file
|
# Start with an empty file
|
||||||
with open(self._event_file, "w") as _:
|
with open(self._event_file, "w") as _:
|
||||||
pass
|
pass
|
||||||
@ -52,10 +53,19 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
|||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
self.call_record[name] += 1
|
self.call_record[name] += 1
|
||||||
|
|
||||||
|
# Include args that we're interested in
|
||||||
|
to_log = [name]
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, int):
|
||||||
|
to_log.append(str(arg))
|
||||||
|
elif isinstance(arg, KVCacheBlocks):
|
||||||
|
to_log.append(f"num_blocks={len(arg.blocks)}")
|
||||||
|
|
||||||
# Log the event as a line to the file
|
# Log the event as a line to the file
|
||||||
try:
|
try:
|
||||||
with open(self._event_file, "a") as f:
|
with open(self._event_file, "a") as f:
|
||||||
f.write(name + "\n")
|
f.write(' '.join(to_log) + "\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[ERROR] Could not log event {name} "
|
print(f"[ERROR] Could not log event {name} "
|
||||||
f"for {self.name}: {e}")
|
f"for {self.name}: {e}")
|
||||||
@ -162,15 +172,23 @@ def test_multi_shared_storage_connector_consistency():
|
|||||||
f"{storage_1_path} and {storage_2_path}")
|
f"{storage_1_path} and {storage_2_path}")
|
||||||
|
|
||||||
events = get_connector_events()
|
events = get_connector_events()
|
||||||
# get_num_new_matched_tokens will be called on each connector in turn.
|
# get_num_new_matched_tokens and update_state_after_alloc will be called
|
||||||
# neither of them have hits so update_state_after_alloc won't be called.
|
# on each connector in turn.
|
||||||
assert events["storage1"][:3] == [
|
assert events["storage1-SCHEDULER"][:3] == [
|
||||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
'get_num_new_matched_tokens 0',
|
||||||
'bind_connector_metadata'
|
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||||
]
|
]
|
||||||
assert events["storage2"][:3] == [
|
assert events["storage1-WORKER"][:5] == [
|
||||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||||
'bind_connector_metadata'
|
'wait_for_layer_load', 'save_kv_layer'
|
||||||
|
]
|
||||||
|
assert events["storage2-SCHEDULER"][:3] == [
|
||||||
|
'get_num_new_matched_tokens 0',
|
||||||
|
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||||
|
]
|
||||||
|
assert events["storage2-WORKER"][:5] == [
|
||||||
|
'register_kv_caches', 'bind_connector_metadata', 'start_load_kv',
|
||||||
|
'wait_for_layer_load', 'save_kv_layer'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||||
@ -182,16 +200,16 @@ def test_multi_shared_storage_connector_consistency():
|
|||||||
|
|
||||||
events = get_connector_events()
|
events = get_connector_events()
|
||||||
# get_num_new_matched_tokens will return new tokens from the first
|
# get_num_new_matched_tokens will return new tokens from the first
|
||||||
# connector so update_state_after_alloc will be called once blocks
|
# connector so update_state_after_alloc will be with allocated blocks
|
||||||
# are allocated for the first connector.
|
# on that one but with zero blocks for others (first nonzero match is
|
||||||
# get_num_new_matched_tokens *won't* be called on the second connector
|
# chosen).
|
||||||
# in this case.
|
assert events["storage1-SCHEDULER"][:3] == [
|
||||||
assert events["storage1"][:4] == [
|
'get_num_new_matched_tokens 0',
|
||||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
|
||||||
'build_connector_meta', 'bind_connector_metadata'
|
|
||||||
]
|
]
|
||||||
assert events["storage2"][:2] == [
|
assert events["storage2-SCHEDULER"][:3] == [
|
||||||
'build_connector_meta', 'bind_connector_metadata'
|
'get_num_new_matched_tokens 0',
|
||||||
|
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Delete storage1 connector state
|
# Delete storage1 connector state
|
||||||
@ -205,17 +223,17 @@ def test_multi_shared_storage_connector_consistency():
|
|||||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||||
|
|
||||||
events = get_connector_events()
|
events = get_connector_events()
|
||||||
# get_num_new_matched_tokens will be called for the first connector but it
|
# get_num_new_matched_tokens will be called for both connectors but will
|
||||||
# won't have a hit so update_state_after_alloc won't be called.
|
# return 0 from the first connector, but the second connector should have
|
||||||
# get_num_new_matched_tokens will also be called on the second connector,
|
# a hit, so update_state_after_alloc will only be called with allocated
|
||||||
# but it should have a hit so update_state_after_alloc will be called.
|
# blocks for the second connector.
|
||||||
assert events["storage1"][:3] == [
|
assert events["storage1-SCHEDULER"][:3] == [
|
||||||
'get_num_new_matched_tokens', 'build_connector_meta',
|
'get_num_new_matched_tokens 0',
|
||||||
'bind_connector_metadata'
|
'update_state_after_alloc num_blocks=0 0', 'build_connector_meta'
|
||||||
]
|
]
|
||||||
assert events["storage2"][:4] == [
|
assert events["storage2-SCHEDULER"][:3] == [
|
||||||
'get_num_new_matched_tokens', 'update_state_after_alloc',
|
'get_num_new_matched_tokens 0',
|
||||||
'build_connector_meta', 'bind_connector_metadata'
|
'update_state_after_alloc num_blocks=7 96', 'build_connector_meta'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
|
|||||||
@ -12,12 +12,12 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -51,8 +51,9 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
self._connectors.append(
|
self._connectors.append(
|
||||||
KVConnectorFactory.create_connector_v1(temp_config, role))
|
KVConnectorFactory.create_connector_v1(temp_config, role))
|
||||||
|
|
||||||
# A mapping from request id to the connector that is assigned to it.
|
# A mapping from request id to the index of the connector chosen to
|
||||||
self._requests_to_connector: dict[str, KVConnectorBase_V1] = {}
|
# load the request from (if any).
|
||||||
|
self._requests_to_connector: dict[str, int] = {}
|
||||||
|
|
||||||
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
# Keeps track of *additional* remaining async saves (beyond 1) to be
|
||||||
# finished per request. Not needed for async loads since we only allow
|
# finished per request. Not needed for async loads since we only allow
|
||||||
@ -136,25 +137,31 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
request: "Request",
|
request: "Request",
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
) -> tuple[int, bool]:
|
) -> tuple[int, bool]:
|
||||||
for c in self._connectors:
|
to_return = (0, False)
|
||||||
|
for i, c in enumerate(self._connectors):
|
||||||
toks, load_async = c.get_num_new_matched_tokens(
|
toks, load_async = c.get_num_new_matched_tokens(
|
||||||
request, num_computed_tokens)
|
request, num_computed_tokens)
|
||||||
# The first connector that has new matched tokens will be assigned
|
# The first connector that has new matched tokens will be assigned
|
||||||
# to this request.
|
# to this request.
|
||||||
if toks > 0:
|
if to_return[0] == 0 and toks > 0:
|
||||||
self._requests_to_connector[request.request_id] = c
|
self._requests_to_connector[request.request_id] = i
|
||||||
return toks, load_async
|
to_return = (toks, load_async)
|
||||||
return 0, False
|
return to_return
|
||||||
|
|
||||||
def update_state_after_alloc(self, request: "Request",
|
def update_state_after_alloc(self, request: "Request",
|
||||||
blocks: "KVCacheBlocks",
|
blocks: "KVCacheBlocks",
|
||||||
num_external_tokens: int):
|
num_external_tokens: int):
|
||||||
# If the request is not assigned to any connector, we do nothing.
|
chosen_connector = self._requests_to_connector.get(
|
||||||
if request.request_id not in self._requests_to_connector:
|
request.request_id, -1)
|
||||||
return
|
for i, c in enumerate(self._connectors):
|
||||||
# We assume that the request is assigned to only one connector.
|
if i == chosen_connector:
|
||||||
c = self._requests_to_connector.pop(request.request_id)
|
# Forward call to the chosen connector (if any).
|
||||||
c.update_state_after_alloc(request, blocks, num_external_tokens)
|
c.update_state_after_alloc(request, blocks,
|
||||||
|
num_external_tokens)
|
||||||
|
else:
|
||||||
|
# Call with empty blocks for other connectors.
|
||||||
|
c.update_state_after_alloc(request,
|
||||||
|
KVCacheBlocks.create_empty(), 0)
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self,
|
||||||
@ -170,7 +177,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
def request_finished(
|
def request_finished(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
blocks: "KVCacheBlocks",
|
blocks: list[int],
|
||||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
async_saves = 0
|
async_saves = 0
|
||||||
kv_txfer_params = None
|
kv_txfer_params = None
|
||||||
@ -187,4 +194,8 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
kv_txfer_params = txfer_params
|
kv_txfer_params = txfer_params
|
||||||
if async_saves > 1:
|
if async_saves > 1:
|
||||||
self._extra_async_saves[request.request_id] = async_saves - 1
|
self._extra_async_saves[request.request_id] = async_saves - 1
|
||||||
|
|
||||||
|
# Clean up other state for this request.
|
||||||
|
self._requests_to_connector.pop(request.request_id, None)
|
||||||
|
|
||||||
return async_saves > 0, kv_txfer_params
|
return async_saves > 0, kv_txfer_params
|
||||||
|
|||||||
@ -221,15 +221,6 @@ class NixlConnectorScheduler:
|
|||||||
if count > 0:
|
if count > 0:
|
||||||
return count, True
|
return count, True
|
||||||
|
|
||||||
# NOTE: if count is 0 here, we have less than block_size
|
|
||||||
# tokens to pull after subtracting the local prefix cache hit.
|
|
||||||
# The remote only sends fully computed blocks, so there is
|
|
||||||
# nothing to transfer but we still need to notify the
|
|
||||||
# prefill worker so that the remote blocks are freed.
|
|
||||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
|
||||||
"remote_port")):
|
|
||||||
self._reqs_need_recv[request.request_id] = (request, [])
|
|
||||||
|
|
||||||
# No remote prefill for this request.
|
# No remote prefill for this request.
|
||||||
return 0, False
|
return 0, False
|
||||||
|
|
||||||
@ -247,9 +238,14 @@ class NixlConnectorScheduler:
|
|||||||
if params.get("remote_block_ids"):
|
if params.get("remote_block_ids"):
|
||||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||||
"remote_port")):
|
"remote_port")):
|
||||||
|
# If remote_blocks and num_external_tokens = 0, we have
|
||||||
|
# a full prefix cache hit on the D worker. We need to call
|
||||||
|
# send_notif in _read_blocks to free the memory on the P.
|
||||||
|
local_block_ids = (blocks.get_unhashed_block_ids()
|
||||||
|
if num_external_tokens > 0 else [])
|
||||||
# Get unhashed blocks to pull from remote.
|
# Get unhashed blocks to pull from remote.
|
||||||
self._reqs_need_recv[request.request_id] = (
|
self._reqs_need_recv[request.request_id] = (
|
||||||
request, blocks.get_unhashed_block_ids())
|
request, local_block_ids)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Got invalid KVTransferParams: %s. This "
|
"Got invalid KVTransferParams: %s. This "
|
||||||
@ -268,15 +264,6 @@ class NixlConnectorScheduler:
|
|||||||
# Loop through scheduled reqs and convert to ReqMeta.
|
# Loop through scheduled reqs and convert to ReqMeta.
|
||||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||||
assert req.kv_transfer_params is not None
|
assert req.kv_transfer_params is not None
|
||||||
# For the case where there are no remote blocks to pull
|
|
||||||
# (block_ids is empty), we don't need to schedule
|
|
||||||
# an async read on the worker side.
|
|
||||||
if not block_ids:
|
|
||||||
logger.debug(
|
|
||||||
"Skipping adding request %s to NixlConnectorMetadata, "
|
|
||||||
"as there are no remote blocks to pull", req_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
meta.add_new_req(
|
meta.add_new_req(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
@ -660,26 +647,26 @@ class NixlConnectorWorker:
|
|||||||
|
|
||||||
# Number of D TP workers reading from a single P TP worker. This is
|
# Number of D TP workers reading from a single P TP worker. This is
|
||||||
# 1 when P and D `--tensor-parallel-size` match.
|
# 1 when P and D `--tensor-parallel-size` match.
|
||||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
|
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
|
||||||
"Local TP size must be divisible by remote TP size."
|
"Local TP size must be divisible by remote TP size.")
|
||||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
||||||
assert tp_ratio > 0, "Decode TP cannot be smaller than"
|
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||||
" prefill TP"
|
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
# With MLA the only difference is in the number of blocks.
|
# With MLA the only difference is in the number of blocks.
|
||||||
remote_block_size = nixl_agent_meta.block_len / (
|
remote_block_size = nixl_agent_meta.block_len // (
|
||||||
self.slot_size_bytes)
|
self.slot_size_bytes)
|
||||||
assert self.block_len == nixl_agent_meta.block_len
|
assert self.block_len == nixl_agent_meta.block_len
|
||||||
else:
|
else:
|
||||||
remote_block_size = nixl_agent_meta.block_len / (
|
remote_block_size = nixl_agent_meta.block_len // (
|
||||||
self.slot_size_bytes * tp_ratio)
|
self.slot_size_bytes * tp_ratio)
|
||||||
|
|
||||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
|
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||||
"Remote P worker KV layer cache must be of shape [2, N, \
|
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||||
|
)
|
||||||
|
|
||||||
assert self.block_size == remote_block_size, "Remote P worker with \
|
assert self.block_size == remote_block_size, "Remote P worker with "
|
||||||
different block size is not supported"
|
"different block size is not supported"
|
||||||
|
|
||||||
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
||||||
|
|
||||||
@ -712,9 +699,9 @@ class NixlConnectorWorker:
|
|||||||
# (addr, len, device id)
|
# (addr, len, device id)
|
||||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Created %s blocks for dst engine %s with remote rank %s and " \
|
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||||
"local rank %s",
|
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||||
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
|
self.tp_rank)
|
||||||
|
|
||||||
# Register with NIXL.
|
# Register with NIXL.
|
||||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||||
|
|||||||
@ -424,11 +424,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
break
|
break
|
||||||
|
|
||||||
# KVConnector: update internal state after allocation.
|
# KVTransfer: the connector uses this info to determine
|
||||||
|
# if a load is needed. Note that
|
||||||
# This information is used to determine if a load is
|
# This information is used to determine if a load is
|
||||||
# needed for this request.
|
# needed for this request.
|
||||||
if num_external_computed_tokens:
|
if self.connector is not None:
|
||||||
assert self.connector is not None
|
|
||||||
self.connector.update_state_after_alloc(
|
self.connector.update_state_after_alloc(
|
||||||
request,
|
request,
|
||||||
new_computed_blocks + new_blocks,
|
new_computed_blocks + new_blocks,
|
||||||
@ -841,7 +841,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
}
|
}
|
||||||
|
|
||||||
finished_req_ids = self.finished_req_ids_dict
|
finished_req_ids = self.finished_req_ids_dict
|
||||||
if finished_req_ids is not None:
|
if finished_req_ids:
|
||||||
# Include ids of requests that finished since last outputs
|
# Include ids of requests that finished since last outputs
|
||||||
# were sent.
|
# were sent.
|
||||||
for client_index, finished_set in finished_req_ids.items():
|
for client_index, finished_set in finished_req_ids.items():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user