mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 19:59:54 +08:00
[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
parent
f29fd8a7f8
commit
19c863068b
@ -1540,6 +1540,7 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
|
||||
@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit). Not supported by vLLM engine V0."))
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
raise ValueError("prompt template is not supported")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' is not supported with "
|
||||
"this instance of vLLM, which uses engine V0.")
|
||||
if not isinstance(data["cache_salt"],
|
||||
str) or not data["cache_salt"]:
|
||||
raise ValueError("Parameter 'cache_salt' must be a "
|
||||
"non-empty string if provided.")
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."))
|
||||
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit). Not supported by vLLM engine V0."))
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.")
|
||||
@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"At least one of `prompt` or `prompt_embeds` must be set.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' is not supported with "
|
||||
"this instance of vLLM, which uses engine V0.")
|
||||
if not isinstance(data["cache_salt"],
|
||||
str) or not data["cache_salt"]:
|
||||
raise ValueError("Parameter 'cache_salt' must be a "
|
||||
"non-empty string if provided.")
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -1971,7 +2017,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""When set, it will enable output to be streamed in a similar fashion
|
||||
"""When set, it will enable output to be streamed in a similar fashion
|
||||
as the Chat Completion endpoint.
|
||||
"""
|
||||
# --8<-- [start:transcription-extra-params]
|
||||
@ -2233,9 +2279,9 @@ class TranslationRequest(OpenAIBaseModel):
|
||||
"""
|
||||
|
||||
stream: Optional[bool] = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: Optional[bool] = False
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(engine_client=engine_client,
|
||||
@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
enable_force_include_usage=enable_force_include_usage)
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if self.default_sampling_params:
|
||||
@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
num_cached_tokens = None
|
||||
first_iteration = True
|
||||
|
||||
stream_options = request.stream_options
|
||||
if stream_options:
|
||||
@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
first_iteration = False
|
||||
|
||||
if res.prompt is not None:
|
||||
prompt_text = res.prompt
|
||||
else:
|
||||
@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens)
|
||||
|
||||
if include_usage:
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
return CompletionResponse(
|
||||
|
||||
@ -226,7 +226,7 @@ class OpenAIServing:
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
given tokenizer.
|
||||
"""
|
||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||
@ -811,6 +811,12 @@ class OpenAIServing:
|
||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||
for request_prompt_text in request_prompts_text
|
||||
]
|
||||
cache_salt = request.cache_salt if (
|
||||
hasattr(request, "cache_salt")
|
||||
and request.cache_salt is not None) else None
|
||||
if cache_salt:
|
||||
for prompt_text in engine_prompts_text:
|
||||
prompt_text["cache_salt"] = cache_salt
|
||||
|
||||
# This check is equivalent to simply checking if
|
||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||
@ -828,6 +834,9 @@ class OpenAIServing:
|
||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||
for request_prompt_embeds in request_prompts_embeds
|
||||
]
|
||||
if cache_salt:
|
||||
for prompt_embed in engine_prompts_embeds:
|
||||
prompt_embed["cache_salt"] = cache_salt
|
||||
|
||||
request_prompts = request_prompts_embeds + request_prompts_text
|
||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user