[Core] Add xxHash as a high-performance hash option for accelerating prefix caching (#29163)

Signed-off-by: LuminolT <lumischen01@gmail.com>
Signed-off-by: Lumis Chen <lumischen01@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Lumis Chen 2025-12-04 00:06:57 +08:00 committed by GitHub
parent 5aa9b09040
commit 9bcf92295a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 332 additions and 8 deletions

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
This focuses on a single test payload shaped like the prefix-cache hash input:
(32-byte bytes object, 32-int tuple)
Usage:
python benchmarks/hash_micro_benchmark.py --iterations 20000
"""
from __future__ import annotations
import argparse
import random
import statistics
import time
from collections.abc import Callable, Iterable
from vllm.utils.hashing import sha256, xxhash
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
"""Generate a deterministic test payload."""
random.seed(seed)
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
return (bytes_data, int_tuple)
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
times: list[float] = []
# Warm-up to avoid first-run noise.
for _ in range(200):
func(data)
for _ in range(iterations):
start = time.perf_counter()
func(data)
end = time.perf_counter()
times.append(end - start)
avg = statistics.mean(times)
std = statistics.stdev(times) if len(times) > 1 else 0.0
return avg, std
def _run_benchmarks(
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
data: tuple,
iterations: int,
):
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
for name, func in benchmarks:
try:
avg, std = _benchmark_func(func, data, iterations)
except ModuleNotFoundError as exc:
print(f"Skipping {name}: {exc}")
continue
yield name, avg, std
def builtin_hash(data: tuple) -> int:
"""Wrapper for Python's built-in hash()."""
return hash(data)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--iterations",
type=int,
default=10_000,
help="Number of measured iterations per hash function.",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for test payload."
)
args = parser.parse_args()
data = _generate_test_data(args.seed)
benchmarks = (
("SHA256 (pickle)", sha256),
("xxHash (pickle)", xxhash),
("built-in hash()", builtin_hash),
)
print("=" * 60)
print("HASH FUNCTION MICRO BENCHMARK")
print("=" * 60)
print("Test data: (32-byte bytes object, 32-int tuple)")
print(f"Iterations: {args.iterations:,}")
print("=" * 60)
results = list(_run_benchmarks(benchmarks, data, args.iterations))
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
print("\nResults:")
for name, avg, std in results:
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
if builtin_entry:
_, builtin_avg, _ = builtin_entry
print("\n" + "=" * 60)
print("SUMMARY (relative to built-in hash())")
print("=" * 60)
for name, avg, _ in results:
if name == "built-in hash()":
continue
speed_ratio = avg / builtin_avg
print(f"{name} is {speed_ratio:.1f}x slower than built-in hash()")
else:
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple benchmark to compare prefix-cache block hashing algorithms.
Example:
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
"""
from __future__ import annotations
import argparse
import random
import statistics
import sys
import time
from collections.abc import Callable, Iterable, Sequence
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
def _generate_blocks(
num_blocks: int, block_size: int, vocab_size: int, seed: int
) -> list[list[int]]:
rng = random.Random(seed)
return [
[rng.randrange(vocab_size) for _ in range(block_size)]
for _ in range(num_blocks)
]
def _hash_all_blocks(
hash_fn: Callable[[object], bytes],
blocks: Iterable[Sequence[int]],
) -> float:
parent_hash: BlockHash | None = None
start = time.perf_counter()
for block in blocks:
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
end = time.perf_counter()
return end - start
def _benchmark(
hash_algo: str,
blocks: list[list[int]],
trials: int,
) -> tuple[float, float, float] | None:
try:
hash_fn = get_hash_fn_by_name(hash_algo)
init_none_hash(hash_fn)
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
except ModuleNotFoundError as exc:
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
return None
avg = statistics.mean(timings)
best = min(timings)
# throughput: tokens / second
tokens_hashed = len(blocks) * len(blocks[0])
throughput = tokens_hashed / best
return avg, best, throughput
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
parser.add_argument(
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
)
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument(
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
)
parser.add_argument(
"--algorithms",
nargs="+",
default=SUPPORTED_ALGOS,
choices=SUPPORTED_ALGOS,
help="Hash algorithms to benchmark.",
)
args = parser.parse_args()
blocks = _generate_blocks(
args.num_blocks, args.block_size, args.vocab_size, args.seed
)
print(
f"Benchmarking {len(args.algorithms)} algorithms on "
f"{args.num_blocks} blocks (block size={args.block_size})."
)
for algo in args.algorithms:
result = _benchmark(algo, blocks, args.trials)
if result is None:
continue
avg, best, throughput = result
print(
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
f"throughput: {throughput / 1e6:.2f}M tokens/s"
)
if __name__ == "__main__":
main()

View File

@ -670,6 +670,35 @@ vllm bench serve \
</details> </details>
### 🧪 Hashing Benchmarks
<details class="admonition abstract" markdown="1">
<summary>Show more</summary>
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
```bash
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
```
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
```bash
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
```
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
```bash
uv pip install xxhash cbor2
```
If an algorithms dependency is missing, the script will skip it and continue.
</details>
### ⚡ Request Prioritization Benchmark ### ⚡ Request Prioritization Benchmark
<details class="admonition abstract" markdown="1"> <details class="admonition abstract" markdown="1">

View File

@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.hashing import _xxhash
def test_prefix_caching_from_cli(): def test_prefix_caching_from_cli():
@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"]) args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
def test_prefix_caching_xxhash_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# set hash algorithm to xxhash (pickle)
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
# set hash algorithm to xxhash_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
def test_defaults_with_usage_context(): def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m") engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)

View File

@ -30,7 +30,7 @@ CacheDType = Literal[
"fp8_ds_mla", "fp8_ds_mla",
] ]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"] KVOffloadingBackend = Literal["native", "lmcache"]
@ -77,9 +77,21 @@ class CacheConfig:
"""Whether to enable prefix caching.""" """Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n """Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n - "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256.""" serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0) cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means """The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to no offloading. Intuitively, this argument can be seen as a virtual way to

View File

@ -11,6 +11,17 @@ from typing import Any
import cbor2 import cbor2
try:
# It is important that this remains an optional dependency.
# It would not be allowed in environments with strict security controls,
# so it's best not to have it installed when not in use.
import xxhash as _xxhash
if not hasattr(_xxhash, "xxh3_128_digest"):
_xxhash = None
except ImportError: # pragma: no cover
_xxhash = None
def sha256(input: Any) -> bytes: def sha256(input: Any) -> bytes:
"""Hash any picklable Python object using SHA-256. """Hash any picklable Python object using SHA-256.
@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes:
return hashlib.sha256(input_bytes).digest() return hashlib.sha256(input_bytes).digest()
def _xxhash_digest(input_bytes: bytes) -> bytes:
if _xxhash is None:
raise ModuleNotFoundError(
"xxhash is required for the 'xxhash' prefix caching hash algorithms. "
"Install it via `pip install xxhash`."
)
return _xxhash.xxh3_128_digest(input_bytes)
def xxhash(input: Any) -> bytes:
"""Hash picklable objects using xxHash."""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return _xxhash_digest(input_bytes)
def xxhash_cbor(input: Any) -> bytes:
"""Hash objects serialized with CBOR using xxHash."""
input_bytes = cbor2.dumps(input, canonical=True)
return _xxhash_digest(input_bytes)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if the function is not found. """Get a hash function by name, or raise an error if the function is not found.
@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256 return sha256
if hash_fn_name == "sha256_cbor": if hash_fn_name == "sha256_cbor":
return sha256_cbor return sha256_cbor
if hash_fn_name == "xxhash":
return xxhash
if hash_fn_name == "xxhash_cbor":
return xxhash_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}") raise ValueError(f"Unsupported hash function: {hash_fn_name}")

View File

@ -12,7 +12,7 @@ from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.hashing import sha256_cbor from vllm.utils.hashing import sha256_cbor, xxhash_cbor
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
@ -83,18 +83,19 @@ logger = init_logger(__name__)
# #
# The function `init_none_hash` initializes this variable globally. # The function `init_none_hash` initializes this variable globally.
NONE_HASH: BlockHash NONE_HASH: BlockHash
_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor})
def init_none_hash(hash_fn: Callable[[Any], bytes]): def init_none_hash(hash_fn: Callable[[Any], bytes]):
global NONE_HASH global NONE_HASH
hash_seed = os.getenv("PYTHONHASHSEED") hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor: if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS:
logger.warning( logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible " "PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor as the hash function." "block-hashes when using CBOR-based hash functions such as "
"Consider setting PYTHONHASHSEED to a fixed value for " "sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a "
"reproducibility." "fixed value for reproducibility."
) )
if hash_seed is None: if hash_seed is None: