diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f336b4656555..0852d49fcad9 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -245,6 +245,7 @@ async def create_chat_completion(request: ChatCompletionRequest, index: int, text: str, finish_reason: Optional[str] = None, + usage: Optional[UsageInfo] = None, ) -> str: choice_data = ChatCompletionResponseStreamChoice( index=index, @@ -257,7 +258,10 @@ async def create_chat_completion(request: ChatCompletionRequest, model=model_name, choices=[choice_data], ) - response_json = response.json(ensure_ascii=False) + if usage is not None: + response.usage = usage + # exclude unset to leave details out of each sse + response_json = response.json(exclude_unset=True, ensure_ascii=False) return response_json @@ -283,17 +287,25 @@ async def create_chat_completion(request: ChatCompletionRequest, i = output.index delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + completion_tokens = len(output.token_ids) + previous_num_tokens[i] = completion_tokens response_json = create_stream_response_json( index=i, text=delta_text, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) response_json = create_stream_response_json( index=i, text="", finish_reason=output.finish_reason, + usage=final_usage, ) yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" @@ -462,6 +474,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): text: str, logprobs: Optional[LogProbs] = None, finish_reason: Optional[str] = None, + usage: Optional[UsageInfo] = None, ) -> str: choice_data = CompletionResponseStreamChoice( index=index, @@ -475,7 +488,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request): model=model_name, choices=[choice_data], ) - response_json = response.json(ensure_ascii=False) + if usage is not None: + response.usage = usage + response_json = response.json(exclude_unset=True, ensure_ascii=False) return response_json @@ -505,11 +520,19 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if output.finish_reason is not None: logprobs = (LogProbs() if request.logprobs is not None else None) + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) response_json = create_stream_response_json( index=i, text="", logprobs=logprobs, finish_reason=output.finish_reason, + usage=final_usage, ) yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7700c5dd483e..39db35620307 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -139,6 +139,7 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] class ChatMessage(BaseModel): @@ -178,3 +179,5 @@ class ChatCompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field( + default=None, description="data about request and response")