[BugFix] Fix and simplify completion API usage streaming (#9475)

This commit is contained in:
Nick Hill 2024-10-18 15:10:26 +01:00 committed by GitHub
parent d2b1bf55ec
commit 25aeb7d4c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts num_prompt_tokens = [0] * num_prompts
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids prompt_token_ids = res.prompt_token_ids
@ -276,18 +284,15 @@ class OpenAIServingCompletion(OpenAIServing):
i = output.index + prompt_idx * num_choices i = output.index + prompt_idx * num_choices
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt # only return the prompt
delta_text = prompt_text delta_text = prompt_text
delta_token_ids = prompt_token_ids delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
has_echoed[i] = True else:
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
delta_text = prompt_text + output.text delta_text = prompt_text + output.text
@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
stop_reason=stop_reason, stop_reason=stop_reason,
) )
]) ])
if (request.stream_options if include_continuous_usage:
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = num_prompt_tokens[prompt_idx] prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i] completion_tokens = previous_num_tokens[i]
usage = UsageInfo( chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
) )
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=False) response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if (request.stream_options total_prompt_tokens = sum(num_prompt_tokens)
and request.stream_options.include_usage): total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse( final_usage_chunk = CompletionStreamResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
model=model_name, model=model_name,
choices=[], choices=[],
usage=usage, usage=final_usage_info,
) )
final_usage_data = (final_usage_chunk.model_dump_json( final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True)) exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n" yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices # report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens) request_metadata.final_usage_info = final_usage_info
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
@ -413,13 +412,13 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs: for output in final_res.outputs:
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo:
assert prompt_text is not None assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids token_ids = prompt_token_ids
out_logprobs = prompt_logprobs out_logprobs = prompt_logprobs
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: else:
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids] token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None: if request.logprobs is None: