mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 21:37:11 +08:00
[CORE] Support Prefix Caching with Prompt Embeds (#27219)
Signed-off-by: Andrew Sansom <andrew@protopia.ai>
This commit is contained in:
parent
243ed7d32e
commit
ff93cc8c84
@ -52,7 +52,7 @@ th:not(:first-child) {
|
||||
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](https://github.com/vllm-project/vllm/pull/4194)<sup>^</sup> | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
|
||||
| best-of | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ✅ | ✅ | | |
|
||||
| beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](https://github.com/vllm-project/vllm/issues/25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
|
||||
|
||||
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
|
||||
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
|
||||
@ -75,4 +75,4 @@ th:not(:first-child) {
|
||||
| multi-step | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/8477) | ✅ | ❌ | ✅ |
|
||||
| best-of | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
| beam-search | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ? | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
|
||||
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/25097) | ✅ |
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -32,6 +33,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
init_none_hash,
|
||||
is_kv_cache_spec_uniform,
|
||||
make_block_hash_with_group_id,
|
||||
tensor_data,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
@ -61,12 +63,13 @@ def _auto_init_hash_fn(request):
|
||||
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
prompt_token_ids: list[int] | None,
|
||||
block_size: int = 3,
|
||||
hash_fn: Callable = hash,
|
||||
mm_positions: list[PlaceholderRange] | None = None,
|
||||
mm_hashes: list[str] | None = None,
|
||||
cache_salt: str | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
):
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
@ -90,6 +93,7 @@ def make_request(
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
|
||||
@ -450,6 +454,52 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
assert next_mm_idx == 1
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_prompt_embeds():
|
||||
prompt_embeds = torch.randn(10, 3)
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=None,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
# Test with prompt embeds for the first block
|
||||
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
expected_embeds = prompt_embeds[0:5]
|
||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
||||
assert extra_keys == (expected_bytes,)
|
||||
|
||||
# Test with prompt embeds for the second block
|
||||
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
|
||||
expected_embeds = prompt_embeds[5:10]
|
||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
||||
assert extra_keys == (expected_bytes,)
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_different_prompt_embeds():
|
||||
prompt_embeds1 = torch.randn(10, 3)
|
||||
prompt_embeds2 = torch.randn(10, 3)
|
||||
request1 = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=None,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_embeds=prompt_embeds1,
|
||||
)
|
||||
request2 = make_request(
|
||||
request_id="1",
|
||||
prompt_token_ids=None,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_embeds=prompt_embeds2,
|
||||
)
|
||||
|
||||
extra_keys1, _ = generate_block_hash_extra_keys(request1, 0, 5, 0)
|
||||
extra_keys2, _ = generate_block_hash_extra_keys(request2, 0, 5, 0)
|
||||
assert extra_keys1 != extra_keys2
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_lora():
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
@ -1556,3 +1606,88 @@ def test_merge_mla_spec():
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]):
|
||||
block_size = 3
|
||||
num_tokens = 2 * block_size
|
||||
prompt_token_ids = [_ for _ in range(num_tokens)]
|
||||
hidden_size = 5
|
||||
prompt_embeds = torch.randn((num_tokens, hidden_size))
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
|
||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
||||
expected_hash1 = hash_fn(
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(prompt_token_ids[:block_size]),
|
||||
(block1_embeds_bytes,),
|
||||
)
|
||||
)
|
||||
assert block_hashes[0] == expected_hash1
|
||||
|
||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
||||
expected_hash2 = hash_fn(
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||
(block2_embeds_bytes,),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == expected_hash2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes]):
|
||||
block_size = 3
|
||||
num_tokens = 2 * block_size
|
||||
prompt_token_ids = [_ for _ in range(num_tokens)]
|
||||
hidden_size = 5
|
||||
prompt_embeds = torch.randn((num_tokens, hidden_size))
|
||||
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
],
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
|
||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
||||
expected_hash1 = hash_fn(
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(prompt_token_ids[:block_size]),
|
||||
("hash1", block1_embeds_bytes),
|
||||
)
|
||||
)
|
||||
assert block_hashes[0] == expected_hash1
|
||||
|
||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
||||
expected_hash2 = hash_fn(
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||
("hash2", block2_embeds_bytes),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == expected_hash2
|
||||
|
||||
@ -1743,16 +1743,6 @@ class EngineArgs:
|
||||
if model_config.runner_type != "pooling":
|
||||
self.enable_chunked_prefill = True
|
||||
|
||||
# TODO: When prefix caching supports prompt embeds inputs, this
|
||||
# check can be removed.
|
||||
if self.enable_prompt_embeds and self.enable_prefix_caching is not False:
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V1. Prefix caching has "
|
||||
"been disabled."
|
||||
)
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
# BlockHash represents the hash of a single KV-cache block used for
|
||||
# prefix caching. Treating it as a distinct type from `bytes` helps
|
||||
@ -461,11 +462,33 @@ def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
|
||||
return [request.lora_request.lora_name]
|
||||
|
||||
|
||||
def _gen_prompt_embeds_extra_hash_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int
|
||||
) -> list[bytes]:
|
||||
"""Generate extra keys related to prompt embeds for block hash computation.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
start_token_idx: The start token index of the block.
|
||||
end_token_idx: The end token index of the block.
|
||||
|
||||
Returns:
|
||||
Return prompt embeddings data of the request if it has prompt embeds.
|
||||
Return empty list otherwise.
|
||||
"""
|
||||
if request.prompt_embeds is None:
|
||||
return []
|
||||
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
|
||||
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
|
||||
return [embeds_bytes]
|
||||
|
||||
|
||||
def generate_block_hash_extra_keys(
|
||||
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
|
||||
) -> tuple[tuple[Any, ...] | None, int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs and request specific metadata (e.g., LoRA name).
|
||||
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
|
||||
data from prompt embeddings.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
@ -484,8 +507,13 @@ def generate_block_hash_extra_keys(
|
||||
cache_salt_keys: list[str] = (
|
||||
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
|
||||
)
|
||||
prompt_embeds_keys = _gen_prompt_embeds_extra_hash_keys(
|
||||
request, start_token_idx, end_token_idx
|
||||
)
|
||||
|
||||
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
|
||||
extra_keys: list[Any] = (
|
||||
lora_extra_keys + mm_extra_keys + cache_salt_keys + prompt_embeds_keys
|
||||
)
|
||||
|
||||
if not extra_keys:
|
||||
return None, new_start_mm_idx
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.v1.engine import UtilityResult
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -218,14 +219,14 @@ class MsgpackEncoder:
|
||||
) -> tuple[str, tuple[int, ...], int | memoryview]:
|
||||
assert self.aux_buffers is not None
|
||||
# view the tensor as a contiguous 1D array of bytes
|
||||
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
|
||||
arr_data = tensor_data(obj)
|
||||
if obj.nbytes < self.size_threshold:
|
||||
# Smaller tensors are encoded inline, just like ndarrays.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr.data)
|
||||
self.aux_buffers.append(arr_data)
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
|
||||
@ -396,3 +396,16 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||
|
||||
_PROFILER_FUNC = func
|
||||
return func(name)
|
||||
|
||||
|
||||
def tensor_data(tensor: torch.Tensor) -> memoryview:
|
||||
"""Get the raw data of a tensor as a uint8 memoryview, useful for
|
||||
serializing and hashing.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
|
||||
Returns:
|
||||
A memoryview of the tensor data as uint8.
|
||||
"""
|
||||
return tensor.flatten().contiguous().view(torch.uint8).numpy().data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user