From 19c863068b2d70a452bde25318dbcf04f274ad19 Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Tue, 15 Jul 2025 23:01:04 +0200 Subject: [PATCH] [Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981) Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/protocol.py | 52 +++++++++++++++++-- vllm/entrypoints/openai/serving_completion.py | 17 ++++++ vllm/entrypoints/openai/serving_engine.py | 11 +++- 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 65ceeff8eb4e6..19d0110ff3712 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1540,6 +1540,7 @@ async def init_app_state( state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, ) if "generate" in model_config.supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fdac6ccd19ed6..f17faa23d01c0 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel): "default: 0). Any priority other than 0 will raise an error " "if the served model does not use priority scheduling."), ) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) # --8<-- [end:responses-extra-params] _DEFAULT_SAMPLING_PARAMS = { @@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel): raise ValueError("prompt template is not supported") return data + @model_validator(mode="before") + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters used for disaggregated serving.") @@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel): "At least one of `prompt` or `prompt_embeds` must be set.") return data + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1971,7 +2017,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ stream: Optional[bool] = False - """When set, it will enable output to be streamed in a similar fashion + """When set, it will enable output to be streamed in a similar fashion as the Chat Completion endpoint. """ # --8<-- [start:transcription-extra-params] @@ -2233,9 +2279,9 @@ class TranslationRequest(OpenAIBaseModel): """ stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6c9c29b714457..eb9a35a7a37d3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) from vllm.entrypoints.openai.serving_engine import ( @@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing): *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, ): super().__init__(engine_client=engine_client, @@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing): request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage) + self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: @@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing): previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + num_cached_tokens = None + first_iteration = True stream_options = request.stream_options if stream_options: @@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing): prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs + if first_iteration: + num_cached_tokens = res.num_cached_tokens + first_iteration = False + if res.prompt is not None: prompt_text = res.prompt else: @@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing): completion_tokens=total_completion_tokens, total_tokens=total_prompt_tokens + total_completion_tokens) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) + if include_usage: final_usage_chunk = CompletionStreamResponse( id=request_id, @@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing): total_tokens=num_prompt_tokens + num_generated_tokens, ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + request_metadata.final_usage_info = usage return CompletionResponse( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index dab5ac0325327..462317a0878c7 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -226,7 +226,7 @@ class OpenAIServing: def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ - Return (and cache) an `AsyncMicrobatchTokenizer` bound to the + Return (and cache) an `AsyncMicrobatchTokenizer` bound to the given tokenizer. """ async_tokenizer = self._async_tokenizer_pool.get(tokenizer) @@ -811,6 +811,12 @@ class OpenAIServing: prompt_token_ids=request_prompt_text["prompt_token_ids"]) for request_prompt_text in request_prompts_text ] + cache_salt = request.cache_salt if ( + hasattr(request, "cache_salt") + and request.cache_salt is not None) else None + if cache_salt: + for prompt_text in engine_prompts_text: + prompt_text["cache_salt"] = cache_salt # This check is equivalent to simply checking if # `request_prompts_embeds` is empty, but it's difficult to propagate @@ -828,6 +834,9 @@ class OpenAIServing: prompt_embeds=request_prompt_embeds["prompt_embeds"]) for request_prompt_embeds in request_prompts_embeds ] + if cache_salt: + for prompt_embed in engine_prompts_embeds: + prompt_embed["cache_salt"] = cache_salt request_prompts = request_prompts_embeds + request_prompts_text engine_prompts = engine_prompts_embeds + engine_prompts_text