mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 14:35:52 +08:00
[Frontend] Support cache_salt in /v1/completions and /v1/responses (#20981)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
parent
f29fd8a7f8
commit
19c863068b
@ -1540,6 +1540,7 @@ async def init_app_state(
|
|||||||
state.openai_serving_models,
|
state.openai_serving_models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
enable_force_include_usage=args.enable_force_include_usage,
|
enable_force_include_usage=args.enable_force_include_usage,
|
||||||
) if "generate" in model_config.supported_tasks else None
|
) if "generate" in model_config.supported_tasks else None
|
||||||
state.openai_serving_pooling = OpenAIServingPooling(
|
state.openai_serving_pooling = OpenAIServingPooling(
|
||||||
|
|||||||
@ -290,6 +290,15 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
"default: 0). Any priority other than 0 will raise an error "
|
"default: 0). Any priority other than 0 will raise an error "
|
||||||
"if the served model does not use priority scheduling."),
|
"if the served model does not use priority scheduling."),
|
||||||
)
|
)
|
||||||
|
cache_salt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the prefix cache will be salted with the provided "
|
||||||
|
"string to prevent an attacker to guess prompts in multi-user "
|
||||||
|
"environments. The salt should be random, protected from "
|
||||||
|
"access by 3rd parties, and long enough to be "
|
||||||
|
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||||
|
"to 256 bit). Not supported by vLLM engine V0."))
|
||||||
# --8<-- [end:responses-extra-params]
|
# --8<-- [end:responses-extra-params]
|
||||||
|
|
||||||
_DEFAULT_SAMPLING_PARAMS = {
|
_DEFAULT_SAMPLING_PARAMS = {
|
||||||
@ -351,6 +360,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
raise ValueError("prompt template is not supported")
|
raise ValueError("prompt template is not supported")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def check_cache_salt_support(cls, data):
|
||||||
|
if data.get("cache_salt") is not None:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError(
|
||||||
|
"Parameter 'cache_salt' is not supported with "
|
||||||
|
"this instance of vLLM, which uses engine V0.")
|
||||||
|
if not isinstance(data["cache_salt"],
|
||||||
|
str) or not data["cache_salt"]:
|
||||||
|
raise ValueError("Parameter 'cache_salt' must be a "
|
||||||
|
"non-empty string if provided.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(OpenAIBaseModel):
|
class ChatCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
@ -1004,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||||
"that are not JSON-encodable can be identified."))
|
"that are not JSON-encodable can be identified."))
|
||||||
|
|
||||||
|
cache_salt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, the prefix cache will be salted with the provided "
|
||||||
|
"string to prevent an attacker to guess prompts in multi-user "
|
||||||
|
"environments. The salt should be random, protected from "
|
||||||
|
"access by 3rd parties, and long enough to be "
|
||||||
|
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||||
|
"to 256 bit). Not supported by vLLM engine V0."))
|
||||||
|
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="KVTransfer parameters used for disaggregated serving.")
|
description="KVTransfer parameters used for disaggregated serving.")
|
||||||
@ -1180,6 +1212,20 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
"At least one of `prompt` or `prompt_embeds` must be set.")
|
"At least one of `prompt` or `prompt_embeds` must be set.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_cache_salt_support(cls, data):
|
||||||
|
if data.get("cache_salt") is not None:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError(
|
||||||
|
"Parameter 'cache_salt' is not supported with "
|
||||||
|
"this instance of vLLM, which uses engine V0.")
|
||||||
|
if not isinstance(data["cache_salt"],
|
||||||
|
str) or not data["cache_salt"]:
|
||||||
|
raise ValueError("Parameter 'cache_salt' must be a "
|
||||||
|
"non-empty string if provided.")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
@ -1971,7 +2017,7 @@ class TranscriptionRequest(OpenAIBaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
"""When set, it will enable output to be streamed in a similar fashion
|
"""When set, it will enable output to be streamed in a similar fashion
|
||||||
as the Chat Completion endpoint.
|
as the Chat Completion endpoint.
|
||||||
"""
|
"""
|
||||||
# --8<-- [start:transcription-extra-params]
|
# --8<-- [start:transcription-extra-params]
|
||||||
@ -2233,9 +2279,9 @@ class TranslationRequest(OpenAIBaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
"""Custom field not present in the original OpenAI definition. When set,
|
"""Custom field not present in the original OpenAI definition. When set,
|
||||||
it will enable output to be streamed in a similar fashion as the Chat
|
it will enable output to be streamed in a similar fashion as the Chat
|
||||||
Completion endpoint.
|
Completion endpoint.
|
||||||
"""
|
"""
|
||||||
# Flattened stream option to simplify form data.
|
# Flattened stream option to simplify form data.
|
||||||
stream_include_usage: Optional[bool] = False
|
stream_include_usage: Optional[bool] = False
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
PromptTokenUsageInfo,
|
||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
@ -56,6 +57,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
*,
|
*,
|
||||||
request_logger: Optional[RequestLogger],
|
request_logger: Optional[RequestLogger],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
|
enable_prompt_tokens_details: bool = False,
|
||||||
enable_force_include_usage: bool = False,
|
enable_force_include_usage: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
@ -64,6 +66,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||||
enable_force_include_usage=enable_force_include_usage)
|
enable_force_include_usage=enable_force_include_usage)
|
||||||
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||||
self.default_sampling_params = (
|
self.default_sampling_params = (
|
||||||
self.model_config.get_diff_sampling_param())
|
self.model_config.get_diff_sampling_param())
|
||||||
if self.default_sampling_params:
|
if self.default_sampling_params:
|
||||||
@ -313,6 +316,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
previous_num_tokens = [0] * num_choices * num_prompts
|
previous_num_tokens = [0] * num_choices * num_prompts
|
||||||
has_echoed = [False] * num_choices * num_prompts
|
has_echoed = [False] * num_choices * num_prompts
|
||||||
num_prompt_tokens = [0] * num_prompts
|
num_prompt_tokens = [0] * num_prompts
|
||||||
|
num_cached_tokens = None
|
||||||
|
first_iteration = True
|
||||||
|
|
||||||
stream_options = request.stream_options
|
stream_options = request.stream_options
|
||||||
if stream_options:
|
if stream_options:
|
||||||
@ -328,6 +333,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
prompt_token_ids = res.prompt_token_ids
|
prompt_token_ids = res.prompt_token_ids
|
||||||
prompt_logprobs = res.prompt_logprobs
|
prompt_logprobs = res.prompt_logprobs
|
||||||
|
|
||||||
|
if first_iteration:
|
||||||
|
num_cached_tokens = res.num_cached_tokens
|
||||||
|
first_iteration = False
|
||||||
|
|
||||||
if res.prompt is not None:
|
if res.prompt is not None:
|
||||||
prompt_text = res.prompt
|
prompt_text = res.prompt
|
||||||
else:
|
else:
|
||||||
@ -431,6 +440,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
completion_tokens=total_completion_tokens,
|
completion_tokens=total_completion_tokens,
|
||||||
total_tokens=total_prompt_tokens + total_completion_tokens)
|
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||||
|
|
||||||
|
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||||
|
final_usage_info.prompt_tokens_details = PromptTokenUsageInfo(
|
||||||
|
cached_tokens=num_cached_tokens)
|
||||||
|
|
||||||
if include_usage:
|
if include_usage:
|
||||||
final_usage_chunk = CompletionStreamResponse(
|
final_usage_chunk = CompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
@ -535,6 +548,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||||
|
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||||
|
cached_tokens=final_res.num_cached_tokens)
|
||||||
|
|
||||||
request_metadata.final_usage_info = usage
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
|
|||||||
@ -226,7 +226,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||||
"""
|
"""
|
||||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||||
given tokenizer.
|
given tokenizer.
|
||||||
"""
|
"""
|
||||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||||
@ -811,6 +811,12 @@ class OpenAIServing:
|
|||||||
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
prompt_token_ids=request_prompt_text["prompt_token_ids"])
|
||||||
for request_prompt_text in request_prompts_text
|
for request_prompt_text in request_prompts_text
|
||||||
]
|
]
|
||||||
|
cache_salt = request.cache_salt if (
|
||||||
|
hasattr(request, "cache_salt")
|
||||||
|
and request.cache_salt is not None) else None
|
||||||
|
if cache_salt:
|
||||||
|
for prompt_text in engine_prompts_text:
|
||||||
|
prompt_text["cache_salt"] = cache_salt
|
||||||
|
|
||||||
# This check is equivalent to simply checking if
|
# This check is equivalent to simply checking if
|
||||||
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
# `request_prompts_embeds` is empty, but it's difficult to propagate
|
||||||
@ -828,6 +834,9 @@ class OpenAIServing:
|
|||||||
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
prompt_embeds=request_prompt_embeds["prompt_embeds"])
|
||||||
for request_prompt_embeds in request_prompts_embeds
|
for request_prompt_embeds in request_prompts_embeds
|
||||||
]
|
]
|
||||||
|
if cache_salt:
|
||||||
|
for prompt_embed in engine_prompts_embeds:
|
||||||
|
prompt_embed["cache_salt"] = cache_salt
|
||||||
|
|
||||||
request_prompts = request_prompts_embeds + request_prompts_text
|
request_prompts = request_prompts_embeds + request_prompts_text
|
||||||
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
engine_prompts = engine_prompts_embeds + engine_prompts_text
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user