mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:45:01 +08:00
Add LogProbs for Chat Completions in OpenAI (#2918)
This commit is contained in:
parent
ef978fe411
commit
70f3e8e3a1
@ -155,15 +155,18 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
|||||||
}]
|
}]
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
model=model_name,
|
messages=messages,
|
||||||
messages=messages,
|
max_tokens=10,
|
||||||
max_tokens=10,
|
logprobs=True,
|
||||||
)
|
top_logprobs=10)
|
||||||
assert chat_completion.id is not None
|
assert chat_completion.id is not None
|
||||||
assert chat_completion.choices is not None and len(
|
assert chat_completion.choices is not None and len(
|
||||||
chat_completion.choices) == 1
|
chat_completion.choices) == 1
|
||||||
assert chat_completion.choices[0].message is not None
|
assert chat_completion.choices[0].message is not None
|
||||||
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
|
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||||
|
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
assert message.role == "assistant"
|
assert message.role == "assistant"
|
||||||
@ -198,13 +201,11 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
|
|||||||
single_output = single_completion.choices[0].text
|
single_output = single_completion.choices[0].text
|
||||||
single_usage = single_completion.usage
|
single_usage = single_completion.usage
|
||||||
|
|
||||||
stream = await client.completions.create(
|
stream = await client.completions.create(model=model_name,
|
||||||
model=model_name,
|
prompt=prompt,
|
||||||
prompt=prompt,
|
max_tokens=5,
|
||||||
max_tokens=5,
|
temperature=0.0,
|
||||||
temperature=0.0,
|
stream=True)
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunks.append(chunk.choices[0].text)
|
chunks.append(chunk.choices[0].text)
|
||||||
|
|||||||
@ -63,6 +63,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
seed: Optional[int] = None
|
seed: Optional[int] = None
|
||||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
|
logprobs: Optional[bool] = False
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
presence_penalty: Optional[float] = 0.0
|
presence_penalty: Optional[float] = 0.0
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
logit_bias: Optional[Dict[str, float]] = None
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
@ -84,6 +86,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
length_penalty: Optional[float] = 1.0
|
length_penalty: Optional[float] = 1.0
|
||||||
|
|
||||||
def to_sampling_params(self) -> SamplingParams:
|
def to_sampling_params(self) -> SamplingParams:
|
||||||
|
if self.logprobs and not self.top_logprobs:
|
||||||
|
raise ValueError("Top logprobs must be set when logprobs is.")
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
@ -96,6 +100,8 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
stop=self.stop,
|
stop=self.stop,
|
||||||
stop_token_ids=self.stop_token_ids,
|
stop_token_ids=self.stop_token_ids,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
logprobs=self.top_logprobs if self.logprobs else None,
|
||||||
|
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||||
best_of=self.best_of,
|
best_of=self.best_of,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
ignore_eos=self.ignore_eos,
|
ignore_eos=self.ignore_eos,
|
||||||
@ -216,6 +222,7 @@ class ChatMessage(BaseModel):
|
|||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -236,6 +243,7 @@ class DeltaMessage(BaseModel):
|
|||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -101,7 +101,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
for i in range(request.n):
|
for i in range(request.n):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
index=i,
|
||||||
|
delta=DeltaMessage(role=role),
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None)
|
||||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
@ -118,6 +121,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
"content") and request.messages[-1].get(
|
"content") and request.messages[-1].get(
|
||||||
"role") == role:
|
"role") == role:
|
||||||
last_msg_content = request.messages[-1]["content"]
|
last_msg_content = request.messages[-1]["content"]
|
||||||
|
|
||||||
if last_msg_content:
|
if last_msg_content:
|
||||||
for i in range(request.n):
|
for i in range(request.n):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
@ -129,6 +133,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
|
logprobs=None,
|
||||||
model=model_name)
|
model=model_name)
|
||||||
data = chunk.model_dump_json(exclude_unset=True)
|
data = chunk.model_dump_json(exclude_unset=True)
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
@ -145,15 +150,29 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if finish_reason_sent[i]:
|
if finish_reason_sent[i]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||||
|
top_logprobs = output.logprobs[
|
||||||
|
previous_num_tokens[i]:] if output.logprobs else None
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
logprobs = self._create_logprobs(
|
||||||
|
token_ids=delta_token_ids,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
initial_text_offset=len(previous_texts[i]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
delta_text = output.text[len(previous_texts[i]):]
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
previous_texts[i] = output.text
|
previous_texts[i] = output.text
|
||||||
previous_num_tokens[i] = len(output.token_ids)
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
|
||||||
if output.finish_reason is None:
|
if output.finish_reason is None:
|
||||||
# Send token-by-token response for each request.n
|
# Send token-by-token response for each request.n
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(content=delta_text),
|
delta=DeltaMessage(content=delta_text),
|
||||||
|
logprobs=logprobs,
|
||||||
finish_reason=None)
|
finish_reason=None)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
@ -174,6 +193,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(content=delta_text),
|
delta=DeltaMessage(content=delta_text),
|
||||||
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason)
|
finish_reason=output.finish_reason)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
@ -208,11 +228,25 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
|
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
top_logprobs = output.logprobs
|
||||||
|
|
||||||
|
if request.logprobs:
|
||||||
|
logprobs = self._create_logprobs(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=ChatMessage(role=role, content=output.text),
|
message=ChatMessage(role=role, content=output.text),
|
||||||
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
)
|
)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user