mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 23:37:16 +08:00
[V1] Prefix caching for vision language models (#11187)
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
c77eb8a33c
commit
bf8717ebae
@ -2,16 +2,23 @@
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
|
||||
|
||||
|
||||
def make_request(request_id, prompt_token_ids):
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
|
||||
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_placeholders={"image": mm_positions}
|
||||
if mm_positions else None,
|
||||
multi_modal_hashes=mm_hashes),
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
@ -38,6 +45,7 @@ def test_prefill():
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
assert len(req0.kv_block_hashes) == 3
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
@ -61,6 +69,7 @@ def test_prefill():
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
assert len(req1.kv_block_hashes) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
@ -90,6 +99,7 @@ def test_prefill():
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_block = manager.get_computed_blocks(req2)
|
||||
assert len(req2.kv_block_hashes) == 3
|
||||
assert [b.block_id for b in computed_block] == [0, 1, 2]
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
||||
@ -416,3 +426,77 @@ def test_cache_blocks():
|
||||
)
|
||||
assert len(manager.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
|
||||
def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
"""
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
|
||||
common_token_ids = list(range(10)) + [-1] * 6
|
||||
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
|
||||
common_token_ids += [-1] * 16
|
||||
|
||||
common_mm_positions = [
|
||||
PlaceholderRange(offset=11, length=10),
|
||||
PlaceholderRange(offset=30, length=18),
|
||||
]
|
||||
common_mm_hashes = ["aaa", "bbb"]
|
||||
|
||||
# A unique image plus some text tokens.
|
||||
unique_token_ids = [-1] * 7 + [100] * 4
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
mm_positions = common_mm_positions + [
|
||||
PlaceholderRange(offset=48, length=7)
|
||||
]
|
||||
mm_hashes = common_mm_hashes + ["ccc"]
|
||||
req0 = make_request("0",
|
||||
all_token_ids,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks
|
||||
assert len(req0.kv_block_hashes) == 3
|
||||
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
|
||||
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
|
||||
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.append_slots(req0, 5)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(req0.kv_block_hashes) == 4
|
||||
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )
|
||||
|
||||
# Cache hit.
|
||||
unique_token_ids = [-1] * 7 + [200] * 5
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
mm_positions = common_mm_positions + [
|
||||
PlaceholderRange(offset=48, length=7)
|
||||
]
|
||||
mm_hashes = common_mm_hashes + ["ccc"]
|
||||
req1 = make_request("1",
|
||||
all_token_ids,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks) == 3
|
||||
|
||||
@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
|
||||
assert engine_args.enable_prefix_caching
|
||||
|
||||
|
||||
def test_defaults():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
|
||||
# Assert V1 defaults
|
||||
assert (engine_args.enable_prefix_caching
|
||||
), "V1 turns on prefix caching by default"
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(
|
||||
@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
|
||||
UsageContext.OPENAI_API_SERVER)
|
||||
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
|
||||
|
||||
|
||||
def test_prefix_cache_disabled_with_multimodel():
|
||||
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
assert not vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
@ -205,6 +205,7 @@ class EngineArgs:
|
||||
# by user.
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
|
||||
|
||||
# Override max_num_seqs if it's not set by user.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
|
||||
@ -1026,11 +1027,11 @@ class EngineArgs:
|
||||
device_config = DeviceConfig(device=self.device)
|
||||
model_config = self.create_model_config()
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
if self.enable_prefix_caching:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is currently not "
|
||||
"supported for multimodal models and has been disabled.")
|
||||
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
|
||||
and self.enable_prefix_caching):
|
||||
logger.warning("--enable-prefix-caching is currently not "
|
||||
"supported for multimodal models in v0 and "
|
||||
"has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
cache_config = CacheConfig(
|
||||
@ -1249,11 +1250,14 @@ class EngineArgs:
|
||||
# When no user override, set the default values based on the usage
|
||||
# context.
|
||||
# TODO(woosuk): Tune the default values for different hardware.
|
||||
if self.max_num_batched_tokens is None:
|
||||
if usage_context == UsageContext.LLM_CLASS:
|
||||
self.max_num_batched_tokens = 8192
|
||||
elif usage_context == UsageContext.OPENAI_API_SERVER:
|
||||
self.max_num_batched_tokens = 2048
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
if (self.max_num_batched_tokens is None
|
||||
and usage_context in default_max_num_batched_tokens):
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens[
|
||||
usage_context]
|
||||
logger.warning(
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, usage_context.value)
|
||||
@ -1263,9 +1267,6 @@ class EngineArgs:
|
||||
Override the EngineConfig's configs based on the usage context for V1.
|
||||
"""
|
||||
assert envs.VLLM_USE_V1, "V1 is not enabled"
|
||||
if engine_config.model_config.is_multimodal_model:
|
||||
# TODO (ywang96): Enable APC by default when VLM supports it.
|
||||
assert not engine_config.cache_config.enable_prefix_caching
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
|
||||
Placeholder ranges for the multi-modal data.
|
||||
"""
|
||||
|
||||
multi_modal_hashes: NotRequired[List[str]]
|
||||
"""
|
||||
The hashes of the multi-modal data.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
@ -177,6 +182,7 @@ def token_inputs(
|
||||
prompt: Optional[str] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
|
||||
multi_modal_hashes: Optional[List[str]] = None,
|
||||
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> TokenInputs:
|
||||
@ -191,6 +197,8 @@ def token_inputs(
|
||||
inputs["multi_modal_data"] = multi_modal_data
|
||||
if multi_modal_inputs is not None:
|
||||
inputs["multi_modal_inputs"] = multi_modal_inputs
|
||||
if multi_modal_hashes is not None:
|
||||
inputs["multi_modal_hashes"] = multi_modal_hashes
|
||||
if multi_modal_placeholders is not None:
|
||||
inputs["multi_modal_placeholders"] = multi_modal_placeholders
|
||||
if mm_processor_kwargs is not None:
|
||||
@ -295,6 +303,18 @@ class SingletonInputsAdapter:
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def multi_modal_hashes(self) -> List[str]:
|
||||
inputs = self.inputs
|
||||
|
||||
if inputs["type"] == "token":
|
||||
return inputs.get("multi_modal_hashes", [])
|
||||
|
||||
if inputs["type"] == "multimodal":
|
||||
return inputs.get("mm_hashes", [])
|
||||
|
||||
assert_never(inputs)
|
||||
|
||||
@cached_property
|
||||
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
|
||||
inputs = self.inputs
|
||||
|
||||
@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: NotRequired[List[str]]
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: MultiModalPlaceholderDict
|
||||
"""
|
||||
For each modality, information about the placeholder tokens in
|
||||
|
||||
@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock, hash_block_tokens,
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -83,10 +85,12 @@ class KVCacheManager:
|
||||
|
||||
computed_blocks = []
|
||||
|
||||
# TODO(rickyx): potentially we could cache this so we don't have to
|
||||
# recompute it every time.
|
||||
block_hashes = hash_request_tokens(self.block_size,
|
||||
request.all_token_ids)
|
||||
# The block hashes for the request may already be computed
|
||||
# if the request was preempted and resumed.
|
||||
if not request.kv_block_hashes:
|
||||
request.set_kv_block_hashes(
|
||||
hash_request_tokens(self.block_size, request))
|
||||
block_hashes = request.kv_block_hashes
|
||||
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
@ -242,14 +246,16 @@ class KVCacheManager:
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
||||
|
||||
self._cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=len(computed_blocks),
|
||||
# The new full blocks are the full blocks that are not computed.
|
||||
full_blocks=self.req_to_blocks[request.request_id]
|
||||
[len(computed_blocks):num_full_blocks],
|
||||
prev_block=computed_blocks[-1] if computed_blocks else None,
|
||||
)
|
||||
new_full_blocks = self.req_to_blocks[
|
||||
request.request_id][len(computed_blocks):num_full_blocks]
|
||||
if new_full_blocks:
|
||||
self._cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=len(computed_blocks),
|
||||
# The new full blocks are the full blocks that are not computed.
|
||||
full_blocks=new_full_blocks,
|
||||
prev_block=computed_blocks[-1] if computed_blocks else None,
|
||||
)
|
||||
|
||||
return new_blocks
|
||||
|
||||
@ -376,6 +382,8 @@ class KVCacheManager:
|
||||
full_blocks: The list of blocks to update hash metadata.
|
||||
prev_block: The previous block in the chain.
|
||||
"""
|
||||
num_cached_block_hashes = len(request.kv_block_hashes)
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
prev_block_hash_value = None
|
||||
if prev_block is not None:
|
||||
@ -387,17 +395,35 @@ class KVCacheManager:
|
||||
for i, blk in enumerate(full_blocks):
|
||||
blk_idx = blk_start_idx + i
|
||||
|
||||
block_tokens = request.all_token_ids[blk_idx *
|
||||
self.block_size:(blk_idx +
|
||||
1) *
|
||||
self.block_size]
|
||||
assert len(block_tokens) == self.block_size, (
|
||||
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
|
||||
f"at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
if blk_idx < num_cached_block_hashes:
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption). In this case we simply
|
||||
# reuse the block hash.
|
||||
block_hash = request.kv_block_hashes[blk_idx]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
start_token_idx = blk_idx * self.block_size
|
||||
end_token_idx = (blk_idx + 1) * self.block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == self.block_size, (
|
||||
f"Expected {self.block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
request.append_kv_block_hashes(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
"""KV-Cache Utilities."""
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockHashType(NamedTuple):
|
||||
"""Hash value of a block and the token IDs in the block.
|
||||
The reason we keep a tuple of token IDs is to make sure no hash
|
||||
collision happens when the hash value is the same.
|
||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
||||
The reason we keep a tuple of token IDs and extra keys is to make sure
|
||||
no hash collision happens when the hash value is the same.
|
||||
"""
|
||||
# Hash value of the block in an integer.
|
||||
hash_value: int
|
||||
# Token IDs in the block.
|
||||
token_ids: Tuple[int, ...]
|
||||
# Extra keys for the block.
|
||||
extra_keys: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
|
||||
return ret
|
||||
|
||||
|
||||
def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
curr_block_token_ids: Sequence[int]) -> BlockHashType:
|
||||
def generate_block_hash_extra_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int,
|
||||
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
|
||||
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
|
||||
indicate a mm input contained in the block and its starting offset in
|
||||
the block tokens.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
start_token_idx: The start token index of the block.
|
||||
end_token_idx: The end token index of the block.
|
||||
start_mm_idx: The start multi-modal index of the block.
|
||||
|
||||
Returns:
|
||||
A tuple of extra keys and the next multi-modal index.
|
||||
"""
|
||||
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if not mm_positions:
|
||||
return None, start_mm_idx
|
||||
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
"The number of multi-modal positions and hashes must match. This "
|
||||
"is likely because you do not enable MM preprocessor hashing. "
|
||||
"Please set mm_cache_preprocessor=True.")
|
||||
|
||||
# Note that we assume mm_positions is sorted by offset.
|
||||
# We do not need to check all mm inputs if the start token index is out of
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
if mm_positions[-1]["offset"] + mm_positions[-1][
|
||||
"length"] < start_token_idx:
|
||||
return None, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
if start_mm_idx < 0:
|
||||
assert -start_mm_idx <= len(mm_positions)
|
||||
start_mm_idx = len(mm_positions) + start_mm_idx
|
||||
|
||||
extra_keys = []
|
||||
curr_mm_idx = start_mm_idx
|
||||
while mm_positions and curr_mm_idx < len(mm_positions):
|
||||
assert mm_hashes[curr_mm_idx] is not None
|
||||
offset = mm_positions[curr_mm_idx]["offset"]
|
||||
length = mm_positions[curr_mm_idx]["length"]
|
||||
if end_token_idx > offset:
|
||||
if start_token_idx > offset + length:
|
||||
# This block has passed the current mm input.
|
||||
curr_mm_idx += 1
|
||||
continue
|
||||
|
||||
# The block contains the current mm input.
|
||||
mm_start = max(0, start_token_idx - offset)
|
||||
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
|
||||
if end_token_idx >= offset + length:
|
||||
# If this block contains the end of the current mm input,
|
||||
# move to the next mm input as this block may also contain
|
||||
# the next mm input.
|
||||
curr_mm_idx += 1
|
||||
else:
|
||||
# Otherwise this block is done with mm inputs.
|
||||
break
|
||||
else:
|
||||
# This block has not reached the current mm input.
|
||||
break
|
||||
return tuple(extra_keys), curr_mm_idx
|
||||
|
||||
|
||||
def hash_block_tokens(
|
||||
parent_block_hash: Optional[int],
|
||||
curr_block_token_ids: Sequence[int],
|
||||
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||
@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
if this is the first block.
|
||||
curr_block_token_ids: A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
extra_keys: Extra keys for the block.
|
||||
|
||||
Returns:
|
||||
The hash value of the block and the token ids in the block.
|
||||
The entire tuple is used as the hash key of the block.
|
||||
"""
|
||||
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
||||
tuple(curr_block_token_ids))
|
||||
tuple(curr_block_token_ids), extra_keys)
|
||||
|
||||
|
||||
def hash_request_tokens(block_size: int,
|
||||
token_ids: Sequence[int]) -> List[BlockHashType]:
|
||||
request: Request) -> List[BlockHashType]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
|
||||
Args:
|
||||
block_size: The size of each block.
|
||||
token_ids: A sequence of token ids in the request.
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
The list of computed hash values.
|
||||
"""
|
||||
token_ids = request.all_token_ids
|
||||
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
|
||||
if mm_positions and len(mm_positions) != len(mm_hashes):
|
||||
raise ValueError(
|
||||
"The number of multi-modal positions and hashes must match.")
|
||||
|
||||
# TODO: Extend this to support other features such as LoRA.
|
||||
need_extra_keys = bool(mm_positions)
|
||||
extra_keys = None
|
||||
curr_mm_idx = 0
|
||||
|
||||
ret = []
|
||||
parent_block_hash_value = None
|
||||
for start in range(0, len(token_ids), block_size):
|
||||
@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
|
||||
# Do not hash the block if it is not full.
|
||||
if len(block_token_ids) < block_size:
|
||||
break
|
||||
|
||||
# Add extra keys if the block is a multi-modal block.
|
||||
if need_extra_keys:
|
||||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash_value,
|
||||
block_token_ids)
|
||||
block_token_ids, extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
|
||||
@ -516,6 +516,7 @@ class NewRequestData:
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
mm_inputs: List["MultiModalKwargs"]
|
||||
mm_hashes: List[str]
|
||||
mm_positions: List["PlaceholderRange"]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
@ -533,6 +534,7 @@ class NewRequestData:
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt,
|
||||
mm_inputs=request.mm_inputs,
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
block_ids=block_ids,
|
||||
|
||||
@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
|
||||
self.client_aborted_requests: List[str] = []
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(vllm_config.model_config,
|
||||
vllm_config.lora_config, self.tokenizer,
|
||||
input_registry)
|
||||
self.processor = Processor(
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
)
|
||||
|
||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.detokenizer = Detokenizer(
|
||||
|
||||
@ -65,7 +65,8 @@ class EngineCore:
|
||||
|
||||
self._last_logging_time = time.time()
|
||||
|
||||
self.mm_input_mapper_server = MMInputMapperServer()
|
||||
self.mm_input_mapper_server = MMInputMapperServer(
|
||||
vllm_config.model_config)
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
cache_config: CacheConfig) -> Tuple[int, int]:
|
||||
@ -98,9 +99,8 @@ class EngineCore:
|
||||
# MM mapper, so anything that has a hash must have a HIT cache
|
||||
# entry here as well.
|
||||
assert request.mm_inputs is not None
|
||||
request.mm_inputs, request.mm_hashes = (
|
||||
self.mm_input_mapper_server.process_inputs(
|
||||
request.mm_inputs, request.mm_hashes))
|
||||
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
|
||||
|
||||
@ -55,9 +55,12 @@ class LLMEngine:
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(vllm_config.model_config,
|
||||
vllm_config.lora_config, self.tokenizer,
|
||||
input_registry, mm_registry)
|
||||
self.processor = Processor(model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
mm_registry=mm_registry)
|
||||
|
||||
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
|
||||
self.detokenizer = Detokenizer(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import PIL
|
||||
from blake3 import blake3
|
||||
@ -42,6 +42,8 @@ class MMInputMapperClient:
|
||||
model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
# Init cache
|
||||
self.use_cache = model_config.mm_cache_preprocessor
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
@ -61,7 +63,7 @@ class MMInputMapperClient:
|
||||
mm_hashes: Optional[List[str]],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
|
||||
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
|
||||
) -> List[MultiModalKwargs]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
@ -70,26 +72,21 @@ class MMInputMapperClient:
|
||||
else:
|
||||
num_inputs = len(precomputed_mm_inputs)
|
||||
|
||||
# Check if hash is enabled
|
||||
use_hash = mm_hashes is not None
|
||||
if use_hash:
|
||||
# Sanity
|
||||
if self.use_cache:
|
||||
assert mm_hashes is not None
|
||||
assert num_inputs == len(
|
||||
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
|
||||
num_inputs, len(mm_hashes))
|
||||
assert num_inputs == len(mm_hashes)
|
||||
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_hashes: Optional[List[str]] = [] if use_hash else None
|
||||
ret_inputs: List[MultiModalKwargs] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
|
||||
|
||||
mm_hash = None
|
||||
mm_input = None
|
||||
if use_hash:
|
||||
if self.use_cache:
|
||||
assert mm_hashes is not None
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
@ -106,7 +103,7 @@ class MMInputMapperClient:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if use_hash:
|
||||
if self.use_cache:
|
||||
# Add to cache
|
||||
assert mm_hash is not None
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
@ -114,18 +111,15 @@ class MMInputMapperClient:
|
||||
self.mm_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
if use_hash:
|
||||
assert mm_hash is not None
|
||||
assert ret_hashes is not None
|
||||
ret_hashes.append(mm_hash)
|
||||
ret_inputs.append(mm_input)
|
||||
|
||||
return ret_inputs, ret_hashes
|
||||
return ret_inputs
|
||||
|
||||
|
||||
class MMInputMapperServer:
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(self, model_config):
|
||||
self.use_cache = model_config.mm_cache_preprocessor
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
def process_inputs(
|
||||
@ -135,6 +129,9 @@ class MMInputMapperServer:
|
||||
) -> List[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
return mm_inputs
|
||||
|
||||
full_mm_inputs = []
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
assert mm_hash is not None
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import LoRAConfig, ModelConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
@ -23,6 +23,7 @@ class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
@ -45,8 +46,9 @@ class Processor:
|
||||
self.mm_input_mapper_client = MMInputMapperClient(model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.mm_hasher = MMHasher(
|
||||
) if model_config.mm_cache_preprocessor else None
|
||||
self.use_hash = model_config.mm_cache_preprocessor or \
|
||||
cache_config.enable_prefix_caching
|
||||
self.mm_hasher = MMHasher()
|
||||
|
||||
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
|
||||
# This ideally should releases the GIL, so we should not block the
|
||||
@ -77,7 +79,7 @@ class Processor:
|
||||
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.mm_hasher is not None:
|
||||
if self.use_hash:
|
||||
mm_hashes = self.mm_hasher.hash(prompt)
|
||||
|
||||
# Process inputs.
|
||||
@ -118,7 +120,7 @@ class Processor:
|
||||
# Apply MM mapper
|
||||
mm_inputs = None
|
||||
if len(decoder_inputs.multi_modal_data) > 0:
|
||||
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
|
||||
mm_inputs = self.mm_input_mapper_client.process_inputs(
|
||||
decoder_inputs.multi_modal_data, mm_hashes,
|
||||
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType
|
||||
|
||||
|
||||
class Request:
|
||||
|
||||
@ -45,6 +48,7 @@ class Request:
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
# Multi-modal input metadata.
|
||||
mm_positions = self.inputs.multi_modal_placeholders
|
||||
if mm_positions:
|
||||
# FIXME(woosuk): Support other modalities.
|
||||
@ -56,6 +60,12 @@ class Request:
|
||||
if self.inputs.multi_modal_inputs:
|
||||
self.mm_inputs = self.inputs.multi_modal_inputs
|
||||
|
||||
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
|
||||
|
||||
# Cache the computed kv block hashes of the request to avoid
|
||||
# recomputing.
|
||||
self._kv_block_hashes: List[BlockHashType] = []
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
return cls(
|
||||
@ -65,6 +75,7 @@ class Request:
|
||||
prompt=request.prompt,
|
||||
multi_modal_data=None,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
mm_processor_kwargs=None,
|
||||
),
|
||||
@ -121,6 +132,17 @@ class Request:
|
||||
num_tokens = self.mm_positions[input_id]["length"]
|
||||
return num_tokens
|
||||
|
||||
@property
|
||||
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
|
||||
# Prevent directly appending to the kv_block_hashes.
|
||||
return ConstantList(self._kv_block_hashes)
|
||||
|
||||
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
|
||||
self._kv_block_hashes = value
|
||||
|
||||
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
|
||||
self._kv_block_hashes.append(block_hash)
|
||||
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a request."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user