mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 22:43:38 +08:00
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
189 lines
7.1 KiB
Python
189 lines
7.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
|
||
from typing import List, Optional, Tuple
|
||
|
||
from .tokenizer import AnyTokenizer
|
||
|
||
|
||
def _replace_none_with_empty(tokens: List[Optional[str]]):
|
||
for i, token in enumerate(tokens):
|
||
if token is None:
|
||
tokens[i] = ""
|
||
|
||
|
||
def _convert_tokens_to_string_with_added_encoders(
|
||
tokenizer: AnyTokenizer,
|
||
output_tokens: List[str],
|
||
skip_special_tokens: bool,
|
||
spaces_between_special_tokens: bool,
|
||
) -> str:
|
||
# Adapted from
|
||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||
# the output_tokens. In Python, running a for loop over a list can be slow
|
||
# even when the loop body is very simple.
|
||
sub_texts: List[str] = []
|
||
current_sub_text: List[str] = []
|
||
all_special_tokens = set(tokenizer.all_special_tokens)
|
||
for token in output_tokens:
|
||
if skip_special_tokens and token in all_special_tokens:
|
||
continue
|
||
if token in tokenizer.get_added_vocab():
|
||
if current_sub_text:
|
||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||
sub_texts.append(sub_text)
|
||
current_sub_text = []
|
||
sub_texts.append(token)
|
||
else:
|
||
current_sub_text.append(token)
|
||
if current_sub_text:
|
||
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||
sub_texts.append(sub_text)
|
||
if spaces_between_special_tokens:
|
||
return " ".join(sub_texts)
|
||
else:
|
||
return "".join(sub_texts)
|
||
|
||
|
||
# 5 is an arbitrary value that should work for all
|
||
# tokenizers (bigger = more conservative).
|
||
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||
|
||
|
||
def convert_prompt_ids_to_tokens(
|
||
tokenizer: AnyTokenizer,
|
||
prompt_ids: List[int],
|
||
skip_special_tokens: bool = False,
|
||
) -> Tuple[List[str], int, int]:
|
||
"""Converts the prompt ids to tokens and returns the tokens and offsets
|
||
for incremental detokenization.
|
||
|
||
Note that not all tokens are converted to strings. Only the tokens that
|
||
are necessary for incremental detokenization are converted to strings.
|
||
"""
|
||
# We do not need to convert the whole prompt to tokens.
|
||
# Offset a little more in case we have special tokens.
|
||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||
prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
|
||
skip_special_tokens=skip_special_tokens)
|
||
read_offset = len(new_tokens)
|
||
prefix_offset = max(
|
||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||
# This is required to guard against out-of-vocab prompt token ids
|
||
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
|
||
return new_tokens, prefix_offset, read_offset
|
||
|
||
|
||
def convert_ids_list_to_tokens(
|
||
tokenizer: AnyTokenizer,
|
||
token_ids: List[int],
|
||
) -> List[str]:
|
||
"""Detokenize the input ids individually.
|
||
|
||
Args:
|
||
tokenizer: tokenizer used by model under test
|
||
token_ids: convert these tokens (Python list form)
|
||
|
||
Returns:
|
||
Python list of token string representations
|
||
|
||
"""
|
||
token_str_lst = tokenizer.convert_ids_to_tokens(token_ids)
|
||
_replace_none_with_empty(token_str_lst) # type: ignore
|
||
return token_str_lst
|
||
|
||
|
||
# Based on
|
||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||
# under Apache 2.0 license
|
||
def detokenize_incrementally(
|
||
tokenizer: AnyTokenizer,
|
||
all_input_ids: List[int],
|
||
prev_tokens: Optional[List[str]],
|
||
prefix_offset: int,
|
||
read_offset: int,
|
||
skip_special_tokens: bool = False,
|
||
spaces_between_special_tokens: bool = True,
|
||
) -> Tuple[List[str], str, int, int]:
|
||
"""Detokenizes the input ids incrementally and returns the new tokens
|
||
and the new text.
|
||
|
||
If `prev_tokens` is None, this function will convert the input ids to
|
||
tokens and return the tokens and the new text. Otherwise, it will return the
|
||
new tokens and the new text.
|
||
|
||
This function will also return the new prefix offset and the new read
|
||
offset to be used in the next iteration.
|
||
|
||
The offsets are necessary to defeat cleanup algorithms in the decode which
|
||
decide to add a space or not depending on the surrounding ids.
|
||
|
||
Args:
|
||
tokenizer: The tokenizer to use.
|
||
all_input_ids: The input ids. The last id is the new token id.
|
||
prev_tokens: The previous tokens. If None, this function will convert
|
||
the input ids to tokens and return the tokens and the new text.
|
||
prefix_offset: The prefix offset.
|
||
read_offset: The read offset.
|
||
skip_special_tokens: Whether to skip special tokens.
|
||
spaces_between_special_tokens: Whether to add spaces between special
|
||
tokens.
|
||
"""
|
||
new_token_id = all_input_ids[-1]
|
||
# This is the first iteration for this sequence
|
||
is_first_iter = prev_tokens is None
|
||
if is_first_iter:
|
||
(prev_tokens, prefix_offset,
|
||
read_offset) = convert_prompt_ids_to_tokens(
|
||
tokenizer,
|
||
all_input_ids[:-1],
|
||
skip_special_tokens=skip_special_tokens)
|
||
assert prev_tokens is not None
|
||
|
||
# If the new token id is out of bounds, return an empty string.
|
||
if 0 <= new_token_id < len(tokenizer):
|
||
# Put new_token_id in a list so skip_special_tokens is respected
|
||
new_tokens = tokenizer.convert_ids_to_tokens(
|
||
[new_token_id], skip_special_tokens=skip_special_tokens)
|
||
if isinstance(new_tokens, str):
|
||
new_tokens = [new_tokens]
|
||
else:
|
||
new_tokens = [""]
|
||
output_tokens = prev_tokens + new_tokens
|
||
|
||
# If this is the first iteration, return all tokens.
|
||
if is_first_iter:
|
||
new_tokens = output_tokens
|
||
|
||
# The prefix text is necessary only to defeat cleanup algorithms in
|
||
# the decode which decide to add a space or not depending on the
|
||
# surrounding ids.
|
||
if tokenizer.is_fast or not tokenizer.get_added_vocab():
|
||
prefix_text = tokenizer.convert_tokens_to_string(
|
||
output_tokens[prefix_offset:read_offset])
|
||
new_text = tokenizer.convert_tokens_to_string(
|
||
output_tokens[prefix_offset:])
|
||
else:
|
||
prefix_text = _convert_tokens_to_string_with_added_encoders(
|
||
tokenizer,
|
||
output_tokens[prefix_offset:read_offset],
|
||
skip_special_tokens=skip_special_tokens,
|
||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||
)
|
||
new_text = _convert_tokens_to_string_with_added_encoders(
|
||
tokenizer,
|
||
output_tokens[prefix_offset:],
|
||
skip_special_tokens=skip_special_tokens,
|
||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||
)
|
||
|
||
if len(new_text) <= len(prefix_text) or new_text.endswith("<EFBFBD>"):
|
||
# utf-8 char at the end means it's a potential unfinished byte sequence
|
||
# from byte fallback tokenization.
|
||
# If it's in the middle, it's probably a real invalid id generated
|
||
# by the model
|
||
return new_tokens, "", prefix_offset, read_offset
|
||
|
||
new_text = new_text[len(prefix_text):]
|
||
return new_tokens, new_text, read_offset, len(output_tokens)
|