Add LogProbs for Chat Completions in OpenAI (#2918)

This commit is contained in:
Jared Moore 2024-02-25 18:39:34 -08:00 committed by GitHub
parent ef978fe411
commit 70f3e8e3a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 14 deletions

View File

@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
}]
# test single completion
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
)
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
single_output = single_completion.choices[0].text
single_usage = single_completion.usage
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
)
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks = []
async for chunk in stream:
chunks.append(chunk.choices[0].text)

View File

@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
length_penalty: Optional[float] = 1.0
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None

View File

@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None
choices = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)