[Prefix Cache] Use LoRA name for consistent KV-cache block hashing (#27211)

Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
This commit is contained in:
Sage 2025-10-22 21:13:03 +03:00 committed by GitHub
parent 1cb8c6c5fe
commit 1651003c35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 6 deletions

View File

@ -8,6 +8,7 @@ import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
@ -449,6 +450,24 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1
def test_generate_block_hash_extra_keys_lora():
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
)
request.lora_request = LoRARequest(
lora_name="test_lora_adapter", lora_int_id=1, lora_path="/path/to/lora"
)
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
assert extra_keys == ("test_lora_adapter",)
request.lora_request = None
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 3, 0)
assert extra_keys is None
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn):
parent_block_hash = BlockHash(b"123")

View File

@ -373,7 +373,7 @@ def need_extra_keys(request: Request) -> bool:
"""
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# LoRA requests need to include the LoRA name.
# Request with provided cache salt need to include the salt.
return (
bool(request.mm_features)
@ -446,26 +446,26 @@ def _gen_mm_extra_hash_keys(
return extra_keys, curr_mm_idx
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
def _gen_lora_extra_hash_keys(request: Request) -> list[str]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
Return LoRA name of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
return [request.lora_request.lora_name]
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int
) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
the multi-modal inputs and request specific metadata (e.g., LoRA name).
Args:
request: The request object.
@ -480,7 +480,7 @@ def generate_block_hash_extra_keys(
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx
)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
lora_extra_keys: list[str] = _gen_lora_extra_hash_keys(request)
cache_salt_keys: list[str] = (
[request.cache_salt] if (start_token_idx == 0 and request.cache_salt) else []
)