mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:15:34 +08:00
[Frontend] Add per-request number of cached token stats (#10174)
This commit is contained in:
parent
176fcb1c71
commit
47db6ec831
@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
|
|||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("cached_position", [0, 1])
|
@pytest.mark.parametrize("cached_position", [0, 1])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
def test_mixed_requests(
|
def test_mixed_requests(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -36,11 +37,12 @@ def test_mixed_requests(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
cached_position: int,
|
cached_position: int,
|
||||||
|
block_size: int,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Test the case when some sequences have the prefix cache hit
|
Test the case when some sequences have the prefix cache hit
|
||||||
and the others don't. The cached position determines where
|
and the others don't. The cached position determines where
|
||||||
the sequence is at among the batch of prefills.
|
the sequence is at among the batch of prefills.
|
||||||
"""
|
"""
|
||||||
override_backend_env_variable(monkeypatch, backend)
|
override_backend_env_variable(monkeypatch, backend)
|
||||||
@ -53,12 +55,30 @@ def test_mixed_requests(
|
|||||||
model,
|
model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
|
block_size=block_size,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
# Run the first prompt so the cache is populated
|
# Run the first prompt so the cache is populated
|
||||||
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
|
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
|
||||||
|
|
||||||
# Run all the promopts
|
# Run all the promopts
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||||
|
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)
|
||||||
|
|
||||||
|
# Verify number of cached tokens
|
||||||
|
for i in range(len(req_outputs)):
|
||||||
|
if i == cached_position:
|
||||||
|
expected_num_cached_tokens = (
|
||||||
|
len(req_outputs[i].prompt_token_ids) //
|
||||||
|
block_size) * block_size
|
||||||
|
else:
|
||||||
|
expected_num_cached_tokens = 0
|
||||||
|
assert req_outputs[
|
||||||
|
i].num_cached_tokens == expected_num_cached_tokens
|
||||||
|
|
||||||
|
vllm_outputs = [
|
||||||
|
(output.prompt_token_ids + list(output.outputs[0].token_ids),
|
||||||
|
output.prompt + output.outputs[0].text) for output in req_outputs
|
||||||
|
]
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
|
|||||||
@ -540,6 +540,7 @@ def init_app_state(
|
|||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser,
|
tool_parser=args.tool_call_parser,
|
||||||
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.task == "generate" else None
|
) if model_config.task == "generate" else None
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
engine_client,
|
engine_client,
|
||||||
|
|||||||
@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=False,
|
default=False,
|
||||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-prompt-tokens-details",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="If set to True, enable prompt_tokens_details in usage.")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
|
|||||||
data: List[ModelCard] = Field(default_factory=list)
|
data: List[ModelCard] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTokenUsageInfo(OpenAIBaseModel):
|
||||||
|
cached_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class UsageInfo(OpenAIBaseModel):
|
class UsageInfo(OpenAIBaseModel):
|
||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
total_tokens: int = 0
|
total_tokens: int = 0
|
||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
class RequestResponseMetadata(BaseModel):
|
class RequestResponseMetadata(BaseModel):
|
||||||
|
|||||||
@ -78,6 +78,11 @@ def parse_args():
|
|||||||
help="Port number for the Prometheus metrics server "
|
help="Port number for the Prometheus metrics server "
|
||||||
"(only needed if enable-metrics is set).",
|
"(only needed if enable-metrics is set).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-prompt-tokens-details",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="If set to True, enable prompt_tokens_details in usage.")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -217,6 +222,7 @@ async def main(args):
|
|||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.task == "generate" else None
|
) if model_config.task == "generate" else None
|
||||||
openai_serving_embedding = OpenAIServingEmbedding(
|
openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
engine,
|
engine,
|
||||||
|
|||||||
@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||||
ToolCall, UsageInfo)
|
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
LoRAModulePath,
|
LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
return_tokens_as_token_ids: bool = False,
|
return_tokens_as_token_ids: bool = False,
|
||||||
enable_auto_tools: bool = False,
|
enable_auto_tools: bool = False,
|
||||||
tool_parser: Optional[str] = None):
|
tool_parser: Optional[str] = None,
|
||||||
|
enable_prompt_tokens_details: bool = False):
|
||||||
super().__init__(engine_client=engine_client,
|
super().__init__(engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
f"tool_parser:'{tool_parser}' which has not "
|
f"tool_parser:'{tool_parser}' which has not "
|
||||||
"been registered") from e
|
"been registered") from e
|
||||||
|
|
||||||
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
previous_num_tokens = [0] * num_choices
|
previous_num_tokens = [0] * num_choices
|
||||||
finish_reason_sent = [False] * num_choices
|
finish_reason_sent = [False] * num_choices
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
|
num_cached_tokens = None
|
||||||
|
|
||||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
tool_choice_function_name = request.tool_choice.function.name
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# the result_generator, it needs to be sent as the FIRST
|
# the result_generator, it needs to be sent as the FIRST
|
||||||
# response (by the try...catch).
|
# response (by the try...catch).
|
||||||
if first_iteration:
|
if first_iteration:
|
||||||
|
num_cached_tokens = res.num_cached_tokens
|
||||||
# Send first response for each request.n (index) with
|
# Send first response for each request.n (index) with
|
||||||
# the role
|
# the role
|
||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# is sent, send the usage
|
# is sent, send the usage
|
||||||
if include_usage:
|
if include_usage:
|
||||||
completion_tokens = sum(previous_num_tokens)
|
completion_tokens = sum(previous_num_tokens)
|
||||||
final_usage = UsageInfo(
|
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||||
prompt_tokens=num_prompt_tokens,
|
completion_tokens=completion_tokens,
|
||||||
completion_tokens=completion_tokens,
|
total_tokens=num_prompt_tokens +
|
||||||
total_tokens=num_prompt_tokens + completion_tokens,
|
completion_tokens)
|
||||||
)
|
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||||
|
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||||
|
cached_tokens=num_cached_tokens)
|
||||||
|
|
||||||
final_usage_chunk = ChatCompletionStreamResponse(
|
final_usage_chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||||
num_generated_tokens = sum(
|
num_generated_tokens = sum(
|
||||||
len(output.token_ids) for output in final_res.outputs)
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
usage = UsageInfo(
|
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||||
prompt_tokens=num_prompt_tokens,
|
completion_tokens=num_generated_tokens,
|
||||||
completion_tokens=num_generated_tokens,
|
total_tokens=num_prompt_tokens +
|
||||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
num_generated_tokens)
|
||||||
)
|
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||||
|
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||||
|
cached_tokens=final_res.num_cached_tokens)
|
||||||
|
|
||||||
request_metadata.final_usage_info = usage
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
|
|||||||
@ -83,10 +83,11 @@ class RequestOutput:
|
|||||||
finished: Whether the whole request is finished.
|
finished: Whether the whole request is finished.
|
||||||
metrics: Metrics associated with the request.
|
metrics: Metrics associated with the request.
|
||||||
lora_request: The LoRA request that was used to generate the output.
|
lora_request: The LoRA request that was used to generate the output.
|
||||||
encoder_prompt: The encoder prompt string of the request;
|
encoder_prompt: The encoder prompt string of the request.
|
||||||
None if decoder-only
|
None if decoder-only.
|
||||||
encoder_prompt_token_ids: The token IDs of the encoder prompt;
|
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
||||||
None if decoder-only
|
None if decoder-only.
|
||||||
|
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -101,6 +102,7 @@ class RequestOutput:
|
|||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
encoder_prompt: Optional[str] = None,
|
encoder_prompt: Optional[str] = None,
|
||||||
encoder_prompt_token_ids: Optional[List[int]] = None,
|
encoder_prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
num_cached_tokens: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
@ -112,6 +114,7 @@ class RequestOutput:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.encoder_prompt = encoder_prompt
|
self.encoder_prompt = encoder_prompt
|
||||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||||
|
self.num_cached_tokens = num_cached_tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def new(
|
def new(
|
||||||
@ -192,6 +195,8 @@ class RequestOutput:
|
|||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
include_prompt = True
|
include_prompt = True
|
||||||
|
# num_cached_tokens should be the same for all the sequences
|
||||||
|
num_cached_tokens = None
|
||||||
for i, seq in enumerate(top_n_seqs):
|
for i, seq in enumerate(top_n_seqs):
|
||||||
output_text = seq.get_output_text_to_return(
|
output_text = seq.get_output_text_to_return(
|
||||||
text_buffer_length, delta)
|
text_buffer_length, delta)
|
||||||
@ -199,6 +204,7 @@ class RequestOutput:
|
|||||||
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
||||||
num_output_tokens = 1 if isinstance(output_token_ids,
|
num_output_tokens = 1 if isinstance(output_token_ids,
|
||||||
int) else len(output_token_ids)
|
int) else len(output_token_ids)
|
||||||
|
num_cached_tokens = seq.data.get_num_cached_tokens()
|
||||||
|
|
||||||
output_logprobs = seq.output_logprobs if include_logprobs else None
|
output_logprobs = seq.output_logprobs if include_logprobs else None
|
||||||
|
|
||||||
@ -272,7 +278,7 @@ class RequestOutput:
|
|||||||
init_args = (seq_group.request_id, prompt, prompt_token_ids,
|
init_args = (seq_group.request_id, prompt, prompt_token_ids,
|
||||||
prompt_logprobs, outputs, finished, seq_group.metrics,
|
prompt_logprobs, outputs, finished, seq_group.metrics,
|
||||||
seq_group.lora_request, encoder_prompt,
|
seq_group.lora_request, encoder_prompt,
|
||||||
encoder_prompt_token_ids)
|
encoder_prompt_token_ids, num_cached_tokens)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
request_output = seq_group.cached_request_output
|
request_output = seq_group.cached_request_output
|
||||||
@ -293,7 +299,8 @@ class RequestOutput:
|
|||||||
f"outputs={self.outputs}, "
|
f"outputs={self.outputs}, "
|
||||||
f"finished={self.finished}, "
|
f"finished={self.finished}, "
|
||||||
f"metrics={self.metrics}, "
|
f"metrics={self.metrics}, "
|
||||||
f"lora_request={self.lora_request})")
|
f"lora_request={self.lora_request}, "
|
||||||
|
f"num_cached_tokens={self.num_cached_tokens})")
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequestOutput:
|
class EmbeddingRequestOutput:
|
||||||
|
|||||||
@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
|
|||||||
...] = msgspec.field(default_factory=tuple)
|
...] = msgspec.field(default_factory=tuple)
|
||||||
# The number of tokens that are computed (that run against the model).
|
# The number of tokens that are computed (that run against the model).
|
||||||
_num_computed_tokens: int = 0
|
_num_computed_tokens: int = 0
|
||||||
|
# The number of tokens with prefix cache hit.
|
||||||
|
_num_cached_tokens: int = 0
|
||||||
_stage: SequenceStage = SequenceStage.PREFILL
|
_stage: SequenceStage = SequenceStage.PREFILL
|
||||||
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
||||||
|
|
||||||
@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
|
|||||||
if self.get_num_uncomputed_tokens() == 0:
|
if self.get_num_uncomputed_tokens() == 0:
|
||||||
self._stage = SequenceStage.DECODE
|
self._stage = SequenceStage.DECODE
|
||||||
|
|
||||||
|
def get_num_cached_tokens(self) -> int:
|
||||||
|
"""Return the number of tokens with prefix cache hit."""
|
||||||
|
return self._num_cached_tokens
|
||||||
|
|
||||||
|
def update_num_cached_tokens(self, num_cached_tokens: int):
|
||||||
|
"""Update the number of tokens with prefix cache hit."""
|
||||||
|
self._num_cached_tokens = num_cached_tokens
|
||||||
|
|
||||||
def reset_state_for_recompute(self) -> None:
|
def reset_state_for_recompute(self) -> None:
|
||||||
"""Reset the number of computed tokens from this sequence. It is
|
"""Reset the number of computed tokens from this sequence. It is
|
||||||
supposed to be called when a sequence needs to be started from
|
supposed to be called when a sequence needs to be started from
|
||||||
@ -379,7 +389,7 @@ class SequenceData(msgspec.Struct,
|
|||||||
|
|
||||||
class Sequence:
|
class Sequence:
|
||||||
"""Stores the data, status, and block information of a sequence.
|
"""Stores the data, status, and block information of a sequence.
|
||||||
|
|
||||||
The sequence is constructed from the :data:`DecoderOnlyInputs`
|
The sequence is constructed from the :data:`DecoderOnlyInputs`
|
||||||
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
|
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
|
||||||
instance passed in through the :code:`inputs` constructor argument.
|
instance passed in through the :code:`inputs` constructor argument.
|
||||||
@ -906,7 +916,7 @@ class SequenceGroupMetadata(
|
|||||||
multi_modal_data: Multi modal data.
|
multi_modal_data: Multi modal data.
|
||||||
mm_processor_kwargs: Multimodal input processor / mapper overrides.
|
mm_processor_kwargs: Multimodal input processor / mapper overrides.
|
||||||
encoder_seq_data: Optional sequence data for encoder prompt
|
encoder_seq_data: Optional sequence data for encoder prompt
|
||||||
(SequenceGroup.encoder_seq). Should be None
|
(SequenceGroup.encoder_seq). Should be None
|
||||||
unless you are working with an encoder/decoder
|
unless you are working with an encoder/decoder
|
||||||
model.
|
model.
|
||||||
cross_block_table: Optional cross-attention block table associated
|
cross_block_table: Optional cross-attention block table associated
|
||||||
|
|||||||
@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
# this may be larger than the sequence length if chunked
|
# this may be larger than the sequence length if chunked
|
||||||
# prefill is enabled.
|
# prefill is enabled.
|
||||||
prefix_cache_len = len(computed_block_nums) * self.block_size
|
prefix_cache_len = len(computed_block_nums) * self.block_size
|
||||||
|
seq_group_metadata.seq_data[inter_data.seq_ids[
|
||||||
|
seq_idx]].update_num_cached_tokens(prefix_cache_len)
|
||||||
|
|
||||||
# The number of so far computed prompt tokens in this sequence.
|
# The number of so far computed prompt tokens in this sequence.
|
||||||
context_len = inter_data.context_lens[seq_idx]
|
context_len = inter_data.context_lens[seq_idx]
|
||||||
# The total number of prompt tokens in this sequence.
|
# The total number of prompt tokens in this sequence.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user