[Misc] Include matched stop string/token in responses (#2976)

Co-authored-by: Sahil Suneja <sahilsuneja@gmail.com>
This commit is contained in:
Nick Hill 2024-03-25 17:31:32 -07:00 committed by GitHub
parent 3a243095e5
commit dfeb2ecc3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 97 additions and 7 deletions

View File

@ -0,0 +1,59 @@
"""Test the different finish_reason="stop" situations during generation:
1. One of the provided stop strings
2. One of the provided stop tokens
3. The EOS token
Run `pytest tests/samplers/test_stop_reason.py`.
"""
import pytest
import transformers
from vllm import SamplingParams
MODEL = "facebook/opt-350m"
STOP_STR = "."
SEED = 42
MAX_TOKENS = 1024
@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL)
yield vllm_model
del vllm_model
def test_stop_reason(vllm_model, example_prompts):
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
llm = vllm_model.model
# test stop token
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED,
max_tokens=MAX_TOKENS,
stop_token_ids=[stop_token_id]))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "stop"
assert output.stop_reason == stop_token_id
# test stop string
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "stop"
assert output.stop_reason == STOP_STR
# test EOS token
outputs = llm.generate(example_prompts,
sampling_params=SamplingParams(
seed=SEED, max_tokens=MAX_TOKENS))
for output in outputs:
output = output.outputs[0]
assert output.finish_reason == "length" or (
output.finish_reason == "stop" and output.stop_reason is None)

View File

@ -740,12 +740,15 @@ class LLMEngine:
if seq.output_text.endswith(stop_str): if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str) self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return return
if seq.get_last_token_id() in sampling_params.stop_token_ids: last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
seq.get_last_token_id()) last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str) self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return return
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.

View File

@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
text: str text: str
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(BaseModel): class CompletionStreamResponse(BaseModel):
@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):

View File

@ -220,7 +220,8 @@ class OpenAIServingChat(OpenAIServing):
index=i, index=i,
delta=DeltaMessage(content=delta_text), delta=DeltaMessage(content=delta_text),
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason) finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
object=chunk_object_type, object=chunk_object_type,
@ -278,6 +279,7 @@ class OpenAIServingChat(OpenAIServing):
message=ChatMessage(role=role, content=output.text), message=ChatMessage(role=role, content=output.text),
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
) )
choices.append(choice_data) choices.append(choice_data)

View File

@ -266,6 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
@ -286,6 +287,7 @@ class OpenAIServingCompletion(OpenAIServing):
text=delta_text, text=delta_text,
logprobs=logprobs, logprobs=logprobs,
finish_reason=finish_reason, finish_reason=finish_reason,
stop_reason=stop_reason,
) )
], ],
usage=final_usage, usage=final_usage,
@ -342,6 +344,7 @@ class OpenAIServingCompletion(OpenAIServing):
text=output_text, text=output_text,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
) )
choices.append(choice_data) choices.append(choice_data)

View File

@ -1,5 +1,5 @@
import time import time
from typing import List, Optional from typing import List, Optional, Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
@ -18,6 +18,9 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested. position if the logprobs are requested.
finish_reason: The reason why the sequence is finished. finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
""" """
@ -29,6 +32,7 @@ class CompletionOutput:
cumulative_logprob: float, cumulative_logprob: float,
logprobs: Optional[SampleLogprobs], logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None, finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.index = index self.index = index
@ -37,6 +41,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs self.logprobs = logprobs
self.finish_reason = finish_reason self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
@ -48,7 +53,8 @@ class CompletionOutput:
f"token_ids={self.token_ids}, " f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, " f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, " f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})") f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
class RequestOutput: class RequestOutput:
@ -111,8 +117,8 @@ class RequestOutput:
seq.get_output_token_ids(), seq.get_output_token_ids(),
seq.get_cumulative_logprob(), seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None, seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status)) SequenceStatus.get_finished_reason(seq.status),
for seq in top_n_seqs seq.stop_reason) for seq in top_n_seqs
] ]
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.

View File

@ -183,6 +183,7 @@ class Sequence:
# Initialize the logical token blocks with the prompt token ids. # Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids) self._append_tokens_to_blocks(prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
# Used for incremental detokenization # Used for incremental detokenization
self.prefix_offset = 0 self.prefix_offset = 0