mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[Frontend] Add per-request number of cached token stats (#10174)
This commit is contained in:
parent
176fcb1c71
commit
47db6ec831
@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("cached_position", [0, 1])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_mixed_requests(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -36,11 +37,12 @@ def test_mixed_requests(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
cached_position: int,
|
||||
block_size: int,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Test the case when some sequences have the prefix cache hit
|
||||
and the others don't. The cached position determines where
|
||||
and the others don't. The cached position determines where
|
||||
the sequence is at among the batch of prefills.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, backend)
|
||||
@ -53,12 +55,30 @@ def test_mixed_requests(
|
||||
model,
|
||||
dtype=dtype,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
) as vllm_model:
|
||||
# Run the first prompt so the cache is populated
|
||||
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
|
||||
|
||||
# Run all the promopts
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)
|
||||
|
||||
# Verify number of cached tokens
|
||||
for i in range(len(req_outputs)):
|
||||
if i == cached_position:
|
||||
expected_num_cached_tokens = (
|
||||
len(req_outputs[i].prompt_token_ids) //
|
||||
block_size) * block_size
|
||||
else:
|
||||
expected_num_cached_tokens = 0
|
||||
assert req_outputs[
|
||||
i].num_cached_tokens == expected_num_cached_tokens
|
||||
|
||||
vllm_outputs = [
|
||||
(output.prompt_token_ids + list(output.outputs[0].token_ids),
|
||||
output.prompt + output.outputs[0].text) for output in req_outputs
|
||||
]
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
|
||||
@ -540,6 +540,7 @@ def init_app_state(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
|
||||
@ -228,6 +228,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -99,10 +99,15 @@ class ModelList(OpenAIBaseModel):
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PromptTokenUsageInfo(OpenAIBaseModel):
|
||||
cached_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class UsageInfo(OpenAIBaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
prompt_tokens_details: Optional[PromptTokenUsageInfo] = None
|
||||
|
||||
|
||||
class RequestResponseMetadata(BaseModel):
|
||||
|
||||
@ -78,6 +78,11 @@ def parse_args():
|
||||
help="Port number for the Prometheus metrics server "
|
||||
"(only needed if enable-metrics is set).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-prompt-tokens-details",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="If set to True, enable prompt_tokens_details in usage.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -217,6 +222,7 @@ async def main(args):
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
|
||||
@ -18,8 +18,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||
ToolCall, UsageInfo)
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServing,
|
||||
@ -49,7 +49,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None):
|
||||
tool_parser: Optional[str] = None,
|
||||
enable_prompt_tokens_details: bool = False):
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
@ -80,6 +81,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@ -252,6 +255,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
num_cached_tokens = None
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@ -305,6 +309,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
if first_iteration:
|
||||
num_cached_tokens = res.num_cached_tokens
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
@ -530,11 +535,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# is sent, send the usage
|
||||
if include_usage:
|
||||
completion_tokens = sum(previous_num_tokens)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens)
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
@ -702,11 +709,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
num_prompt_tokens += len(final_res.encoder_prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens +
|
||||
num_generated_tokens)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens)
|
||||
|
||||
request_metadata.final_usage_info = usage
|
||||
|
||||
|
||||
@ -83,10 +83,11 @@ class RequestOutput:
|
||||
finished: Whether the whole request is finished.
|
||||
metrics: Metrics associated with the request.
|
||||
lora_request: The LoRA request that was used to generate the output.
|
||||
encoder_prompt: The encoder prompt string of the request;
|
||||
None if decoder-only
|
||||
encoder_prompt_token_ids: The token IDs of the encoder prompt;
|
||||
None if decoder-only
|
||||
encoder_prompt: The encoder prompt string of the request.
|
||||
None if decoder-only.
|
||||
encoder_prompt_token_ids: The token IDs of the encoder prompt.
|
||||
None if decoder-only.
|
||||
num_cached_tokens: The number of tokens with prefix cache hit.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -101,6 +102,7 @@ class RequestOutput:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
encoder_prompt: Optional[str] = None,
|
||||
encoder_prompt_token_ids: Optional[List[int]] = None,
|
||||
num_cached_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
@ -112,6 +114,7 @@ class RequestOutput:
|
||||
self.lora_request = lora_request
|
||||
self.encoder_prompt = encoder_prompt
|
||||
self.encoder_prompt_token_ids = encoder_prompt_token_ids
|
||||
self.num_cached_tokens = num_cached_tokens
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
@ -192,6 +195,8 @@ class RequestOutput:
|
||||
|
||||
outputs = []
|
||||
include_prompt = True
|
||||
# num_cached_tokens should be the same for all the sequences
|
||||
num_cached_tokens = None
|
||||
for i, seq in enumerate(top_n_seqs):
|
||||
output_text = seq.get_output_text_to_return(
|
||||
text_buffer_length, delta)
|
||||
@ -199,6 +204,7 @@ class RequestOutput:
|
||||
output_token_ids = seq.get_output_token_ids_to_return(delta)
|
||||
num_output_tokens = 1 if isinstance(output_token_ids,
|
||||
int) else len(output_token_ids)
|
||||
num_cached_tokens = seq.data.get_num_cached_tokens()
|
||||
|
||||
output_logprobs = seq.output_logprobs if include_logprobs else None
|
||||
|
||||
@ -272,7 +278,7 @@ class RequestOutput:
|
||||
init_args = (seq_group.request_id, prompt, prompt_token_ids,
|
||||
prompt_logprobs, outputs, finished, seq_group.metrics,
|
||||
seq_group.lora_request, encoder_prompt,
|
||||
encoder_prompt_token_ids)
|
||||
encoder_prompt_token_ids, num_cached_tokens)
|
||||
|
||||
if use_cache:
|
||||
request_output = seq_group.cached_request_output
|
||||
@ -293,7 +299,8 @@ class RequestOutput:
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
f"metrics={self.metrics}, "
|
||||
f"lora_request={self.lora_request})")
|
||||
f"lora_request={self.lora_request}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens})")
|
||||
|
||||
|
||||
class EmbeddingRequestOutput:
|
||||
|
||||
@ -167,6 +167,8 @@ class SequenceData(msgspec.Struct,
|
||||
...] = msgspec.field(default_factory=tuple)
|
||||
# The number of tokens that are computed (that run against the model).
|
||||
_num_computed_tokens: int = 0
|
||||
# The number of tokens with prefix cache hit.
|
||||
_num_cached_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: List[int] = msgspec.field(default_factory=list)
|
||||
|
||||
@ -323,6 +325,14 @@ class SequenceData(msgspec.Struct,
|
||||
if self.get_num_uncomputed_tokens() == 0:
|
||||
self._stage = SequenceStage.DECODE
|
||||
|
||||
def get_num_cached_tokens(self) -> int:
|
||||
"""Return the number of tokens with prefix cache hit."""
|
||||
return self._num_cached_tokens
|
||||
|
||||
def update_num_cached_tokens(self, num_cached_tokens: int):
|
||||
"""Update the number of tokens with prefix cache hit."""
|
||||
self._num_cached_tokens = num_cached_tokens
|
||||
|
||||
def reset_state_for_recompute(self) -> None:
|
||||
"""Reset the number of computed tokens from this sequence. It is
|
||||
supposed to be called when a sequence needs to be started from
|
||||
@ -379,7 +389,7 @@ class SequenceData(msgspec.Struct,
|
||||
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
|
||||
The sequence is constructed from the :data:`DecoderOnlyInputs`
|
||||
(for decoder-only) or :data:`EncoderDecoderInputs` (for encoder-decoder)
|
||||
instance passed in through the :code:`inputs` constructor argument.
|
||||
@ -906,7 +916,7 @@ class SequenceGroupMetadata(
|
||||
multi_modal_data: Multi modal data.
|
||||
mm_processor_kwargs: Multimodal input processor / mapper overrides.
|
||||
encoder_seq_data: Optional sequence data for encoder prompt
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
model.
|
||||
cross_block_table: Optional cross-attention block table associated
|
||||
|
||||
@ -542,6 +542,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# this may be larger than the sequence length if chunked
|
||||
# prefill is enabled.
|
||||
prefix_cache_len = len(computed_block_nums) * self.block_size
|
||||
seq_group_metadata.seq_data[inter_data.seq_ids[
|
||||
seq_idx]].update_num_cached_tokens(prefix_cache_len)
|
||||
|
||||
# The number of so far computed prompt tokens in this sequence.
|
||||
context_len = inter_data.context_lens[seq_idx]
|
||||
# The total number of prompt tokens in this sequence.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user