[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)

Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
Marko Rosenmueller 2025-07-15 23:01:04 +02:00 committed by GitHub
parent f29fd8a7f8
commit 19c863068b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 4 deletions

View File

@ -1540,6 +1540,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(

View File

@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."),
)
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
# --8<-- [end:responses-extra-params]
_DEFAULT_SAMPLING_PARAMS = {
@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel):
raise ValueError("prompt template is not supported")
return data
@model_validator(mode="before")
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.")
@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel):
"At least one of `prompt` or `prompt_embeds` must be set.")
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
@ -1971,7 +2017,7 @@ class TranscriptionRequest(OpenAIBaseModel):
"""
stream: Optional[bool] = False
"""When set, it will enable output to be streamed in a similar fashion
"""When set, it will enable output to be streamed in a similar fashion
as the Chat Completion endpoint.
"""
# --8<-- [start:transcription-extra-params]
@ -2233,9 +2279,9 @@ class TranslationRequest(OpenAIBaseModel):
"""
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing):
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options
if stream_options:
@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
if res.prompt is not None:
prompt_text = res.prompt
else:
@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing):
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
return CompletionResponse(

View File

@ -226,7 +226,7 @@ class OpenAIServing:
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
@ -811,6 +811,12 @@ class OpenAIServing:
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
cache_salt = request.cache_salt if (
hasattr(request, "cache_salt")
and request.cache_salt is not None) else None
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
@ -828,6 +834,9 @@ class OpenAIServing:
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text