mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 02:45:01 +08:00
[Misc] Include matched stop string/token in responses (#2976)
Co-authored-by: Sahil Suneja <sahilsuneja@gmail.com>
This commit is contained in:
parent
3a243095e5
commit
dfeb2ecc3a
59
tests/samplers/test_stop_reason.py
Normal file
59
tests/samplers/test_stop_reason.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
"""Test the different finish_reason="stop" situations during generation:
|
||||||
|
1. One of the provided stop strings
|
||||||
|
2. One of the provided stop tokens
|
||||||
|
3. The EOS token
|
||||||
|
|
||||||
|
Run `pytest tests/samplers/test_stop_reason.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
MODEL = "facebook/opt-350m"
|
||||||
|
STOP_STR = "."
|
||||||
|
SEED = 42
|
||||||
|
MAX_TOKENS = 1024
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vllm_model(vllm_runner):
|
||||||
|
vllm_model = vllm_runner(MODEL)
|
||||||
|
yield vllm_model
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_reason(vllm_model, example_prompts):
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
|
||||||
|
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
|
||||||
|
llm = vllm_model.model
|
||||||
|
|
||||||
|
# test stop token
|
||||||
|
outputs = llm.generate(example_prompts,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
seed=SEED,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
stop_token_ids=[stop_token_id]))
|
||||||
|
for output in outputs:
|
||||||
|
output = output.outputs[0]
|
||||||
|
assert output.finish_reason == "stop"
|
||||||
|
assert output.stop_reason == stop_token_id
|
||||||
|
|
||||||
|
# test stop string
|
||||||
|
outputs = llm.generate(example_prompts,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
|
||||||
|
for output in outputs:
|
||||||
|
output = output.outputs[0]
|
||||||
|
assert output.finish_reason == "stop"
|
||||||
|
assert output.stop_reason == STOP_STR
|
||||||
|
|
||||||
|
# test EOS token
|
||||||
|
outputs = llm.generate(example_prompts,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
seed=SEED, max_tokens=MAX_TOKENS))
|
||||||
|
for output in outputs:
|
||||||
|
output = output.outputs[0]
|
||||||
|
assert output.finish_reason == "length" or (
|
||||||
|
output.finish_reason == "stop" and output.stop_reason is None)
|
||||||
@ -740,12 +740,15 @@ class LLMEngine:
|
|||||||
if seq.output_text.endswith(stop_str):
|
if seq.output_text.endswith(stop_str):
|
||||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
seq.stop_reason = stop_str
|
||||||
return
|
return
|
||||||
if seq.get_last_token_id() in sampling_params.stop_token_ids:
|
last_token_id = seq.get_last_token_id()
|
||||||
|
if last_token_id in sampling_params.stop_token_ids:
|
||||||
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||||
seq.get_last_token_id())
|
last_token_id)
|
||||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
seq.stop_reason = last_token_id
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if the sequence has generated the EOS token.
|
# Check if the sequence has generated the EOS token.
|
||||||
|
|||||||
@ -338,6 +338,13 @@ class CompletionResponseChoice(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
stop_reason: Union[None, int, str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
@ -354,6 +361,13 @@ class CompletionResponseStreamChoice(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
stop_reason: Union[None, int, str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"The stop string or token id that caused the completion "
|
||||||
|
"to stop, None if the completion finished for some other reason "
|
||||||
|
"including encountering the EOS token"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionStreamResponse(BaseModel):
|
class CompletionStreamResponse(BaseModel):
|
||||||
@ -375,6 +389,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
stop_reason: Union[None, int, str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
@ -396,6 +411,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
stop_reason: Union[None, int, str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
|||||||
@ -220,7 +220,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
index=i,
|
index=i,
|
||||||
delta=DeltaMessage(content=delta_text),
|
delta=DeltaMessage(content=delta_text),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason)
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason)
|
||||||
chunk = ChatCompletionStreamResponse(
|
chunk = ChatCompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
object=chunk_object_type,
|
object=chunk_object_type,
|
||||||
@ -278,6 +279,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
message=ChatMessage(role=role, content=output.text),
|
message=ChatMessage(role=role, content=output.text),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason,
|
||||||
)
|
)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
|||||||
@ -266,6 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
previous_texts[i] = output.text
|
previous_texts[i] = output.text
|
||||||
previous_num_tokens[i] = len(output.token_ids)
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
finish_reason = output.finish_reason
|
finish_reason = output.finish_reason
|
||||||
|
stop_reason = output.stop_reason
|
||||||
if output.finish_reason is not None: # return final usage
|
if output.finish_reason is not None: # return final usage
|
||||||
prompt_tokens = len(res.prompt_token_ids)
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
completion_tokens = len(output.token_ids)
|
completion_tokens = len(output.token_ids)
|
||||||
@ -286,6 +287,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
text=delta_text,
|
text=delta_text,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
stop_reason=stop_reason,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
usage=final_usage,
|
usage=final_usage,
|
||||||
@ -342,6 +344,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
text=output_text,
|
text=output_text,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
|
stop_reason=output.stop_reason,
|
||||||
)
|
)
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||||
@ -18,6 +18,9 @@ class CompletionOutput:
|
|||||||
logprobs: The log probabilities of the top probability words at each
|
logprobs: The log probabilities of the top probability words at each
|
||||||
position if the logprobs are requested.
|
position if the logprobs are requested.
|
||||||
finish_reason: The reason why the sequence is finished.
|
finish_reason: The reason why the sequence is finished.
|
||||||
|
stop_reason: The stop string or token id that caused the completion
|
||||||
|
to stop, None if the completion finished for some other reason
|
||||||
|
including encountering the EOS token.
|
||||||
lora_request: The LoRA request that was used to generate the output.
|
lora_request: The LoRA request that was used to generate the output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -29,6 +32,7 @@ class CompletionOutput:
|
|||||||
cumulative_logprob: float,
|
cumulative_logprob: float,
|
||||||
logprobs: Optional[SampleLogprobs],
|
logprobs: Optional[SampleLogprobs],
|
||||||
finish_reason: Optional[str] = None,
|
finish_reason: Optional[str] = None,
|
||||||
|
stop_reason: Union[int, str, None] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.index = index
|
self.index = index
|
||||||
@ -37,6 +41,7 @@ class CompletionOutput:
|
|||||||
self.cumulative_logprob = cumulative_logprob
|
self.cumulative_logprob = cumulative_logprob
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
self.finish_reason = finish_reason
|
self.finish_reason = finish_reason
|
||||||
|
self.stop_reason = stop_reason
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
|
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
@ -48,7 +53,8 @@ class CompletionOutput:
|
|||||||
f"token_ids={self.token_ids}, "
|
f"token_ids={self.token_ids}, "
|
||||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||||
f"logprobs={self.logprobs}, "
|
f"logprobs={self.logprobs}, "
|
||||||
f"finish_reason={self.finish_reason})")
|
f"finish_reason={self.finish_reason}, "
|
||||||
|
f"stop_reason={self.stop_reason})")
|
||||||
|
|
||||||
|
|
||||||
class RequestOutput:
|
class RequestOutput:
|
||||||
@ -111,8 +117,8 @@ class RequestOutput:
|
|||||||
seq.get_output_token_ids(),
|
seq.get_output_token_ids(),
|
||||||
seq.get_cumulative_logprob(),
|
seq.get_cumulative_logprob(),
|
||||||
seq.output_logprobs if include_logprobs else None,
|
seq.output_logprobs if include_logprobs else None,
|
||||||
SequenceStatus.get_finished_reason(seq.status))
|
SequenceStatus.get_finished_reason(seq.status),
|
||||||
for seq in top_n_seqs
|
seq.stop_reason) for seq in top_n_seqs
|
||||||
]
|
]
|
||||||
|
|
||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
|
|||||||
@ -183,6 +183,7 @@ class Sequence:
|
|||||||
# Initialize the logical token blocks with the prompt token ids.
|
# Initialize the logical token blocks with the prompt token ids.
|
||||||
self._append_tokens_to_blocks(prompt_token_ids)
|
self._append_tokens_to_blocks(prompt_token_ids)
|
||||||
self.status = SequenceStatus.WAITING
|
self.status = SequenceStatus.WAITING
|
||||||
|
self.stop_reason: Union[int, str, None] = None
|
||||||
|
|
||||||
# Used for incremental detokenization
|
# Used for incremental detokenization
|
||||||
self.prefix_offset = 0
|
self.prefix_offset = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user