[Frontend] Add per-request number of cached token stats (#10174)

This commit is contained in:
zifeitong 2024-11-12 08:42:28 -08:00 committed by GitHub
parent 176fcb1c71
commit 47db6ec831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 89 additions and 23 deletions

View File

@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1]) @pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("block_size", [16])
def test_mixed_requests( def test_mixed_requests(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
@ -36,11 +37,12 @@ def test_mixed_requests(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
cached_position: int, cached_position: int,
block_size: int,
monkeypatch, monkeypatch,
) -> None: ) -> None:
""" """
Test the case when some sequences have the prefix cache hit Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where and the others don't. The cached position determines where
the sequence is at among the batch of prefills. the sequence is at among the batch of prefills.
""" """
override_backend_env_variable(monkeypatch, backend) override_backend_env_variable(monkeypatch, backend)
@ -53,12 +55,30 @@ def test_mixed_requests(
model, model,
dtype=dtype, dtype=dtype,
enable_prefix_caching=True, enable_prefix_caching=True,
block_size=block_size,
) as vllm_model: ) as vllm_model:
# Run the first prompt so the cache is populated # Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
# Run all the promopts # Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)
# Verify number of cached tokens
for i in range(len(req_outputs)):
if i == cached_position:
expected_num_cached_tokens = (
len(req_outputs[i].prompt_token_ids) //
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens
vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs
]
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,

View File

@ -540,6 +540,7 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser, tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None ) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion( state.openai_serving_completion = OpenAIServingCompletion(
engine_client, engine_client,

View File

@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False, default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
) )
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser return parser

View File

@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list) data: List[ModelCard] = Field(default_factory=list)
class PromptTokenUsageInfo(OpenAIBaseModel):
cached_tokens: Optional[int] = None
class UsageInfo(OpenAIBaseModel): class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0 prompt_tokens: int = 0
total_tokens: int = 0 total_tokens: int = 0
completion_tokens: Optional[int] = 0 completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
class RequestResponseMetadata(BaseModel): class RequestResponseMetadata(BaseModel):

View File

@ -78,6 +78,11 @@ def parse_args():
help="Port number for the Prometheus metrics server " help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).", "(only needed if enable-metrics is set).",
) )
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser.parse_args() return parser.parse_args()
@ -217,6 +222,7 @@ async def main(args):
prompt_adapters=None, prompt_adapters=None,
request_logger=request_logger, request_logger=request_logger,
chat_template=None, chat_template=None,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None ) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding( openai_serving_embedding = OpenAIServingEmbedding(
engine, engine,

View File

@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
ToolCall, UsageInfo) RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath, from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath, LoRAModulePath,
OpenAIServing, OpenAIServing,
@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str], chat_template: Optional[str],
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
tool_parser: Optional[str] = None): tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
super().__init__(engine_client=engine_client, super().__init__(engine_client=engine_client,
model_config=model_config, model_config=model_config,
base_model_paths=base_model_paths, base_model_paths=base_model_paths,
@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
f"tool_parser:'{tool_parser}' which has not " f"tool_parser:'{tool_parser}' which has not "
"been registered") from e "been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
async def create_chat_completion( async def create_chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens = [0] * num_choices previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0 num_prompt_tokens = 0
num_cached_tokens = None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
if first_iteration: if first_iteration:
num_cached_tokens = res.num_cached_tokens
# Send first response for each request.n (index) with # Send first response for each request.n (index) with
# the role # the role
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage # is sent, send the usage
if include_usage: if include_usage:
completion_tokens = sum(previous_num_tokens) completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo( final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens,
completion_tokens=completion_tokens, total_tokens=num_prompt_tokens +
total_tokens=num_prompt_tokens + completion_tokens, completion_tokens)
) if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
num_prompt_tokens += len(final_res.encoder_prompt_token_ids) num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum( num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(prompt_tokens=num_prompt_tokens,
prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens,
completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens +
total_tokens=num_prompt_tokens + num_generated_tokens, num_generated_tokens)
) if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage request_metadata.final_usage_info = usage

View File

@ -83,10 +83,11 @@ class RequestOutput:
finished: Whether the whole request is finished. finished: Whether the whole request is finished.
metrics: Metrics associated with the request. metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request; encoder_prompt: The encoder prompt string of the request.
None if decoder-only None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt; encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
""" """
def __init__( def __init__(
@ -101,6 +102,7 @@ class RequestOutput:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None, encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None, encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
@ -112,6 +114,7 @@ class RequestOutput:
self.lora_request = lora_request self.lora_request = lora_request
self.encoder_prompt = encoder_prompt self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
@classmethod @classmethod
def new( def new(
@ -192,6 +195,8 @@ class RequestOutput:
outputs = [] outputs = []
include_prompt = True include_prompt = True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens = None
for i, seq in enumerate(top_n_seqs): for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return( output_text = seq.get_output_text_to_return(
text_buffer_length, delta) text_buffer_length, delta)
@ -199,6 +204,7 @@ class RequestOutput:
output_token_ids = seq.get_output_token_ids_to_return(delta) output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids, num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids) int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens()
output_logprobs = seq.output_logprobs if include_logprobs else None output_logprobs = seq.output_logprobs if include_logprobs else None
@ -272,7 +278,7 @@ class RequestOutput:
init_args = (seq_group.request_id, prompt, prompt_token_ids, init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics, prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt, seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids) encoder_prompt_token_ids, num_cached_tokens)
if use_cache: if use_cache:
request_output = seq_group.cached_request_output request_output = seq_group.cached_request_output
@ -293,7 +299,8 @@ class RequestOutput:
f"outputs={self.outputs}, " f"outputs={self.outputs}, "
f"finished={self.finished}, " f"finished={self.finished}, "
f"metrics={self.metrics}, " f"metrics={self.metrics}, "
f"lora_request={self.lora_request})") f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")
class EmbeddingRequestOutput: class EmbeddingRequestOutput:

View File

@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
...] = msgspec.field(default_factory=tuple) ...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0 _num_computed_tokens: int = 0
# The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL _stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list) _cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
if self.get_num_uncomputed_tokens() == 0: if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE self._stage = SequenceStage.DECODE
def get_num_cached_tokens(self) -> int:
"""Return the number of tokens with prefix cache hit."""
return self._num_cached_tokens
def update_num_cached_tokens(self, num_cached_tokens: int):
"""Update the number of tokens with prefix cache hit."""
self._num_cached_tokens = num_cached_tokens
def reset_state_for_recompute(self) -> None: def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is """Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from supposed to be called when a sequence needs to be started from
@ -379,7 +389,7 @@ class SequenceData(msgspec.Struct,
class Sequence: class Sequence:
"""Stores the data, status, and block information of a sequence. """Stores the data, status, and block information of a sequence.
The sequence is constructed from the :data:`DecoderOnlyInputs` The sequence is constructed from the :data:`DecoderOnlyInputs`
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder) (for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
instance passed in through the :code:`inputs` constructor argument. instance passed in through the :code:`inputs` constructor argument.
@ -906,7 +916,7 @@ class SequenceGroupMetadata(
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
mm_processor_kwargs: Multimodal input processor / mapper overrides. mm_processor_kwargs: Multimodal input processor / mapper overrides.
encoder_seq_data: Optional sequence data for encoder prompt encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
model. model.
cross_block_table: Optional cross-attention block table associated cross_block_table: Optional cross-attention block table associated

View File

@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# this may be larger than the sequence length if chunked # this may be larger than the sequence length if chunked
# prefill is enabled. # prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size prefix_cache_len = len(computed_block_nums) * self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len)
# The number of so far computed prompt tokens in this sequence. # The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx] context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence. # The total number of prompt tokens in this sequence.