[V1] Prefix caching for vision language models (#11187)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Cody Yu 2024-12-17 16:37:59 -08:00 committed by GitHub
parent c77eb8a33c
commit bf8717ebae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 341 additions and 97 deletions

View File

@ -2,16 +2,23 @@
import pytest
from vllm.inputs import token_inputs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
def make_request(request_id, prompt_token_ids):
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None):
return Request(
request_id=request_id,
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
multi_modal_placeholders={"image": mm_positions}
if mm_positions else None,
multi_modal_hashes=mm_hashes),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
@ -38,6 +45,7 @@ def test_prefill():
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
@ -416,3 +426,77 @@ def test_cache_blocks():
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None
def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids = list(range(10)) + [-1] * 6
common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
common_token_ids += [-1] * 16
common_mm_positions = [
PlaceholderRange(offset=11, length=10),
PlaceholderRange(offset=30, length=18),
]
common_mm_hashes = ["aaa", "bbb"]
# A unique image plus some text tokens.
unique_token_ids = [-1] * 7 + [100] * 4
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == (("aaa", 0), )
assert req0.kv_block_hashes[1].extra_keys == (("aaa", 5), ("bbb", 0))
assert req0.kv_block_hashes[2].extra_keys == (("bbb", 2), )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys.
assert len(req0.kv_block_hashes) == 4
assert req0.kv_block_hashes[3].extra_keys == (("ccc", 0), )
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
all_token_ids = common_token_ids + unique_token_ids
mm_positions = common_mm_positions + [
PlaceholderRange(offset=48, length=7)
]
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3

View File

@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert engine_args.enable_prefix_caching
def test_defaults():
engine_args = EngineArgs(model="facebook/opt-125m")
# Assert V1 defaults
assert (engine_args.enable_prefix_caching
), "V1 turns on prefix caching by default"
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(
@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext.OPENAI_API_SERVER)
assert vllm_config.scheduler_config.max_num_seqs == 1024
assert vllm_config.scheduler_config.max_num_batched_tokens == 2048
def test_prefix_cache_disabled_with_multimodel():
engine_args = EngineArgs(model="llava-hf/llava-1.5-7b-hf")
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
assert not vllm_config.cache_config.enable_prefix_caching

View File

@ -205,6 +205,7 @@ class EngineArgs:
# by user.
if self.enable_prefix_caching is None:
self.enable_prefix_caching = bool(envs.VLLM_USE_V1)
# Override max_num_seqs if it's not set by user.
if self.max_num_seqs is None:
self.max_num_seqs = 256 if not envs.VLLM_USE_V1 else 1024
@ -1026,11 +1027,11 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config()
if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
if (model_config.is_multimodal_model and not envs.VLLM_USE_V1
and self.enable_prefix_caching):
logger.warning("--enable-prefix-caching is currently not "
"supported for multimodal models in v0 and "
"has been disabled.")
self.enable_prefix_caching = False
cache_config = CacheConfig(
@ -1249,11 +1250,14 @@ class EngineArgs:
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
if self.max_num_batched_tokens is None:
if usage_context == UsageContext.LLM_CLASS:
self.max_num_batched_tokens = 8192
elif usage_context == UsageContext.OPENAI_API_SERVER:
self.max_num_batched_tokens = 2048
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.warning(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, usage_context.value)
@ -1263,9 +1267,6 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert envs.VLLM_USE_V1, "V1 is not enabled"
if engine_config.model_config.is_multimodal_model:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert not engine_config.cache_config.enable_prefix_caching
@dataclass

View File

@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
"""
multi_modal_hashes: NotRequired[List[str]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
@ -177,6 +182,7 @@ def token_inputs(
prompt: Optional[str] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_inputs: Optional["MultiModalKwargs"] = None,
multi_modal_hashes: Optional[List[str]] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
@ -191,6 +197,8 @@ def token_inputs(
inputs["multi_modal_data"] = multi_modal_data
if multi_modal_inputs is not None:
inputs["multi_modal_inputs"] = multi_modal_inputs
if multi_modal_hashes is not None:
inputs["multi_modal_hashes"] = multi_modal_hashes
if multi_modal_placeholders is not None:
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never(inputs)
@cached_property
def multi_modal_hashes(self) -> List[str]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_hashes", [])
if inputs["type"] == "multimodal":
return inputs.get("mm_hashes", [])
assert_never(inputs)
@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs

View File

@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs: MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: NotRequired[List[str]]
"""The hashes of the multi-modal data."""
mm_placeholders: MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in

View File

@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request
@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if not request.kv_block_hashes:
request.set_kv_block_hashes(
hash_request_tokens(self.block_size, request))
block_hashes = request.kv_block_hashes
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
@ -242,14 +246,16 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks
@ -376,6 +382,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes = len(request.kv_block_hashes)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
@ -387,17 +395,35 @@ class KVCacheManager:
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = request.kv_block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value, block_tokens)
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
request.append_kv_block_hashes(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash

View File

@ -1,20 +1,25 @@
"""KV-Cache Utilities."""
from collections.abc import Sequence
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Tuple
from typing import Any, List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""Hash value of a block (int), the token IDs in the block, and extra keys.
The reason we keep a tuple of token IDs and extra keys is to make sure
no hash collision happens when the hash value is the same.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: Tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
@dataclass
@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return ret
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int]) -> BlockHashType:
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return None, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
return None, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
extra_keys = []
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
mm_start = max(0, start_token_idx - offset)
extra_keys.append((mm_hashes[curr_mm_idx], mm_start))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return tuple(extra_keys), curr_mm_idx
def hash_block_tokens(
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[Tuple[Any, ...]] = None) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
tuple(curr_block_token_ids))
tuple(curr_block_token_ids), extra_keys)
def hash_request_tokens(block_size: int,
token_ids: Sequence[int]) -> List[BlockHashType]:
request: Request) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
request: The request object.
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match.")
# TODO: Extend this to support other features such as LoRA.
need_extra_keys = bool(mm_positions)
extra_keys = None
curr_mm_idx = 0
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
# Add extra keys if the block is a multi-modal block.
if need_extra_keys:
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
block_token_ids, extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret

View File

@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids: List[int]
prompt: Optional[str]
mm_inputs: List["MultiModalKwargs"]
mm_hashes: List[str]
mm_positions: List["PlaceholderRange"]
sampling_params: SamplingParams
block_ids: List[int]
@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids=request.prompt_token_ids,
prompt=request.prompt,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,

View File

@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
self.client_aborted_requests: List[str] = []
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry)
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self.detokenizer = Detokenizer(

View File

@ -65,7 +65,8 @@ class EngineCore:
self._last_logging_time = time.time()
self.mm_input_mapper_server = MMInputMapperServer()
self.mm_input_mapper_server = MMInputMapperServer(
vllm_config.model_config)
def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
assert request.mm_inputs is not None
request.mm_inputs, request.mm_hashes = (
self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes))
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)

View File

@ -55,9 +55,12 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(vllm_config.model_config,
vllm_config.lora_config, self.tokenizer,
input_registry, mm_registry)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self.detokenizer = Detokenizer(

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
import PIL
from blake3 import blake3
@ -42,6 +42,8 @@ class MMInputMapperClient:
model_config)
self.mm_registry.init_mm_limits_per_prompt(model_config)
# Init cache
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes: Optional[List[str]],
mm_processor_kwargs: Optional[Dict[str, Any]],
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
) -> Tuple[List[MultiModalKwargs], Optional[List[str]]]:
) -> List[MultiModalKwargs]:
if precomputed_mm_inputs is None:
image_inputs = mm_data["image"]
if not isinstance(image_inputs, list):
@ -70,26 +72,21 @@ class MMInputMapperClient:
else:
num_inputs = len(precomputed_mm_inputs)
# Check if hash is enabled
use_hash = mm_hashes is not None
if use_hash:
# Sanity
if self.use_cache:
assert mm_hashes is not None
assert num_inputs == len(
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
num_inputs, len(mm_hashes))
assert num_inputs == len(mm_hashes)
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes: Optional[List[str]] = [] if use_hash else None
ret_inputs: List[MultiModalKwargs] = []
for input_id in range(num_inputs):
if self.mm_debug_cache_hit_ratio_steps is not None:
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
mm_hash = None
mm_input = None
if use_hash:
if self.use_cache:
assert mm_hashes is not None
mm_hash = mm_hashes[input_id]
mm_input = self.mm_cache.get(mm_hash)
@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs=mm_processor_kwargs,
)
if use_hash:
if self.use_cache:
# Add to cache
assert mm_hash is not None
self.mm_cache.put(mm_hash, mm_input)
@ -114,18 +111,15 @@ class MMInputMapperClient:
self.mm_cache_hits += 1
mm_input = None # Avoids sending mm_input to Server
if use_hash:
assert mm_hash is not None
assert ret_hashes is not None
ret_hashes.append(mm_hash)
ret_inputs.append(mm_input)
return ret_inputs, ret_hashes
return ret_inputs
class MMInputMapperServer:
def __init__(self, ):
def __init__(self, model_config):
self.use_cache = model_config.mm_cache_preprocessor
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
@ -135,6 +129,9 @@ class MMInputMapperServer:
) -> List[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
return mm_inputs
full_mm_inputs = []
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None

View File

@ -1,7 +1,7 @@
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from vllm.config import LoRAConfig, ModelConfig
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
@ -23,6 +23,7 @@ class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
@ -45,8 +46,9 @@ class Processor:
self.mm_input_mapper_client = MMInputMapperClient(model_config)
# Multi-modal hasher (for images)
self.mm_hasher = MMHasher(
) if model_config.mm_cache_preprocessor else None
self.use_hash = model_config.mm_cache_preprocessor or \
cache_config.enable_prefix_caching
self.mm_hasher = MMHasher()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
@ -77,7 +79,7 @@ class Processor:
# Compute MM hashes (if enabled)
mm_hashes = None
if self.mm_hasher is not None:
if self.use_hash:
mm_hashes = self.mm_hasher.hash(prompt)
# Process inputs.
@ -118,7 +120,7 @@ class Processor:
# Apply MM mapper
mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0:
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)

View File

@ -1,5 +1,5 @@
import enum
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
from vllm.lora.request import LoRARequest
@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import BlockHashType
class Request:
@ -45,6 +48,7 @@ class Request:
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.num_computed_tokens = 0
# Multi-modal input metadata.
mm_positions = self.inputs.multi_modal_placeholders
if mm_positions:
# FIXME(woosuk): Support other modalities.
@ -56,6 +60,12 @@ class Request:
if self.inputs.multi_modal_inputs:
self.mm_inputs = self.inputs.multi_modal_inputs
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self._kv_block_hashes: List[BlockHashType] = []
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
return cls(
@ -65,6 +75,7 @@ class Request:
prompt=request.prompt,
multi_modal_data=None,
multi_modal_inputs=request.mm_inputs,
multi_modal_hashes=request.mm_hashes,
multi_modal_placeholders=request.mm_placeholders,
mm_processor_kwargs=None,
),
@ -121,6 +132,17 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
def kv_block_hashes(self) -> ConstantList["BlockHashType"]:
# Prevent directly appending to the kv_block_hashes.
return ConstantList(self._kv_block_hashes)
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
self._kv_block_hashes = value
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
self._kv_block_hashes.append(block_hash)
class RequestStatus(enum.IntEnum):
"""Status of a request."""