[Core] Prevent side-channel attacks via cache salting (#17045)

Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
Marko Rosenmueller 2025-04-30 14:27:21 +02:00 committed by GitHub
parent a7d5b016bd
commit 77073c77bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 328 additions and 126 deletions

View File

@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified
* Parent hash value: The hash value of the parent hash block. * Parent hash value: The hash value of the parent hash block.
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision. * Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below). * Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments.
> **Note 1:** We only cache full blocks. > **Note 1:** We only cache full blocks.
@ -76,6 +76,24 @@ Block 3
In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow. In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow.
**Cache Isolation for Security**
To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance.
```json
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Here is a document with details about the world series: ..."},
{"role": "user", "content": "Who won the world series in 2020?"}
],
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ=="
}
```
With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others.
> **Note:** Cache isolation is not supported in engine V0.
## Data Structure ## Data Structure
The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified): The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified):

View File

@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config():
assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].temperature == 0.0
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
def test_serving_chat_did_set_correct_cache_salt():
mock_model_config = MockModelConfig()
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
# Test cache_salt
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?"
}],
)
# By default cache_salt in the engine prompt is not set
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
asyncio.run(serving_chat.create_chat_completion(req))
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

View File

@ -60,8 +60,16 @@ def _run_incremental_decode(tokenizer,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
) )
request = EngineCoreRequest("", prompt_token_ids, None, None, None, params, request = EngineCoreRequest("",
None, 0.0, None) prompt_token_ids,
None,
None,
None,
params,
None,
0.0,
None,
cache_salt=None)
if fast is None: if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request( detokenizer = IncrementalDetokenizer.from_new_request(

View File

@ -29,7 +29,8 @@ from vllm.v1.request import Request
def make_request(request_id, def make_request(request_id,
prompt_token_ids, prompt_token_ids,
mm_positions=None, mm_positions=None,
mm_hashes=None): mm_hashes=None,
cache_salt=None):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None multi_modal_inputs = None
else: else:
@ -45,6 +46,7 @@ def make_request(request_id,
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=cache_salt,
) )
@ -213,6 +215,45 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
assert next_mm_idx == 0 assert next_mm_idx == 0
def test_generate_block_hash_extra_keys_cache_salt():
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
cache_salt="salt",
)
# salt is added for the first token
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 1, 0)
assert extra_keys == ('salt', )
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 10, 0)
assert extra_keys == ('salt', )
# no salt added for other tokens
extra_keys, _ = generate_block_hash_extra_keys(request, 1, 2, 0)
assert extra_keys is None
extra_keys, _ = generate_block_hash_extra_keys(request, 6, 10, 0)
assert extra_keys is None
# works together with other extra keys
request_mm = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
],
mm_hashes=["hash1"],
cache_salt="salt",
)
# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(
request_mm, 0, 5, 0)
assert extra_keys == ("hash1", "salt")
assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_block_tokens(hash_fn): def test_hash_block_tokens(hash_fn):
parent_block_hash = 123 parent_block_hash = 123

View File

@ -21,7 +21,8 @@ def make_request(request_id,
prompt_token_ids, prompt_token_ids,
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
prompt_logprobs: Optional[int] = None): prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None):
if mm_positions is None: if mm_positions is None:
multi_modal_inputs = None multi_modal_inputs = None
else: else:
@ -38,6 +39,7 @@ def make_request(request_id,
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
lora_request=None, lora_request=None,
cache_salt=cache_salt,
) )
@ -603,6 +605,66 @@ def test_mm_prefix_caching():
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
def test_cache_key_salting():
"""
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks) == 3
assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_prefill_not_enough_free_blocks_with_computed_blocks():
""" """
This is a unit test that tests the correctness of the allocate_slots This is a unit test that tests the correctness of the allocate_slots

View File

@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest:
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None,
) )

View File

@ -43,6 +43,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None,
) )

View File

@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors):
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]

View File

@ -14,6 +14,7 @@ from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator) ValidationInfo, field_validator, model_validator)
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from vllm import envs
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -408,6 +409,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified with 'logprobs', tokens are represented " "If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens " " as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")) "that are not JSON-encodable can be identified."))
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -726,6 +736,20 @@ class ChatCompletionRequest(OpenAIBaseModel):
"`add_generation_prompt` to True.") "`add_generation_prompt` to True.")
return data return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class CompletionRequest(OpenAIBaseModel): class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation

View File

@ -470,6 +470,9 @@ class OpenAIServing:
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt] return conversation, [request_prompt], [engine_prompt]
def _log_inputs( def _log_inputs(

View File

@ -28,6 +28,11 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
""" """
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class TokensPrompt(TypedDict): class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt.""" """Schema for a tokenized prompt."""
@ -52,6 +57,11 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
""" """
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
""" """
@ -141,11 +151,17 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available. The original prompt text corresponding to the token IDs, if available.
""" """
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def token_inputs( def token_inputs(
prompt_token_ids: list[int], prompt_token_ids: list[int],
token_type_ids: Optional[list[int]] = None, token_type_ids: Optional[list[int]] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
cache_salt: Optional[str] = None,
) -> TokenInputs: ) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values.""" """Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
@ -154,6 +170,8 @@ def token_inputs(
inputs["prompt"] = prompt inputs["prompt"] = prompt
if token_type_ids is not None: if token_type_ids is not None:
inputs["token_type_ids"] = token_type_ids inputs["token_type_ids"] = token_type_ids
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs return inputs

View File

@ -17,7 +17,8 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs) PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -283,6 +284,29 @@ class InputPreprocessor:
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs, return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes) return_mm_hashes)
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
ParsedTextPrompt,
ParsedTokensPrompt]):
prompt_text = None
prompt_token_ids = None
token_type_ids = None
cache_salt = None
if parsed_prompt["type"] == "str":
prompt_text = parsed_prompt["content"]
else:
cache_salt = parsed_prompt["content"].get("cache_salt")
if parsed_prompt["type"] == "text":
prompt_text = parsed_prompt["content"]["prompt"]
elif parsed_prompt["type"] == "tokens":
prompt_token_ids = parsed_prompt["content"].get(
"prompt_token_ids")
token_type_ids = parsed_prompt["content"].get("token_type_ids")
else:
assert_never(parsed_prompt)
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
@ -304,70 +328,36 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance * :class:`SingletonInputs` instance
""" """
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
if parsed["type"] == "str": # If multimodal data is present, process and return immediately
prompt_text = parsed["content"] if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = self._process_multimodal(
prompt_text if prompt_text is not None else prompt_token_ids,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return token_inputs( return token_inputs(
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) token_type_ids=token_type_ids,
cache_salt=cache_salt,
if parsed["type"] == "tokens": )
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
token_type_ids = tokens_content.get("token_type_ids")
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return self._process_multimodal(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
assert_never(parsed)
async def _prompt_to_llm_inputs_async( async def _prompt_to_llm_inputs_async(
self, self,
@ -379,64 +369,35 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
prompt_text = parsed["content"] self._get_prompt_data(parsed)
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = await self._process_multimodal_async(
prompt_token_ids if prompt_text is None else prompt_text,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
lora_request=lora_request, lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
return token_inputs( return token_inputs(
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) token_type_ids=token_type_ids,
cache_salt=cache_salt,
if parsed["type"] == "tokens": )
tokens_content = parsed["content"]
prompt_token_ids = tokens_content["prompt_token_ids"]
multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(prompt_token_ids=prompt_token_ids)
if parsed["type"] == "text":
text_content = parsed["content"]
prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None:
return await self._process_multimodal_async(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
assert_never(parsed)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
@ -516,6 +477,11 @@ class InputPreprocessor:
mm_hashes=inputs["mm_hashes"], mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
cache_salt = inputs.get("cache_salt")
if cache_salt is not None:
decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token": elif inputs["type"] == "token":
# Text-only inputs # Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])

View File

@ -826,6 +826,11 @@ class MultiModalInputs(TypedDict):
:code:`prompt_token_ids`. :code:`prompt_token_ids`.
""" """
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class MultiModalEncDecInputs(MultiModalInputs): class MultiModalEncDecInputs(MultiModalInputs):
""" """

View File

@ -275,7 +275,10 @@ def need_extra_keys(request: Request) -> bool:
# Multimodal requests need to include the MM hash. # Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID. # LoRA requests need to include the LoRA ID.
return bool(request.mm_positions) or (request.lora_request is not None) # Request with provided cache salt need to include the salt.
return bool(request.mm_positions) or (request.lora_request
is not None) or (request.cache_salt
is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
@ -380,8 +383,10 @@ def generate_block_hash_extra_keys(
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys( mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx) request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request) lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
cache_salt_keys: list[str] = [request.cache_salt] if (
start_token_idx == 0 and request.cache_salt) else []
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
if not extra_keys: if not extra_keys:
return None, new_start_mm_idx return None, new_start_mm_idx

View File

@ -57,6 +57,7 @@ class EngineCoreRequest(
eos_token_id: Optional[int] eos_token_id: Optional[int]
arrival_time: float arrival_time: float
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
# Used in DP case to indicate which wave of requests this is expected to # Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before # belong to, to cover a race condition where the request is sent before

View File

@ -317,6 +317,7 @@ class Processor:
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
) )
def _validate_model_inputs(self, def _validate_model_inputs(self,

View File

@ -29,6 +29,7 @@ class Request:
arrival_time: float, arrival_time: float,
lora_request: Optional["LoRARequest"] = None, lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.sampling_params = sampling_params self.sampling_params = sampling_params
@ -51,6 +52,7 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related # Multi-modal related
self.mm_positions = multi_modal_placeholders or [] self.mm_positions = multi_modal_placeholders or []
@ -89,6 +91,7 @@ class Request:
lora_request=request.lora_request, lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest( structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params), sampling_params=request.sampling_params),
cache_salt=request.cache_salt,
) )
def append_output_token_ids( def append_output_token_ids(