mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-23 05:57:03 +08:00
[Prefix Cache] Add reproducible prefix-cache block hashing using SHA-256 + CBOR (64bit) (#20511)
Signed-off-by: Maroon Ayoub <maroon.ayoub@ibm.com>
This commit is contained in:
parent
8632e831ba
commit
66f6fbd393
@ -47,3 +47,4 @@ python-json-logger # Used by logging as per examples/others/logging_configuratio
|
||||
scipy # Required for phi-4-multimodal-instruct
|
||||
ninja # Required for xgrammar, rocm, tpu, xpu
|
||||
pybase64 # fast base64 implementation
|
||||
cbor2 # Required for cross-language serialization of hashable objects
|
||||
|
||||
@ -11,6 +11,7 @@ ruff
|
||||
# Required for argparse hook only
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
cachetools
|
||||
cbor2
|
||||
cloudpickle
|
||||
fastapi
|
||||
msgspec
|
||||
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
@ -16,7 +16,8 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
|
||||
hash_block_tokens, hash_request_tokens, init_none_hash,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
@ -78,24 +79,27 @@ def new_sliding_window_spec(block_size=16,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
def test_none_hash(monkeypatch):
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_none_hash(monkeypatch, hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
# case 1: PYTHONHASHSEED is not set, use random
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv('PYTHONHASHSEED', raising=False)
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH != 0
|
||||
|
||||
# case 2: PYTHONHASHSEED is set, use the seed
|
||||
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv('PYTHONHASHSEED', 'python hash seed')
|
||||
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
|
||||
reloaded_kv_cache_utils.init_none_hash(hash_fn)
|
||||
assert reloaded_kv_cache_utils.NONE_HASH is not None
|
||||
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
|
||||
assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
||||
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
@ -287,9 +291,10 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
assert next_mm_idx == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_hash_block_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
parent_block_hash = 123
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
@ -303,9 +308,10 @@ def test_hash_block_tokens(hash_fn):
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_hash_request_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -332,8 +338,10 @@ def test_hash_request_tokens(hash_fn):
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_hash_tokens_different_mm_input(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -359,8 +367,10 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -916,4 +926,4 @@ def test_get_kv_cache_config():
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
|
||||
])
|
||||
])
|
||||
|
||||
@ -11,11 +11,12 @@ import torch
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, hash_block_tokens)
|
||||
KVCacheBlock, hash_block_tokens,
|
||||
init_none_hash)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@ -91,7 +92,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
@ -101,7 +102,8 @@ def test_prefill(hash_algo):
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = sha256 if hash_algo == "sha256" else hash
|
||||
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if hash_algo == "sha256" else hash)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -696,12 +698,14 @@ def test_basic_prefix_caching_disabled():
|
||||
assert not blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
function of KVCacheManager.
|
||||
"""
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
block_size = 4
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
|
||||
@ -1564,7 +1564,7 @@ class ModelConfig:
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
|
||||
|
||||
|
||||
@config
|
||||
@ -1609,7 +1609,12 @@ class CacheConfig:
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads."""
|
||||
- "sha256" is collision resistant but with certain overheads.
|
||||
This option uses Pickle for object serialization before hashing.\n
|
||||
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
|
||||
hash. It serializes objects using canonical CBOR and hashes them with
|
||||
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
|
||||
digest."""
|
||||
cpu_offload_gb: float = 0
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
|
||||
@ -52,6 +52,7 @@ from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import cachetools
|
||||
import cbor2
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -3177,6 +3178,29 @@ def sha256(input) -> int:
|
||||
byteorder="big")
|
||||
|
||||
|
||||
def sha256_cbor_64bit(input) -> int:
|
||||
"""
|
||||
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
|
||||
|
||||
This option is useful for non-Python-dependent serialization and hashing.
|
||||
|
||||
Args:
|
||||
input: Object to be serialized and hashed. Supported types include
|
||||
basic Python types and complex structures like lists, tuples, and
|
||||
dictionaries.
|
||||
Custom classes must implement CBOR serialization methods.
|
||||
|
||||
Returns:
|
||||
An integer in the range [0, 2^64-1] representing the lower 64 bits
|
||||
of the SHA-256 hash of the CBOR serialized input.
|
||||
"""
|
||||
input_bytes = cbor2.dumps(input, canonical=True)
|
||||
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
return full_hash & ((1 << 64) - 1)
|
||||
|
||||
|
||||
def is_torch_equal_or_newer(target: str) -> bool:
|
||||
"""Check if the installed torch version is >= the target version.
|
||||
|
||||
|
||||
@ -7,10 +7,10 @@ from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
hash_request_tokens, init_none_hash)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -79,7 +79,10 @@ class KVCacheManager:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.caching_hash_fn = (
|
||||
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if caching_hash_algo == "sha256" else hash)
|
||||
init_none_hash(self.caching_hash_fn)
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
@ -46,18 +46,30 @@ class BlockHashWithGroupId(NamedTuple):
|
||||
return self.block_hash.hash_value
|
||||
|
||||
|
||||
# The hash seed for the first block of the prefix block sequence.
|
||||
#
|
||||
# Even if the hash function is the builtin hash(), we use sha256 to generate
|
||||
# the initial hash to simplify the code. This is not performance critical
|
||||
# as it is done one per process.
|
||||
# The hash seed for the first block of any prefix block sequence.
|
||||
#
|
||||
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
||||
# variable if set such that processes can share the seed if needed.
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
|
||||
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
|
||||
#
|
||||
# The function `init_none_hash` initializes this variable globally.
|
||||
NONE_HASH: int
|
||||
|
||||
|
||||
def init_none_hash(hash_fn: Callable):
|
||||
global NONE_HASH
|
||||
|
||||
hash_seed = os.getenv("PYTHONHASHSEED")
|
||||
if hash_seed is None and hash_fn is sha256_cbor_64bit:
|
||||
logger.warning(
|
||||
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
|
||||
"block-hashes when using sha256_cbor_64bit as the hash function."
|
||||
"Consider setting PYTHONHASHSEED to a fixed value for "
|
||||
"reproducibility.")
|
||||
|
||||
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
|
||||
if hash_seed is None else hash_fn(hash_seed))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user