mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 13:37:15 +08:00
181 lines
6.3 KiB
Python
181 lines
6.3 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Optional, Union
|
|
|
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
|
from vllm.logger import init_logger
|
|
from vllm.sampling_params import RequestOutputKind
|
|
from vllm.transformers_utils.detokenizer_utils import (
|
|
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class DetokenizerOutput:
|
|
output_text: str
|
|
token_ids: List[int]
|
|
finished: bool
|
|
finish_reason: Optional[str] = None
|
|
stop_reason: Union[int, str, None] = None
|
|
|
|
|
|
@dataclass
|
|
class IncrementalDetokenizer:
|
|
|
|
# Generation data
|
|
output_text: str
|
|
tokens: List[str]
|
|
token_ids: List[int]
|
|
prompt_len: int
|
|
|
|
# Stop strings
|
|
stop: List[str]
|
|
include_stop_str_in_output: bool
|
|
|
|
# Metadata for incremental detokenization
|
|
prefix_offset: int
|
|
read_offset: int
|
|
|
|
# Parameters for detokenization
|
|
skip_special_tokens: bool
|
|
spaces_between_special_tokens: bool
|
|
output_kind: RequestOutputKind
|
|
|
|
# Tokenizer for this request
|
|
tokenizer: AnyTokenizer
|
|
|
|
# Accounting for stop string buffering
|
|
stop_buffer_length: int
|
|
_last_output_text_offset: int = 0
|
|
|
|
@property
|
|
def output_token_ids(self) -> List[int]:
|
|
return self.token_ids[self.prompt_len:]
|
|
|
|
@classmethod
|
|
def from_new_request(
|
|
cls,
|
|
tokenizer: AnyTokenizer,
|
|
request: EngineCoreRequest,
|
|
) -> "IncrementalDetokenizer":
|
|
|
|
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
|
tokenizer=tokenizer,
|
|
prompt_ids=request.prompt_token_ids,
|
|
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
|
)
|
|
|
|
stops = request.sampling_params.stop
|
|
# Number of chars to hold back when stop strings are to be excluded
|
|
# from streamed output.
|
|
if stops and not request.sampling_params.include_stop_str_in_output:
|
|
stop_buffer_length = max(len(s) for s in stops) - 1
|
|
else:
|
|
stop_buffer_length = 0
|
|
|
|
return cls(
|
|
output_text="",
|
|
tokens=tokens,
|
|
# Detokenizer mutates this list, so need a unique copy.
|
|
# NOTE(Nick): could we take ownership of it though?
|
|
token_ids=request.prompt_token_ids.copy(),
|
|
stop=stops,
|
|
include_stop_str_in_output=request.sampling_params.
|
|
include_stop_str_in_output,
|
|
prefix_offset=prefix_offset,
|
|
read_offset=read_offset,
|
|
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
|
spaces_between_special_tokens=request.sampling_params.
|
|
spaces_between_special_tokens,
|
|
output_kind=request.sampling_params.output_kind,
|
|
prompt_len=len(request.prompt_token_ids),
|
|
tokenizer=tokenizer,
|
|
stop_buffer_length=stop_buffer_length,
|
|
)
|
|
|
|
def update_from_output(
|
|
self,
|
|
output: EngineCoreOutput,
|
|
) -> Optional[DetokenizerOutput]:
|
|
"""
|
|
Update RequestState for the request_id by:
|
|
1) Detokenize the new token ids incrementally.
|
|
2) Update the RequestOutput with the new text.
|
|
"""
|
|
|
|
new_token_ids = output.new_token_ids
|
|
finish_reason = output.finish_reason
|
|
stop_reason = output.stop_reason
|
|
|
|
# 1) Detokenize the new token ids incrementally.
|
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
|
# new_token_ids is more than 1. We need to optimize this.
|
|
decoded_text = ""
|
|
for new_token_id in new_token_ids:
|
|
self.token_ids.append(new_token_id)
|
|
(new_tokens, new_decoded_token_text, prefix_offset,
|
|
read_offset) = detokenize_incrementally(
|
|
tokenizer=self.tokenizer,
|
|
all_input_ids=self.token_ids,
|
|
prev_tokens=self.tokens,
|
|
prefix_offset=self.prefix_offset,
|
|
read_offset=self.read_offset,
|
|
skip_special_tokens=self.skip_special_tokens,
|
|
spaces_between_special_tokens=self.
|
|
spaces_between_special_tokens,
|
|
)
|
|
|
|
self.tokens.extend(new_tokens)
|
|
self.prefix_offset = prefix_offset
|
|
self.read_offset = read_offset
|
|
self.output_text += new_decoded_token_text
|
|
|
|
decoded_text += new_decoded_token_text
|
|
|
|
# 2) Evaluate stop criteria.
|
|
if self.stop:
|
|
stop = StopChecker.check_stop_strings(
|
|
output_text=self.output_text,
|
|
new_char_count=len(decoded_text),
|
|
stop=self.stop,
|
|
include_in_output=self.include_stop_str_in_output,
|
|
)
|
|
if stop is not None:
|
|
stop_str, truncate_to = stop
|
|
if truncate_to != -1:
|
|
self.output_text = self.output_text[:truncate_to]
|
|
finish_reason = "stop" # TODO: use constant
|
|
stop_reason = stop_str
|
|
|
|
# TODO: handle stop_token_ids here too?
|
|
|
|
# 3) Update the RequestOutput object with the new text.
|
|
finished = bool(finish_reason)
|
|
if self.output_kind == RequestOutputKind.FINAL_ONLY \
|
|
and not finished:
|
|
return None
|
|
|
|
delta = self.output_kind == RequestOutputKind.DELTA
|
|
output_text = self._get_next_output_text(finished, delta)
|
|
token_ids = new_token_ids if delta else self.output_token_ids
|
|
|
|
return DetokenizerOutput(output_text, token_ids, finished,
|
|
finish_reason, stop_reason)
|
|
|
|
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
|
|
"""If delta is True, only new text since the last call to
|
|
this method is returned"""
|
|
|
|
# We return the full output text if the sequence is finished.
|
|
buffer_length = 0 if finished else self.stop_buffer_length
|
|
if not delta:
|
|
return self.output_text[:-buffer_length] if buffer_length else (
|
|
self.output_text)
|
|
length = len(self.output_text) - buffer_length
|
|
last_offset = self._last_output_text_offset
|
|
if last_offset < length:
|
|
self._last_output_text_offset = length
|
|
return self.output_text[last_offset:length]
|
|
return ""
|