mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-30 16:30:04 +08:00
[Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673)
Signed-off-by: linzebing <linzebing1995@gmail.com>
This commit is contained in:
parent
bba1042c6f
commit
82dfb12e52
@ -6,6 +6,8 @@ import msgspec
|
|||||||
import zmq
|
import zmq
|
||||||
from msgspec.msgpack import Decoder
|
from msgspec.msgpack import Decoder
|
||||||
|
|
||||||
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Types copied from vllm.distributed.kv_events
|
# Types copied from vllm.distributed.kv_events
|
||||||
@ -22,8 +24,8 @@ class KVCacheEvent(
|
|||||||
|
|
||||||
|
|
||||||
class BlockStored(KVCacheEvent):
|
class BlockStored(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[BlockHash]
|
||||||
parent_block_hash: Optional[int]
|
parent_block_hash: Optional[BlockHash]
|
||||||
token_ids: list[int]
|
token_ids: list[int]
|
||||||
block_size: int
|
block_size: int
|
||||||
lora_id: Optional[int]
|
lora_id: Optional[int]
|
||||||
@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
|
|||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[BlockHash]
|
||||||
medium: Optional[str]
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
|
|||||||
|
|
||||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||||
(None, bool, [1, 2, 3])])
|
(None, bool, [1, 2, 3])])
|
||||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
def test_sha256(input: tuple):
|
||||||
def test_sha256(input: tuple, output: int):
|
digest = sha256(input)
|
||||||
hash = sha256(input)
|
assert digest is not None
|
||||||
assert hash is not None
|
assert isinstance(digest, bytes)
|
||||||
assert isinstance(hash, int)
|
assert digest != b""
|
||||||
assert hash != 0
|
|
||||||
|
|
||||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
assert digest == hashlib.sha256(input_bytes).digest()
|
||||||
byteorder="big")
|
|
||||||
|
|
||||||
# hashing again, returns the same value
|
# hashing again, returns the same value
|
||||||
assert hash == sha256(input)
|
assert digest == sha256(input)
|
||||||
|
|
||||||
# hashing different input, returns different value
|
# hashing different input, returns different value
|
||||||
assert hash != sha256(input + (1, ))
|
assert digest != sha256(input + (1, ))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -6,20 +6,22 @@ from typing import Callable, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
from vllm.utils import GiB_bytes, sha256, sha256_cbor
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
# disable yapf here as it formats differently than isort such that both fail
|
# disable yapf here as it formats differently than isort such that both fail
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.v1.core.kv_cache_utils import (
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||||
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
is_kv_cache_type_uniform, make_block_hash_with_group_id,
|
||||||
|
unify_kv_cache_configs)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, KVCacheTensor,
|
KVCacheGroupSpec, KVCacheTensor,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
|
|||||||
sliding_window=sliding_window)
|
sliding_window=sliding_window)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_none_hash(monkeypatch, hash_fn):
|
def test_none_hash(monkeypatch, hash_fn):
|
||||||
import vllm.v1.core.kv_cache_utils
|
import vllm.v1.core.kv_cache_utils
|
||||||
|
|
||||||
@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
|
|||||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||||
assert reloaded_kv_cache_utils.NONE_HASH != 0
|
assert reloaded_kv_cache_utils.NONE_HASH != b""
|
||||||
|
|
||||||
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
|
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
|
|||||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
|
||||||
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
||||||
|
|
||||||
|
|
||||||
def test_kv_cache_block():
|
def test_kv_cache_block():
|
||||||
import vllm.v1.core.kv_cache_utils
|
|
||||||
|
|
||||||
# Test KVCacheBlock initialization
|
# Test KVCacheBlock initialization
|
||||||
block = KVCacheBlock(block_id=0)
|
block = KVCacheBlock(block_id=0)
|
||||||
@ -127,8 +128,7 @@ def test_kv_cache_block():
|
|||||||
assert block.ref_cnt == 0
|
assert block.ref_cnt == 0
|
||||||
|
|
||||||
# Test block hash setting and resetting
|
# Test block hash setting and resetting
|
||||||
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
|
block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
|
||||||
token_ids=(1, 2, 3))
|
|
||||||
block.block_hash = block_hash
|
block.block_hash = block_hash
|
||||||
assert block.block_hash == block_hash
|
assert block.block_hash == block_hash
|
||||||
|
|
||||||
@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
|||||||
assert next_mm_idx == 1
|
assert next_mm_idx == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_hash_block_tokens(hash_fn):
|
def test_hash_block_tokens(hash_fn):
|
||||||
import vllm.v1.core.kv_cache_utils
|
|
||||||
init_none_hash(hash_fn)
|
init_none_hash(hash_fn)
|
||||||
parent_block_hash = 123
|
parent_block_hash = BlockHash(b"123")
|
||||||
curr_block_token_ids = (1, 2, 3)
|
curr_block_token_ids = (1, 2, 3)
|
||||||
extra_keys = ("key1", "key2")
|
extra_keys = ("key1", "key2")
|
||||||
|
|
||||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||||
curr_block_token_ids, extra_keys)
|
curr_block_token_ids, extra_keys)
|
||||||
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
|
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
|
||||||
assert block_hash.hash_value == hash_fn(
|
assert block_hash == expected
|
||||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
|
||||||
assert block_hash.token_ids == curr_block_token_ids
|
|
||||||
assert block_hash.extra_keys == extra_keys
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_request_block_hasher(hash_fn):
|
def test_request_block_hasher(hash_fn):
|
||||||
import vllm.v1.core.kv_cache_utils
|
kv_cache_utils.init_none_hash(hash_fn)
|
||||||
init_none_hash(hash_fn)
|
|
||||||
request = make_request(
|
request = make_request(
|
||||||
request_id="0",
|
request_id="0",
|
||||||
prompt_token_ids=[_ for _ in range(6)],
|
prompt_token_ids=[_ for _ in range(6)],
|
||||||
@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
|
|||||||
|
|
||||||
block_hashes = request.block_hashes
|
block_hashes = request.block_hashes
|
||||||
assert len(block_hashes) == 2
|
assert len(block_hashes) == 2
|
||||||
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
|
assert block_hashes[0] == hash_fn(
|
||||||
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
|
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
|
||||||
|
assert block_hashes[1] == hash_fn(
|
||||||
# Check the first block
|
(block_hashes[0], (3, 4, 5), ("hash2", )))
|
||||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
|
||||||
assert block_hashes[0].extra_keys == ("hash1", )
|
|
||||||
|
|
||||||
# Check the second block
|
|
||||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
|
||||||
assert block_hashes[1].extra_keys == ("hash2", )
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_hash_tokens_different_mm_input(hash_fn):
|
def test_hash_tokens_different_mm_input(hash_fn):
|
||||||
init_none_hash(hash_fn)
|
init_none_hash(hash_fn)
|
||||||
|
|
||||||
@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
|||||||
assert block_hashes1[1] != block_hashes2[1]
|
assert block_hashes1[1] != block_hashes2[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||||
init_none_hash(hash_fn)
|
kv_cache_utils.init_none_hash(hash_fn)
|
||||||
|
|
||||||
request = make_request(
|
request = make_request(
|
||||||
request_id="0",
|
request_id="0",
|
||||||
@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
|||||||
block_hashes = request.block_hashes
|
block_hashes = request.block_hashes
|
||||||
|
|
||||||
assert len(block_hashes) == 2
|
assert len(block_hashes) == 2
|
||||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
assert block_hashes[0] == hash_fn(
|
||||||
assert block_hashes[0].extra_keys is None
|
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
|
||||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
|
||||||
assert block_hashes[1].extra_keys is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_metrics():
|
def test_metrics():
|
||||||
|
|||||||
@ -8,17 +8,19 @@ from typing import Callable, Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.v1.core.kv_cache_utils as kv_cache_utils
|
||||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256, sha256_cbor_64bit
|
from vllm.utils import sha256, sha256_cbor
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||||
KVCacheBlock,
|
get_block_hash, get_group_id,
|
||||||
get_request_block_hasher,
|
get_request_block_hasher,
|
||||||
hash_block_tokens, init_none_hash)
|
hash_block_tokens, init_none_hash,
|
||||||
|
make_block_hash_with_group_id)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, SlidingWindowSpec)
|
KVCacheGroupSpec, SlidingWindowSpec)
|
||||||
|
|
||||||
@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_prefill(hash_algo):
|
def test_prefill(hash_fn):
|
||||||
|
init_none_hash(hash_fn)
|
||||||
|
|
||||||
block_size = 16
|
block_size = 16
|
||||||
manager = KVCacheManager(
|
manager = KVCacheManager(
|
||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
@ -110,10 +114,6 @@ def test_prefill(hash_algo):
|
|||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# choose the hash function according to the parameter
|
|
||||||
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
|
|
||||||
sha256 if hash_algo == "sha256" else hash)
|
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
|
|
||||||
@ -137,10 +137,12 @@ def test_prefill(hash_algo):
|
|||||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||||
block_tokens)
|
block_tokens)
|
||||||
assert manager.block_pool.blocks[
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||||
block_id].block_hash.block_hash == block_hash
|
assert blk_hash is not None
|
||||||
|
assert get_block_hash(blk_hash) == block_hash
|
||||||
|
assert get_group_id(blk_hash) == 0
|
||||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||||
parent_block_hash = block_hash.hash_value
|
parent_block_hash = block_hash
|
||||||
|
|
||||||
# Check partial block metadata
|
# Check partial block metadata
|
||||||
for block_id in (4, ):
|
for block_id in (4, ):
|
||||||
@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
|
|||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
hash_fn = hash
|
hash_fn = sha256
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||||
@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
|
|||||||
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
||||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||||
block_tokens)
|
block_tokens)
|
||||||
for block_id in block_ids:
|
for group_id, block_id in enumerate(block_ids):
|
||||||
assert manager.block_pool.blocks[
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||||
block_id].block_hash.block_hash == block_hash
|
assert blk_hash is not None
|
||||||
|
assert get_block_hash(blk_hash) == block_hash
|
||||||
|
assert get_group_id(blk_hash) == group_id
|
||||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||||
parent_block_hash = block_hash.hash_value
|
parent_block_hash = block_hash
|
||||||
|
|
||||||
# Check partial block metadata
|
# Check partial block metadata
|
||||||
for block_id in (4, 8, 12):
|
for block_id in (4, 8, 12):
|
||||||
@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
|
|||||||
cached_block_hash_to_block_bak = copy.copy(
|
cached_block_hash_to_block_bak = copy.copy(
|
||||||
manager.block_pool.cached_block_hash_to_block)
|
manager.block_pool.cached_block_hash_to_block)
|
||||||
|
|
||||||
def test_partial_request_hit(request_id: str,
|
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||||
hash_to_evict: list[BlockHashWithGroupId],
|
|
||||||
expect_hit_length: int):
|
expect_hit_length: int):
|
||||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||||
block_size, hash)
|
block_size, sha256)
|
||||||
for hash_with_group_id in hash_to_evict:
|
for hash_with_group_id in hash_to_evict:
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
hash_with_group_id)
|
hash_with_group_id)
|
||||||
@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
|
|||||||
|
|
||||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||||
test_partial_request_hit("2", [
|
test_partial_request_hit("2", [
|
||||||
BlockHashWithGroupId(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
BlockHashWithGroupId(block_hashes[0], 2)
|
make_block_hash_with_group_id(block_hashes[0], 2)
|
||||||
], 3)
|
], 3)
|
||||||
|
|
||||||
# Evict the first block of full attention, makes total cache miss.
|
# Evict the first block of full attention, makes total cache miss.
|
||||||
test_partial_request_hit("3", [
|
test_partial_request_hit(
|
||||||
BlockHashWithGroupId(block_hashes[0], 0),
|
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
|
||||||
], 0)
|
|
||||||
|
|
||||||
# Evict the last block of all layers, reduces the hit length to 2.
|
# Evict the last block of all layers, reduces the hit length to 2.
|
||||||
test_partial_request_hit("4", [
|
test_partial_request_hit("4", [
|
||||||
BlockHashWithGroupId(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
BlockHashWithGroupId(block_hashes[2], 1),
|
make_block_hash_with_group_id(block_hashes[2], 1),
|
||||||
BlockHashWithGroupId(block_hashes[2], 2),
|
make_block_hash_with_group_id(block_hashes[2], 2),
|
||||||
], 2)
|
], 2)
|
||||||
|
|
||||||
# Evict the last block of full attention, reduces the hit length to 2.
|
# Evict the last block of full attention, reduces the hit length to 2.
|
||||||
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
|
test_partial_request_hit(
|
||||||
2)
|
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
|
||||||
|
|
||||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||||
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
|
test_partial_request_hit(
|
||||||
2)
|
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
|
||||||
|
|
||||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||||
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
|
test_partial_request_hit(
|
||||||
2)
|
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
|
||||||
|
|
||||||
# Evict different set of blocks for full attention and sliding window makes
|
# Evict different set of blocks for full attention and sliding window makes
|
||||||
# total cache miss.
|
# total cache miss.
|
||||||
@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
|
|||||||
# The cache hit length of sliding window is 2 * block_size.
|
# The cache hit length of sliding window is 2 * block_size.
|
||||||
# Then it is cache miss as the two type of layers have different hit length.
|
# Then it is cache miss as the two type of layers have different hit length.
|
||||||
test_partial_request_hit("8", [
|
test_partial_request_hit("8", [
|
||||||
BlockHashWithGroupId(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
BlockHashWithGroupId(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
BlockHashWithGroupId(block_hashes[0], 2),
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||||
], 0)
|
], 0)
|
||||||
|
|
||||||
|
|
||||||
@ -372,8 +374,8 @@ def test_prefill_plp():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
)
|
)
|
||||||
# the default hash function is hash
|
# the default hash function is sha256
|
||||||
hash_fn = hash
|
hash_fn = sha256
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
@ -404,10 +406,12 @@ def test_prefill_plp():
|
|||||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||||
block_tokens)
|
block_tokens)
|
||||||
assert manager.block_pool.blocks[
|
blk_hash = (manager.block_pool.blocks[block_id].block_hash)
|
||||||
block_id].block_hash.block_hash == block_hash
|
assert blk_hash is not None
|
||||||
|
assert get_block_hash(blk_hash) == block_hash
|
||||||
|
assert get_group_id(blk_hash) == 0
|
||||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||||
parent_block_hash = block_hash.hash_value
|
parent_block_hash = block_hash
|
||||||
|
|
||||||
# Check partial block metadata
|
# Check partial block metadata
|
||||||
for block_id in (4, ):
|
for block_id in (4, ):
|
||||||
@ -493,7 +497,7 @@ def test_decode():
|
|||||||
# Incomplete 1 block (7 tokens)
|
# Incomplete 1 block (7 tokens)
|
||||||
unique_token_ids = [3] * 7
|
unique_token_ids = [3] * 7
|
||||||
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
|
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
|
||||||
hash)
|
sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -538,7 +542,7 @@ def test_evict():
|
|||||||
)
|
)
|
||||||
|
|
||||||
last_token_id = 5 * 16 + 7
|
last_token_id = 5 * 16 + 7
|
||||||
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
|
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -550,7 +554,7 @@ def test_evict():
|
|||||||
# 3 blocks.
|
# 3 blocks.
|
||||||
req1 = make_request("1", list(range(last_token_id,
|
req1 = make_request("1", list(range(last_token_id,
|
||||||
last_token_id + 3 * 16)), block_size,
|
last_token_id + 3 * 16)), block_size,
|
||||||
hash)
|
sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -572,7 +576,7 @@ def test_evict():
|
|||||||
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
||||||
|
|
||||||
# Touch the first 2 blocks.
|
# Touch the first 2 blocks.
|
||||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
|
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||||
assert num_computed_tokens == 2 * 16
|
assert num_computed_tokens == 2 * 16
|
||||||
@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
|
|||||||
|
|
||||||
# Allocate 1 block and cache it.
|
# Allocate 1 block and cache it.
|
||||||
num_tokens = block_size * 1
|
num_tokens = block_size * 1
|
||||||
req = make_request("0", list(range(num_tokens)), block_size, hash)
|
req = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
|
|||||||
|
|
||||||
# Allocate a new block that's not full, make sure hash info on the
|
# Allocate a new block that's not full, make sure hash info on the
|
||||||
# block is cleared.
|
# block is cleared.
|
||||||
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
|
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
|
|
||||||
# Allocate a block and cache it.
|
# Allocate a block and cache it.
|
||||||
num_tokens = block_size * 1
|
num_tokens = block_size * 1
|
||||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
|
|
||||||
# Allocate another block.
|
# Allocate another block.
|
||||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
|
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
|
||||||
block_size, hash)
|
block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
|
|
||||||
# Now if we have a cache hit on the first block, we should evict the second
|
# Now if we have a cache hit on the first block, we should evict the second
|
||||||
# cached block rather than the first one.
|
# cached block rather than the first one.
|
||||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
|
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert len(computed_blocks.blocks[0]) == 1
|
assert len(computed_blocks.blocks[0]) == 1
|
||||||
assert computed_blocks.blocks[0][0].block_id == 1
|
assert computed_blocks.blocks[0][0].block_id == 1
|
||||||
@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
)
|
)
|
||||||
|
|
||||||
req1 = make_request("1", list(range(10)), block_size,
|
req1 = make_request("1", list(range(10)), block_size,
|
||||||
hash) # 2 blocks and some more
|
sha256) # 2 blocks and some more
|
||||||
|
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
|
|
||||||
# No caching.
|
# No caching.
|
||||||
req2 = make_request("2", list(range(16)), block_size,
|
req2 = make_request("2", list(range(16)), block_size,
|
||||||
hash) # shared prefix
|
sha256) # shared prefix
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
assert len(blocks.blocks[0]) == 4
|
assert len(blocks.blocks[0]) == 4
|
||||||
|
|
||||||
# New requests should not have any blocks.
|
# New requests should not have any blocks.
|
||||||
req3 = make_request("3", list(range(4)), block_size, hash)
|
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
assert not blocks
|
assert not blocks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||||
def test_cache_blocks(hash_fn):
|
def test_cache_blocks(hash_fn):
|
||||||
"""
|
"""
|
||||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||||
@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
|
|||||||
# Block 1/5: [4, 5, 6, 7]
|
# Block 1/5: [4, 5, 6, 7]
|
||||||
# Block 2/6: [8, 9, 10, 11]
|
# Block 2/6: [8, 9, 10, 11]
|
||||||
# Block 3/7: [12, 13]
|
# Block 3/7: [12, 13]
|
||||||
req = make_request("0", list(range(14)), block_size, hash)
|
req = make_request("0", list(range(14)), block_size, sha256)
|
||||||
|
|
||||||
# Cache the blocks for group 0.
|
# Cache the blocks for group 0.
|
||||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||||
@ -845,6 +849,8 @@ def test_mm_prefix_caching():
|
|||||||
"""
|
"""
|
||||||
This tests that the multi-modal prefix caching is correct.
|
This tests that the multi-modal prefix caching is correct.
|
||||||
"""
|
"""
|
||||||
|
kv_cache_utils.init_none_hash(sha256)
|
||||||
|
|
||||||
block_size = 16
|
block_size = 16
|
||||||
manager = KVCacheManager(
|
manager = KVCacheManager(
|
||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
@ -874,23 +880,30 @@ def test_mm_prefix_caching():
|
|||||||
req0 = make_request("0",
|
req0 = make_request("0",
|
||||||
all_token_ids,
|
all_token_ids,
|
||||||
block_size,
|
block_size,
|
||||||
hash,
|
sha256,
|
||||||
mm_positions=mm_positions,
|
mm_positions=mm_positions,
|
||||||
mm_hashes=mm_hashes)
|
mm_hashes=mm_hashes)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
|
||||||
# Completed block should have hashes with extra keys.
|
# Completed block should have hashes
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
block_hashes = req0.block_hashes
|
block_hashes = req0.block_hashes
|
||||||
assert len(block_hashes) == 3
|
assert len(block_hashes) == 3
|
||||||
assert block_hashes[0].extra_keys == ("aaa", )
|
assert block_hashes[0] == sha256(
|
||||||
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
|
||||||
assert block_hashes[2].extra_keys == ("bbb", )
|
("aaa", )))
|
||||||
|
assert block_hashes[1] == sha256(
|
||||||
|
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
|
||||||
|
("aaa", "bbb")))
|
||||||
|
assert block_hashes[2] == sha256(
|
||||||
|
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
|
||||||
|
("bbb", )))
|
||||||
|
|
||||||
blocks = manager.allocate_slots(req0, 59,
|
blocks = manager.allocate_slots(req0, 59,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
|
assert blocks is not None
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0.num_computed_tokens = 59
|
req0.num_computed_tokens = 59
|
||||||
|
|
||||||
@ -901,10 +914,10 @@ def test_mm_prefix_caching():
|
|||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||||
|
|
||||||
# The just completed block should have hashes with extra keys.
|
|
||||||
assert len(block_hashes) == 4
|
assert len(block_hashes) == 4
|
||||||
assert block_hashes[3].extra_keys == ("ccc", )
|
assert block_hashes[3] == sha256(
|
||||||
|
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
|
||||||
|
("ccc", )))
|
||||||
|
|
||||||
# Cache hit.
|
# Cache hit.
|
||||||
unique_token_ids = [-1] * 7 + [200] * 5
|
unique_token_ids = [-1] * 7 + [200] * 5
|
||||||
@ -916,7 +929,7 @@ def test_mm_prefix_caching():
|
|||||||
req1 = make_request("1",
|
req1 = make_request("1",
|
||||||
all_token_ids,
|
all_token_ids,
|
||||||
block_size,
|
block_size,
|
||||||
hash,
|
sha256,
|
||||||
mm_positions=mm_positions,
|
mm_positions=mm_positions,
|
||||||
mm_hashes=mm_hashes)
|
mm_hashes=mm_hashes)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
@ -929,6 +942,8 @@ def test_cache_key_salting():
|
|||||||
This tests that cache salts are applied during hashing and the cache
|
This tests that cache salts are applied during hashing and the cache
|
||||||
is separated cache as expected.
|
is separated cache as expected.
|
||||||
"""
|
"""
|
||||||
|
kv_cache_utils.init_none_hash(sha256)
|
||||||
|
|
||||||
block_size = 16
|
block_size = 16
|
||||||
manager = KVCacheManager(
|
manager = KVCacheManager(
|
||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
@ -939,21 +954,26 @@ def test_cache_key_salting():
|
|||||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||||
token_ids = common_token_ids + [3] * 11
|
token_ids = common_token_ids + [3] * 11
|
||||||
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
|
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
|
||||||
# Completed block should have hashes with extra keys.
|
# Completed block should have hashes
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
block_hashes = req0.block_hashes
|
block_hashes = req0.block_hashes
|
||||||
assert len(block_hashes) == 3
|
assert len(block_hashes) == 3
|
||||||
assert block_hashes[0].extra_keys == ("salt1", )
|
assert block_hashes[0] == sha256(
|
||||||
assert block_hashes[1].extra_keys is None
|
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
|
||||||
assert block_hashes[2].extra_keys is None
|
assert block_hashes[1] == sha256(
|
||||||
|
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||||
|
assert block_hashes[2] == sha256(
|
||||||
|
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||||
|
None))
|
||||||
|
|
||||||
blocks = manager.allocate_slots(req0, 59,
|
blocks = manager.allocate_slots(req0, 59,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
|
assert blocks is not None
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0.num_computed_tokens = 59
|
req0.num_computed_tokens = 59
|
||||||
|
|
||||||
@ -964,14 +984,13 @@ def test_cache_key_salting():
|
|||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||||
|
|
||||||
# Now one more block that should not have extra keys.
|
|
||||||
assert len(block_hashes) == 4
|
assert len(block_hashes) == 4
|
||||||
assert block_hashes[3].extra_keys is None
|
assert block_hashes[3] == sha256(
|
||||||
|
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
|
||||||
|
|
||||||
# Test cache hit with a new request that has the same salt.
|
# Test cache hit with a new request that has the same salt.
|
||||||
token_ids = common_token_ids + [4] * 11
|
token_ids = common_token_ids + [4] * 11
|
||||||
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
|
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
# Should match only a prefix of 3 blocks.
|
# Should match only a prefix of 3 blocks.
|
||||||
assert len(computed_blocks.blocks[0]) == 3
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
@ -979,13 +998,19 @@ def test_cache_key_salting():
|
|||||||
|
|
||||||
# Test cache miss with same content but different salt.
|
# Test cache miss with same content but different salt.
|
||||||
token_ids = common_token_ids + [4] * 11
|
token_ids = common_token_ids + [4] * 11
|
||||||
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
|
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert len(computed_blocks.blocks[0]) == 0
|
assert len(computed_blocks.blocks[0]) == 0
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
block_hashes = req2.block_hashes
|
block_hashes = req2.block_hashes
|
||||||
assert len(block_hashes) == 3
|
assert len(block_hashes) == 3
|
||||||
assert block_hashes[0].extra_keys == ("salt2", )
|
assert block_hashes[0] == sha256(
|
||||||
|
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
|
||||||
|
assert block_hashes[1] == sha256(
|
||||||
|
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
|
||||||
|
assert block_hashes[2] == sha256(
|
||||||
|
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
|
||||||
|
None))
|
||||||
|
|
||||||
|
|
||||||
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||||
@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
# | Common-0 | Common-1 | Common-2 | ... |
|
# | Common-0 | Common-1 | Common-2 | ... |
|
||||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
req0 = make_request("0", common_token_ids, block_size, hash)
|
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
req0.request_id]
|
req0.request_id]
|
||||||
|
|
||||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||||
req1 = make_request("1", common_token_ids * 2, block_size, hash)
|
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert computed_blocks.blocks[0] == block_part0
|
assert computed_blocks.blocks[0] == block_part0
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
|
|
||||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||||
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
|
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
# but it cannot be allocated due to insufficient free blocks (2).
|
# but it cannot be allocated due to insufficient free blocks (2).
|
||||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||||
req3 = make_request("3", common_token_ids * 3, block_size, hash)
|
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||||
assert computed_blocks.blocks[0] == block_part1
|
assert computed_blocks.blocks[0] == block_part1
|
||||||
assert num_computed_tokens == 6 * 16
|
assert num_computed_tokens == 6 * 16
|
||||||
@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
|
|||||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
unique_token_ids = [3] * 7
|
unique_token_ids = [3] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids, block_size, hash)
|
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||||
blocks = manager.allocate_slots(req0, 55)
|
blocks = manager.allocate_slots(req0, 55)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
unique_token_ids = [4] * 7
|
unique_token_ids = [4] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req1 = make_request("1", all_token_ids, block_size, hash)
|
req1 = make_request("1", all_token_ids, block_size, sha256)
|
||||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||||
assert len(req1.block_hashes) == 3
|
assert len(req1.block_hashes) == 3
|
||||||
assert len(computed_blocks.blocks[0]) == 3
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
|
|||||||
assert manager.prefix_cache_stats is None
|
assert manager.prefix_cache_stats is None
|
||||||
|
|
||||||
# Call all functions that check whether log_stats is disabled.
|
# Call all functions that check whether log_stats is disabled.
|
||||||
req = make_request("0", list(range(16)), block_size, hash)
|
req = make_request("0", list(range(16)), block_size, sha256)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
assert not computed_blocks.blocks[0]
|
assert not computed_blocks.blocks[0]
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
|
|||||||
|
|
||||||
def test_maybe_evict_cached_block():
|
def test_maybe_evict_cached_block():
|
||||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||||
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
|
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||||
token_ids=(100, )),
|
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||||
group_id=1000)
|
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||||
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
|
|
||||||
token_ids=(200, )),
|
|
||||||
group_id=2000)
|
|
||||||
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
|
|
||||||
token_ids=(300, )),
|
|
||||||
group_id=3000)
|
|
||||||
block_hashes = [
|
block_hashes = [
|
||||||
block_hash0,
|
block_hash0,
|
||||||
block_hash1,
|
block_hash1,
|
||||||
@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_tokens = block_size * blocks_to_cache
|
num_tokens = block_size * blocks_to_cache
|
||||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
|
||||||
_ = manager.allocate_slots(req0, num_tokens)
|
_ = manager.allocate_slots(req0, num_tokens)
|
||||||
events = manager.take_events()
|
events = manager.take_events()
|
||||||
|
|
||||||
@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
|||||||
# Should see block_to_cache number of removed block events and a new block
|
# Should see block_to_cache number of removed block events and a new block
|
||||||
# stored event
|
# stored event
|
||||||
manager.free(req0)
|
manager.free(req0)
|
||||||
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
|
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
|
||||||
_ = manager.allocate_slots(req1, num_tokens)
|
_ = manager.allocate_slots(req1, num_tokens)
|
||||||
events = manager.take_events()
|
events = manager.take_events()
|
||||||
|
|
||||||
@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
|
|||||||
|
|
||||||
# Request with 3 full blocks (48 tokens)
|
# Request with 3 full blocks (48 tokens)
|
||||||
token_ids = [0] * (3 * block_size)
|
token_ids = [0] * (3 * block_size)
|
||||||
req = make_request("divisible_request", token_ids, block_size, hash)
|
req = make_request("divisible_request", token_ids, block_size, sha256)
|
||||||
|
|
||||||
# Prime the cache
|
# Prime the cache
|
||||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
|
|||||||
manager.free(req)
|
manager.free(req)
|
||||||
|
|
||||||
# New request with same tokens + Eagle enabled
|
# New request with same tokens + Eagle enabled
|
||||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
|
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
|
||||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||||
|
|
||||||
# Should retain 1 block:
|
# Should retain 1 block:
|
||||||
@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
|
|||||||
)
|
)
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
token_ids = [0] * (2 * block_size + 5)
|
token_ids = [0] * (2 * block_size + 5)
|
||||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||||
|
|
||||||
# Prime the cache
|
# Prime the cache
|
||||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
|
|||||||
manager.free(req)
|
manager.free(req)
|
||||||
|
|
||||||
# New request with Eagle enabled
|
# New request with Eagle enabled
|
||||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||||
assert len(computed_blocks.blocks[0]) == 1
|
assert len(computed_blocks.blocks[0]) == 1
|
||||||
@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
|
|||||||
|
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
token_ids = [0] * (2 * block_size + 5)
|
token_ids = [0] * (2 * block_size + 5)
|
||||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
req = make_request("partial_block_test", token_ids, block_size, sha256)
|
||||||
|
|
||||||
# Prime the cache
|
# Prime the cache
|
||||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||||
@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
|
|||||||
manager.free(req)
|
manager.free(req)
|
||||||
|
|
||||||
# New request with Eagle enabled
|
# New request with Eagle enabled
|
||||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
|
||||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||||
assert len(computed_blocks.blocks[0]) == 1
|
assert len(computed_blocks.blocks[0]) == 1
|
||||||
@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
|
|||||||
assert manager.block_pool.get_cached_block(
|
assert manager.block_pool.get_cached_block(
|
||||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||||
|
|
||||||
# New request
|
# New request
|
||||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
|
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
|
||||||
block_size, hash)
|
block_size, sha256)
|
||||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
||||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||||
# not considered. But after dropping the last matched block due to eagle,
|
# not considered. But after dropping the last matched block due to eagle,
|
||||||
|
|||||||
@ -6,8 +6,8 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||||
KVCacheBlock)
|
make_block_hash_with_group_id)
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
ChunkedLocalAttentionManager, SlidingWindowManager)
|
ChunkedLocalAttentionManager, SlidingWindowManager)
|
||||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||||
@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
|
|
||||||
def run_one_case(block_is_cached, tail_token, expect_length):
|
def run_one_case(block_is_cached, tail_token, expect_length):
|
||||||
block_hash_list = [
|
block_hash_list = [
|
||||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block.clear()
|
||||||
@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
block_pool.cached_block_hash_to_block[
|
||||||
block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||||
i: block_pool.blocks[i + 10],
|
i: block_pool.blocks[i + 10],
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
|
|
||||||
def run_one_case(block_is_cached, expect_length):
|
def run_one_case(block_is_cached, expect_length):
|
||||||
block_hash_list = [
|
block_hash_list = [
|
||||||
BlockHash(i, ()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block.clear()
|
||||||
@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
block_pool.cached_block_hash_to_block[
|
||||||
block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||||
i: block_pool.blocks[i + 10],
|
i: block_pool.blocks[i + 10],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
|||||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import sha256
|
||||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||||
init_none_hash)
|
init_none_hash)
|
||||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||||
@ -130,10 +131,10 @@ def create_requests(
|
|||||||
) -> list[Request]:
|
) -> list[Request]:
|
||||||
global _none_hash_initialized
|
global _none_hash_initialized
|
||||||
if not _none_hash_initialized:
|
if not _none_hash_initialized:
|
||||||
init_none_hash(hash)
|
init_none_hash(sha256)
|
||||||
_none_hash_initialized = True
|
_none_hash_initialized = True
|
||||||
|
|
||||||
block_hasher = get_request_block_hasher(block_size, hash)
|
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||||
sampling_params = SamplingParams(ignore_eos=False,
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
|||||||
@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
|
|||||||
assert vllm_config.cache_config.enable_prefix_caching
|
assert vllm_config.cache_config.enable_prefix_caching
|
||||||
|
|
||||||
# default hash algorithm is "builtin"
|
# default hash algorithm is "builtin"
|
||||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||||
|
|
||||||
|
# set hash algorithm to sha256_cbor
|
||||||
|
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
|
||||||
|
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||||
|
assert vllm_config.cache_config.prefix_caching_hash_algo == \
|
||||||
|
"sha256_cbor"
|
||||||
|
|
||||||
# set hash algorithm to sha256
|
# set hash algorithm to sha256
|
||||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||||
|
|
||||||
# set hash algorithm to builtin
|
|
||||||
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
|
|
||||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
|
||||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
|
||||||
|
|
||||||
# an invalid hash algorithm raises an error
|
# an invalid hash algorithm raises an error
|
||||||
parser.exit_on_error = False
|
parser.exit_on_error = False
|
||||||
with pytest.raises(ArgumentError):
|
with pytest.raises(ArgumentError):
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
|||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||||
SharedStorageConnector)
|
SharedStorageConnector)
|
||||||
|
from vllm.utils import sha256
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||||
init_none_hash)
|
init_none_hash)
|
||||||
@ -127,11 +128,11 @@ def create_request(request_id: int,
|
|||||||
use_all_1s_for_prompt_tokens: bool = False,
|
use_all_1s_for_prompt_tokens: bool = False,
|
||||||
num_remote_blocks: int = 3,
|
num_remote_blocks: int = 3,
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
hash_fn: Callable = hash) -> Request:
|
hash_fn: Callable = sha256) -> Request:
|
||||||
"""Make dummy request for testing."""
|
"""Make dummy request for testing."""
|
||||||
global _none_hash_initialized
|
global _none_hash_initialized
|
||||||
if not _none_hash_initialized:
|
if not _none_hash_initialized:
|
||||||
init_none_hash(hash)
|
init_none_hash(hash_fn)
|
||||||
_none_hash_initialized = True
|
_none_hash_initialized = True
|
||||||
|
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||||
|
|||||||
@ -24,7 +24,7 @@ logger = init_logger(__name__)
|
|||||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||||
MambaDType = Literal["auto", "float32"]
|
MambaDType = Literal["auto", "float32"]
|
||||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -63,17 +63,12 @@ class CacheConfig:
|
|||||||
"""Sliding window size for the KV cache. This is primarily set in
|
"""Sliding window size for the KV cache. This is primarily set in
|
||||||
`ModelConfig` and that value should be manually duplicated here."""
|
`ModelConfig` and that value should be manually duplicated here."""
|
||||||
enable_prefix_caching: Optional[bool] = None
|
enable_prefix_caching: Optional[bool] = None
|
||||||
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
"""Whether to enable prefix caching. Enabled by default for V1."""
|
||||||
default for V1."""
|
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
|
||||||
"""Set the hash algorithm for prefix caching:\n
|
"""Set the hash algorithm for prefix caching:\n
|
||||||
- "builtin" is Python's built-in hash.\n
|
- "sha256" uses Pickle for object serialization before hashing.\n
|
||||||
- "sha256" is collision resistant but with certain overheads.
|
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||||
This option uses Pickle for object serialization before hashing.\n
|
serializes objects using canonical CBOR and hashes them with SHA-256."""
|
||||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
|
||||||
hash. It serializes objects using canonical CBOR and hashes them with
|
|
||||||
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
|
|
||||||
digest."""
|
|
||||||
cpu_offload_gb: float = 0
|
cpu_offload_gb: float = 0
|
||||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import zmq
|
|||||||
|
|
||||||
from vllm.config.kv_events import KVEventsConfig
|
from vllm.config.kv_events import KVEventsConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
|
|||||||
|
|
||||||
|
|
||||||
class BlockStored(KVCacheEvent):
|
class BlockStored(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[ExternalBlockHash]
|
||||||
parent_block_hash: Optional[int]
|
parent_block_hash: Optional[ExternalBlockHash]
|
||||||
token_ids: list[int]
|
token_ids: list[int]
|
||||||
block_size: int
|
block_size: int
|
||||||
lora_id: Optional[int]
|
lora_id: Optional[int]
|
||||||
@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
|
|||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[int]
|
block_hashes: list[ExternalBlockHash]
|
||||||
medium: Optional[str]
|
medium: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1592,20 +1592,12 @@ class EngineArgs:
|
|||||||
"in low performance due to small KV cache size. Consider "
|
"in low performance due to small KV cache size. Consider "
|
||||||
"setting --max-model-len to a smaller value.", max_model_len)
|
"setting --max-model-len to a smaller value.", max_model_len)
|
||||||
|
|
||||||
# if using prefix caching, we must set a hash algo
|
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||||
if self.enable_prefix_caching:
|
if self.enable_prefix_caching and model_config.is_multimodal_model:
|
||||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
logger.warning(
|
||||||
if model_config.is_multimodal_model:
|
"--enable-prefix-caching is not supported for multimodal "
|
||||||
logger.warning(
|
"models in V0 and has been disabled.")
|
||||||
"--enable-prefix-caching is not supported for multimodal "
|
self.enable_prefix_caching = False
|
||||||
"models in V0 and has been disabled.")
|
|
||||||
self.enable_prefix_caching = False
|
|
||||||
|
|
||||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
|
||||||
if self.prefix_caching_hash_algo == "sha256":
|
|
||||||
raise ValueError(
|
|
||||||
"sha256 is not supported for prefix caching in V0 engine. "
|
|
||||||
"Please use 'builtin'.")
|
|
||||||
|
|
||||||
# Set max_num_seqs to 256 for VLLM_V0.
|
# Set max_num_seqs to 256 for VLLM_V0.
|
||||||
if self.max_num_seqs is None:
|
if self.max_num_seqs is None:
|
||||||
|
|||||||
@ -171,6 +171,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# Add optional custom scopes for profiling, disable to avoid overheads
|
# Add optional custom scopes for profiling, disable to avoid overheads
|
||||||
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
"VLLM_CUSTOM_SCOPES_FOR_PROFILING":
|
||||||
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
|
||||||
|
|
||||||
|
# Represent block hashes in KV cache events as 64-bit integers instead of
|
||||||
|
# raw bytes. Defaults to True for backward compatibility.
|
||||||
|
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
|
|||||||
and getattr(cfg.attn_config, "alibi", False)))))
|
and getattr(cfg.attn_config, "alibi", False)))))
|
||||||
|
|
||||||
|
|
||||||
def sha256(input) -> int:
|
def sha256(input) -> bytes:
|
||||||
"""Hash any picklable Python object using SHA-256.
|
"""Hash any picklable Python object using SHA-256.
|
||||||
|
|
||||||
The input is serialized using pickle before hashing, which allows
|
The input is serialized using pickle before hashing, which allows
|
||||||
@ -3260,16 +3260,15 @@ def sha256(input) -> int:
|
|||||||
input: Any picklable Python object.
|
input: Any picklable Python object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An integer representing the SHA-256 hash of the serialized input.
|
Bytes representing the SHA-256 hash of the serialized input.
|
||||||
"""
|
"""
|
||||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
return hashlib.sha256(input_bytes).digest()
|
||||||
byteorder="big")
|
|
||||||
|
|
||||||
|
|
||||||
def sha256_cbor_64bit(input) -> int:
|
def sha256_cbor(input) -> bytes:
|
||||||
"""
|
"""
|
||||||
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
|
Hash objects using CBOR serialization and SHA-256.
|
||||||
|
|
||||||
This option is useful for non-Python-dependent serialization and hashing.
|
This option is useful for non-Python-dependent serialization and hashing.
|
||||||
|
|
||||||
@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
|
|||||||
Custom classes must implement CBOR serialization methods.
|
Custom classes must implement CBOR serialization methods.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An integer in the range [0, 2^64-1] representing the lower 64 bits
|
Bytes representing the SHA-256 hash of the CBOR serialized input.
|
||||||
of the SHA-256 hash of the CBOR serialized input.
|
|
||||||
"""
|
"""
|
||||||
input_bytes = cbor2.dumps(input, canonical=True)
|
input_bytes = cbor2.dumps(input, canonical=True)
|
||||||
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
return hashlib.sha256(input_bytes).digest()
|
||||||
byteorder="big")
|
|
||||||
|
|
||||||
return full_hash & ((1 << 64) - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
|
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||||
"""Get a hash function by name, or raise an error if
|
"""Get a hash function by name, or raise an error if
|
||||||
the function is not found.
|
the function is not found.
|
||||||
Args:
|
Args:
|
||||||
@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
|
|||||||
"""
|
"""
|
||||||
if hash_fn_name == "sha256":
|
if hash_fn_name == "sha256":
|
||||||
return sha256
|
return sha256
|
||||||
if hash_fn_name == "sha256_cbor_64bit":
|
if hash_fn_name == "sha256_cbor":
|
||||||
return sha256_cbor_64bit
|
return sha256_cbor
|
||||||
if hash_fn_name == "builtin":
|
|
||||||
return hash
|
|
||||||
|
|
||||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
|||||||
KVCacheEvent)
|
KVCacheEvent)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
ExternalBlockHash,
|
||||||
|
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||||
|
get_block_hash,
|
||||||
|
make_block_hash_with_group_id,
|
||||||
|
maybe_convert_block_hash)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -84,8 +88,10 @@ class BlockPool:
|
|||||||
"""
|
"""
|
||||||
cached_blocks = []
|
cached_blocks = []
|
||||||
for group_id in kv_cache_group_ids:
|
for group_id in kv_cache_group_ids:
|
||||||
|
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||||
|
block_hash, group_id)
|
||||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||||
BlockHashWithGroupId(block_hash, group_id))
|
block_hash_with_group_id)
|
||||||
if not cached_blocks_one_group:
|
if not cached_blocks_one_group:
|
||||||
return None
|
return None
|
||||||
first_block = next(iter(cached_blocks_one_group.values()))
|
first_block = next(iter(cached_blocks_one_group.values()))
|
||||||
@ -124,28 +130,29 @@ class BlockPool:
|
|||||||
assert len(request.block_hashes) >= num_full_blocks
|
assert len(request.block_hashes) >= num_full_blocks
|
||||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||||
|
|
||||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
new_hashes: Optional[list[ExternalBlockHash]] = (
|
||||||
else None)
|
[] if self.enable_kv_cache_events else None)
|
||||||
for i, blk in enumerate(new_full_blocks):
|
for i, blk in enumerate(new_full_blocks):
|
||||||
assert blk.block_hash is None
|
assert blk.block_hash is None
|
||||||
block_hash = new_block_hashes[i]
|
block_hash = new_block_hashes[i]
|
||||||
|
|
||||||
# Update and added the full block to the cache.
|
# Update and added the full block to the cache.
|
||||||
block_hash_with_group_id = BlockHashWithGroupId(
|
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||||
block_hash, kv_cache_group_id)
|
block_hash, kv_cache_group_id)
|
||||||
blk.block_hash = block_hash_with_group_id
|
blk.block_hash = block_hash_with_group_id
|
||||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||||
blk.block_id] = blk
|
blk.block_id] = blk
|
||||||
if new_hashes is not None:
|
if new_hashes is not None:
|
||||||
new_hashes.append(block_hash.hash_value)
|
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||||
|
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
if num_cached_blocks == 0:
|
if num_cached_blocks == 0:
|
||||||
parent_block_hash = None
|
parent_block_hash: Optional[ExternalBlockHash] = None
|
||||||
else:
|
else:
|
||||||
parent_block = blocks[num_cached_blocks - 1]
|
parent_block = blocks[num_cached_blocks - 1]
|
||||||
assert parent_block.block_hash is not None
|
assert parent_block.block_hash is not None
|
||||||
parent_block_hash = parent_block.block_hash.get_hash_value()
|
parent_block_hash = maybe_convert_block_hash(
|
||||||
|
get_block_hash(parent_block.block_hash))
|
||||||
|
|
||||||
self.kv_event_queue.append(
|
self.kv_event_queue.append(
|
||||||
BlockStored(
|
BlockStored(
|
||||||
@ -220,7 +227,9 @@ class BlockPool:
|
|||||||
# we disable hybrid kv cache manager when kv cache event is
|
# we disable hybrid kv cache manager when kv cache event is
|
||||||
# enabled, so there is only one group.
|
# enabled, so there is only one group.
|
||||||
self.kv_event_queue.append(
|
self.kv_event_queue.append(
|
||||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
|
BlockRemoved(block_hashes=[
|
||||||
|
maybe_convert_block_hash(get_block_hash(block_hash))
|
||||||
|
],
|
||||||
medium=MEDIUM_GPU))
|
medium=MEDIUM_GPU))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,12 @@ import os
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from typing import Any, Callable, NamedTuple, Optional
|
from typing import Any, Callable, NewType, Optional, Union
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
|
from vllm.utils import GiB_bytes, cdiv, sha256_cbor
|
||||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, KVCacheSpec,
|
KVCacheGroupSpec, KVCacheSpec,
|
||||||
@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
|||||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
# BlockHash represents the hash of a single KV-cache block used for
|
||||||
|
# prefix caching. Treating it as a distinct type from ``bytes`` helps
|
||||||
|
# catch accidental misuse when passing around raw byte strings.
|
||||||
|
BlockHash = NewType("BlockHash", bytes)
|
||||||
|
|
||||||
|
# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID.
|
||||||
|
# It is represented as raw bytes for compactness and efficiency. The helper
|
||||||
|
# functions below pack/unpack the ``BlockHash`` and group id into/from the key.
|
||||||
|
BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes)
|
||||||
|
|
||||||
|
# ExternalBlockHash is used for reproducible prefix-cache block hashing.
|
||||||
|
# It's a union of ``bytes`` and ``int`` to keep backward compatibility
|
||||||
|
# after we default block hashing to use sha256 bytes.
|
||||||
|
ExternalBlockHash = Union[bytes, int]
|
||||||
|
|
||||||
|
|
||||||
class BlockHash(NamedTuple):
|
def make_block_hash_with_group_id(block_hash: BlockHash,
|
||||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
group_id: int) -> BlockHashWithGroupId:
|
||||||
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
"""Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``.
|
||||||
hash collisions when the hash value is the same. By using SHA256 however,
|
|
||||||
hash collisions are practically impossible.
|
The group id is encoded using 4 bytes in big-endian order and appended to
|
||||||
|
the block hash bytes. This representation avoids creating tuples while
|
||||||
|
still allowing us to recover both components when needed.
|
||||||
"""
|
"""
|
||||||
# Hash value of the block in an integer.
|
return BlockHashWithGroupId(block_hash +
|
||||||
hash_value: int
|
group_id.to_bytes(4, "big", signed=False))
|
||||||
# Token IDs in the block.
|
|
||||||
token_ids: tuple[int, ...]
|
|
||||||
# Extra keys for the block.
|
|
||||||
extra_keys: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BlockHashWithGroupId(NamedTuple):
|
def get_block_hash(key: BlockHashWithGroupId) -> BlockHash:
|
||||||
# The hash value for the contents (e.g., token_ids) of a block without group
|
"""Extract the ``BlockHash`` from a ``BlockHashWithGroupId``."""
|
||||||
# ID. The value is the same for blocks representing the same tokens but for
|
return BlockHash(key[:-4])
|
||||||
# different groups.
|
|
||||||
block_hash: BlockHash
|
|
||||||
# The KV cache group ID.
|
|
||||||
group_id: int
|
|
||||||
|
|
||||||
def get_hash_value(self) -> int:
|
|
||||||
return self.block_hash.hash_value
|
|
||||||
|
|
||||||
|
def get_group_id(key: BlockHashWithGroupId) -> int:
|
||||||
|
"""Extract the group id from a ``BlockHashWithGroupId``."""
|
||||||
|
return int.from_bytes(key[-4:], "big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
|
||||||
|
if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES:
|
||||||
|
return hash_bytes
|
||||||
|
return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1)
|
||||||
|
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# The hash seed for the first block of any prefix block sequence.
|
# The hash seed for the first block of any prefix block sequence.
|
||||||
#
|
#
|
||||||
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
||||||
# variable if set such that processes can share the seed if needed.
|
# variable if set such that processes can share the seed if needed. This aligns
|
||||||
# This aligns with the behavior of Python's hash() function, which also uses
|
# with the behavior of Python's hash() function, which also uses a random seed
|
||||||
# a random seed if PYTHONHASHSEED is not set.
|
# if PYTHONHASHSEED is not set.
|
||||||
#
|
#
|
||||||
# The function `init_none_hash` initializes this variable globally.
|
# The function `init_none_hash` initializes this variable globally.
|
||||||
NONE_HASH: int
|
NONE_HASH: BlockHash
|
||||||
|
|
||||||
|
|
||||||
def init_none_hash(hash_fn: Callable):
|
def init_none_hash(hash_fn: Callable[[Any], bytes]):
|
||||||
global NONE_HASH
|
global NONE_HASH
|
||||||
|
|
||||||
hash_seed = os.getenv("PYTHONHASHSEED")
|
hash_seed = os.getenv("PYTHONHASHSEED")
|
||||||
if hash_seed is None and hash_fn is sha256_cbor_64bit:
|
if hash_seed is None and hash_fn is sha256_cbor:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
|
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
|
||||||
"block-hashes when using sha256_cbor_64bit as the hash function."
|
"block-hashes when using sha256_cbor as the hash function."
|
||||||
"Consider setting PYTHONHASHSEED to a fixed value for "
|
"Consider setting PYTHONHASHSEED to a fixed value for "
|
||||||
"reproducibility.")
|
"reproducibility.")
|
||||||
|
|
||||||
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
|
if hash_seed is None:
|
||||||
if hash_seed is None else hash_fn(hash_seed))
|
NONE_HASH = BlockHash(os.urandom(32))
|
||||||
|
else:
|
||||||
|
NONE_HASH = BlockHash(hash_fn(hash_seed))
|
||||||
|
|
||||||
|
|
||||||
class PrefixCachingMetrics:
|
class PrefixCachingMetrics:
|
||||||
@ -142,8 +162,8 @@ class KVCacheBlock:
|
|||||||
block_id: int
|
block_id: int
|
||||||
# Reference count.
|
# Reference count.
|
||||||
ref_cnt: int = 0
|
ref_cnt: int = 0
|
||||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
# The hash key (block hash + group id) of the block, only available
|
||||||
# It is only available when the block is full.
|
# when the block is full and cached.
|
||||||
_block_hash: Optional[BlockHashWithGroupId] = None
|
_block_hash: Optional[BlockHashWithGroupId] = None
|
||||||
|
|
||||||
# Used to construct a doubly linked list for free blocks.
|
# Used to construct a doubly linked list for free blocks.
|
||||||
@ -177,7 +197,7 @@ class KVCacheBlock:
|
|||||||
if self.next_free_block else None)
|
if self.next_free_block else None)
|
||||||
return (f"KVCacheBlock(block_id={self.block_id}, "
|
return (f"KVCacheBlock(block_id={self.block_id}, "
|
||||||
f"ref_cnt={self.ref_cnt}, "
|
f"ref_cnt={self.ref_cnt}, "
|
||||||
f"_block_hash={self._block_hash}, "
|
f"_block_hash={self._block_hash!r}, "
|
||||||
f"prev_free_block={prev_block_id}, "
|
f"prev_free_block={prev_block_id}, "
|
||||||
f"next_free_block={next_block_id})")
|
f"next_free_block={next_block_id})")
|
||||||
|
|
||||||
@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
|
|||||||
|
|
||||||
|
|
||||||
def hash_block_tokens(
|
def hash_block_tokens(
|
||||||
hash_function: Callable,
|
hash_function: Callable[[Any], bytes],
|
||||||
parent_block_hash: Optional[int],
|
parent_block_hash: Optional[BlockHash],
|
||||||
curr_block_token_ids: Sequence[int],
|
curr_block_token_ids: Sequence[int],
|
||||||
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
|
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
|
||||||
"""Computes a hash value corresponding to the contents of a block and
|
"""Computes a hash value corresponding to the contents of a block and
|
||||||
the contents of the preceding block(s). The hash value is used for
|
the contents of the preceding block(s). The hash value is used for
|
||||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||||
hash values for the same block contents.
|
hash values for the same block contents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hash_function: The hash function used to compute block hash.
|
hash_function: The hash function used to compute block hash.
|
||||||
parent_block_hash: The hash of the parent block. None
|
parent_block_hash: The hash of the parent block. None
|
||||||
@ -533,7 +552,6 @@ def hash_block_tokens(
|
|||||||
curr_block_token_ids: A list of token ids in the current
|
curr_block_token_ids: A list of token ids in the current
|
||||||
block. The current block is assumed to be full.
|
block. The current block is assumed to be full.
|
||||||
extra_keys: Extra keys for the block.
|
extra_keys: Extra keys for the block.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The hash value of the block and the token ids in the block.
|
The hash value of the block and the token ids in the block.
|
||||||
The entire tuple is used as the hash key of the block.
|
The entire tuple is used as the hash key of the block.
|
||||||
@ -544,26 +562,16 @@ def hash_block_tokens(
|
|||||||
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
||||||
return BlockHash(
|
return BlockHash(
|
||||||
hash_function(
|
hash_function(
|
||||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)))
|
||||||
curr_block_token_ids_tuple, extra_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def get_request_block_hasher(
|
def get_request_block_hasher(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
caching_hash_fn: Callable[[Any],
|
caching_hash_fn: Callable[[Any], bytes],
|
||||||
int]) -> Callable[[Request], list[BlockHash]]:
|
) -> Callable[[Request], list[BlockHash]]:
|
||||||
"""
|
"""
|
||||||
Returns a function which computes the list of un-computed block hashes
|
Returns a function which computes the list of un-computed block hashes
|
||||||
of a request.
|
of a request."""
|
||||||
|
|
||||||
Each request holds a list of its block hashes (request.block_hashes).
|
|
||||||
When a request is created, it calls the below function to compute
|
|
||||||
the hashes of all full blocks of the request's initial tokens.
|
|
||||||
The hashes are then stored in request.block_hashes.
|
|
||||||
Later, whenever new tokens are appended to the request, it calls
|
|
||||||
the below function again to compute any new full blocks of tokens.
|
|
||||||
The returned new hashes are appended to request.block_hashes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def request_block_hasher(request: Request) -> list[BlockHash]:
|
def request_block_hasher(request: Request) -> list[BlockHash]:
|
||||||
start_token_idx = len(request.block_hashes) * block_size
|
start_token_idx = len(request.block_hashes) * block_size
|
||||||
@ -577,8 +585,8 @@ def get_request_block_hasher(
|
|||||||
# last mm input.
|
# last mm input.
|
||||||
curr_mm_idx = -1
|
curr_mm_idx = -1
|
||||||
|
|
||||||
prev_block_hash_value = request.block_hashes[-1].hash_value \
|
prev_block_hash_value = (request.block_hashes[-1]
|
||||||
if request.block_hashes else None
|
if request.block_hashes else None)
|
||||||
new_block_hashes: list[BlockHash] = []
|
new_block_hashes: list[BlockHash] = []
|
||||||
while True:
|
while True:
|
||||||
end_token_idx = start_token_idx + block_size
|
end_token_idx = start_token_idx + block_size
|
||||||
@ -598,7 +606,7 @@ def get_request_block_hasher(
|
|||||||
|
|
||||||
new_block_hashes.append(block_hash)
|
new_block_hashes.append(block_hash)
|
||||||
start_token_idx += block_size
|
start_token_idx += block_size
|
||||||
prev_block_hash_value = block_hash.hash_value
|
prev_block_hash_value = block_hash
|
||||||
|
|
||||||
return new_block_hashes
|
return new_block_hashes
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user