[Frontend] Add per-request number of cached token stats (#10174)

This commit is contained in:
zifeitong 2024-11-12 08:42:28 -08:00 committed by GitHub
parent 176fcb1c71
commit 47db6ec831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 89 additions and 23 deletions

View File

@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("block_size", [16])
def test_mixed_requests(
hf_runner,
vllm_runner,
@ -36,6 +37,7 @@ def test_mixed_requests(
dtype: str,
max_tokens: int,
cached_position: int,
block_size: int,
monkeypatch,
) -> None:
"""
@ -53,12 +55,30 @@ def test_mixed_requests(
model,
dtype=dtype,
enable_prefix_caching=True,
block_size=block_size,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)
# Verify number of cached tokens
for i in range(len(req_outputs)):
if i == cached_position:
expected_num_cached_tokens = (
len(req_outputs[i].prompt_token_ids) //
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens
vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs
]
check_outputs_equal(
outputs_0_lst=hf_outputs,

View File

@ -540,6 +540,7 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,

View File

@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser

View File

@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
data: List[ModelCard] = Field(default_factory=list)
class PromptTokenUsageInfo(OpenAIBaseModel):
cached_tokens: Optional[int] = None
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
class RequestResponseMetadata(BaseModel):

View File

@ -78,6 +78,11 @@ def parse_args():
help="Port number for the Prometheus metrics server "
"(only needed if enable-metrics is set).",
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
return parser.parse_args()
@ -217,6 +222,7 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,

View File

@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None):
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
async def create_chat_completion(
self,
request: ChatCompletionRequest,
@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
num_cached_tokens = None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
num_cached_tokens = res.num_cached_tokens
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
if include_usage:
completion_tokens = sum(previous_num_tokens)
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
total_tokens=num_prompt_tokens +
completion_tokens)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens)
final_usage_chunk = ChatCompletionStreamResponse(
id=request_id,
@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
total_tokens=num_prompt_tokens +
num_generated_tokens)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage

View File

@ -83,10 +83,11 @@ class RequestOutput:
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request;
None if decoder-only
encoder_prompt_token_ids: The token IDs of the encoder prompt;
None if decoder-only
encoder_prompt: The encoder prompt string of the request.
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
"""
def __init__(
@ -101,6 +102,7 @@ class RequestOutput:
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[List[int]] = None,
num_cached_tokens: Optional[int] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
@ -112,6 +114,7 @@ class RequestOutput:
self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
@classmethod
def new(
@ -192,6 +195,8 @@ class RequestOutput:
outputs = []
include_prompt = True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens = None
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
@ -199,6 +204,7 @@ class RequestOutput:
output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens()
output_logprobs = seq.output_logprobs if include_logprobs else None
@ -272,7 +278,7 @@ class RequestOutput:
init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids)
encoder_prompt_token_ids, num_cached_tokens)
if use_cache:
request_output = seq_group.cached_request_output
@ -293,7 +299,8 @@ class RequestOutput:
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request})")
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens})")
class EmbeddingRequestOutput:

View File

@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
...] = msgspec.field(default_factory=tuple)
# The number of tokens that are computed (that run against the model).
_num_computed_tokens: int = 0
# The number of tokens with prefix cache hit.
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
def get_num_cached_tokens(self) -> int:
"""Return the number of tokens with prefix cache hit."""
return self._num_cached_tokens
def update_num_cached_tokens(self, num_cached_tokens: int):
"""Update the number of tokens with prefix cache hit."""
self._num_cached_tokens = num_cached_tokens
def reset_state_for_recompute(self) -> None:
"""Reset the number of computed tokens from this sequence. It is
supposed to be called when a sequence needs to be started from

View File

@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
seq_group_metadata.seq_data[inter_data.seq_ids[
seq_idx]].update_num_cached_tokens(prefix_cache_len)
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.