[BugFix] Fix and simplify completion API usage streaming (#9475)

This commit is contained in:
Nick Hill 2024-10-18 15:10:26 +01:00 committed by GitHub
parent d2b1bf55ec
commit 25aeb7d4c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
include_usage, include_continuous_usage = False, False
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
@ -276,18 +284,15 @@ class OpenAIServingCompletion(OpenAIServing):
i = output.index + prompt_idx * num_choices
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
else:
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = prompt_text + output.text
@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
stop_reason=stop_reason,
)
])
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
if include_continuous_usage:
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
chunk.usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=False)
yield f"data: {response_json}\n\n"
if (request.stream_options
and request.stream_options.include_usage):
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
if include_usage:
final_usage_chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[],
usage=usage,
usage=final_usage_info,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=False, exclude_none=True))
yield f"data: {final_usage_data}\n\n"
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens = sum(num_prompt_tokens)
total_completion_tokens = sum(previous_num_tokens)
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens)
request_metadata.final_usage_info = final_usage_info
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
@ -413,13 +412,13 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
if request.echo:
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
assert prompt_text is not None
else:
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None: