[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)

Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
Marko Rosenmueller 2025-07-15 23:01:04 +02:00 committed by GitHub
parent f29fd8a7f8
commit 19c863068b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 4 deletions

View File

@ -1540,6 +1540,7 @@ async def init_app_state(
state.openai_serving_models, state.openai_serving_models,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
) if "generate" in model_config.supported_tasks else None ) if "generate" in model_config.supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling( state.openai_serving_pooling = OpenAIServingPooling(

View File

@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
"default: 0). Any priority other than 0 will raise an error " "default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."), "if the served model does not use priority scheduling."),
) )
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
# --8<-- [end:responses-extra-params] # --8<-- [end:responses-extra-params]
_DEFAULT_SAMPLING_PARAMS = { _DEFAULT_SAMPLING_PARAMS = {
@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel):
raise ValueError("prompt template is not supported") raise ValueError("prompt template is not supported")
return data return data
@model_validator(mode="before")
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class ChatCompletionRequest(OpenAIBaseModel): class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
" as strings of the form 'token_id:{token_id}' so that tokens " " as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")) "that are not JSON-encodable can be identified."))
cache_salt: Optional[str] = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit). Not supported by vLLM engine V0."))
kv_transfer_params: Optional[dict[str, Any]] = Field( kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, default=None,
description="KVTransfer parameters used for disaggregated serving.") description="KVTransfer parameters used for disaggregated serving.")
@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel):
"At least one of `prompt` or `prompt_embeds` must be set.") "At least one of `prompt` or `prompt_embeds` must be set.")
return data return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None:
if not envs.VLLM_USE_V1:
raise ValueError(
"Parameter 'cache_salt' is not supported with "
"this instance of vLLM, which uses engine V0.")
if not isinstance(data["cache_salt"],
str) or not data["cache_salt"]:
raise ValueError("Parameter 'cache_salt' must be a "
"non-empty string if provided.")
return data
class EmbeddingCompletionRequest(OpenAIBaseModel): class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseStreamChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, CompletionStreamResponse,
ErrorResponse, ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
UsageInfo) UsageInfo)
from vllm.entrypoints.openai.serving_engine import ( from vllm.entrypoints.openai.serving_engine import (
@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing):
*, *,
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
): ):
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage) enable_force_include_usage=enable_force_include_usage)
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.default_sampling_params = ( self.default_sampling_params = (
self.model_config.get_diff_sampling_param()) self.model_config.get_diff_sampling_param())
if self.default_sampling_params: if self.default_sampling_params:
@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
previous_num_tokens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts num_prompt_tokens = [0] * num_prompts
num_cached_tokens = None
first_iteration = True
stream_options = request.stream_options stream_options = request.stream_options
if stream_options: if stream_options:
@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids = res.prompt_token_ids prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs prompt_logprobs = res.prompt_logprobs
if first_iteration:
num_cached_tokens = res.num_cached_tokens
first_iteration = False
if res.prompt is not None: if res.prompt is not None:
prompt_text = res.prompt prompt_text = res.prompt
else: else:
@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing):
completion_tokens=total_completion_tokens, completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens) total_tokens=total_prompt_tokens + total_completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
if include_usage: if include_usage:
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=request_id, id=request_id,
@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing):
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
) )
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage request_metadata.final_usage_info = usage
return CompletionResponse( return CompletionResponse(

View File

@ -811,6 +811,12 @@ class OpenAIServing:
prompt_token_ids=request_prompt_text["prompt_token_ids"]) prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text for request_prompt_text in request_prompts_text
] ]
cache_salt = request.cache_salt if (
hasattr(request, "cache_salt")
and request.cache_salt is not None) else None
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if # This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate # `request_prompts_embeds` is empty, but it's difficult to propagate
@ -828,6 +834,9 @@ class OpenAIServing:
prompt_embeds=request_prompt_embeds["prompt_embeds"]) prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds for request_prompt_embeds in request_prompts_embeds
] ]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text engine_prompts = engine_prompts_embeds + engine_prompts_text