mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:05:28 +08:00
Incrementally decode output tokens (#121)
This commit is contained in:
parent
aedba6d5ec
commit
e86717833d
@ -291,7 +291,7 @@ class Scheduler:
|
|||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
# Append a new token to the sequence.
|
# Append a new token to the sequence.
|
||||||
output = seq_outputs[seq.seq_id]
|
output = seq_outputs[seq.seq_id]
|
||||||
seq.append_token(output.output_token, output.logprobs)
|
seq.append_token_id(output.output_token, output.logprobs)
|
||||||
return self.running.copy()
|
return self.running.copy()
|
||||||
|
|
||||||
def free_seq(self, seq: Sequence) -> None:
|
def free_seq(self, seq: Sequence) -> None:
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class SequenceData:
|
|||||||
self.output_token_ids: List[int] = []
|
self.output_token_ids: List[int] = []
|
||||||
self.cumulative_logprob = 0.0
|
self.cumulative_logprob = 0.0
|
||||||
|
|
||||||
def append_token(self, token_id: int, logprob: float) -> None:
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||||
self.output_token_ids.append(token_id)
|
self.output_token_ids.append(token_id)
|
||||||
self.cumulative_logprob += logprob
|
self.cumulative_logprob += logprob
|
||||||
|
|
||||||
@ -64,6 +64,7 @@ class Sequence:
|
|||||||
|
|
||||||
self.data = SequenceData(prompt_token_ids)
|
self.data = SequenceData(prompt_token_ids)
|
||||||
self.output_logprobs: List[Dict[int, float]] = []
|
self.output_logprobs: List[Dict[int, float]] = []
|
||||||
|
self.output_tokens: List[str] = []
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||||
@ -92,11 +93,15 @@ class Sequence:
|
|||||||
last_block.append_tokens(token_ids[:num_empty_slots])
|
last_block.append_tokens(token_ids[:num_empty_slots])
|
||||||
token_ids = token_ids[num_empty_slots:]
|
token_ids = token_ids[num_empty_slots:]
|
||||||
|
|
||||||
def append_token(self, token_id: int, logprobs: Dict[int, float]) -> None:
|
def append_token_id(
|
||||||
|
self,
|
||||||
|
token_id: int,
|
||||||
|
logprobs: Dict[int, float],
|
||||||
|
) -> None:
|
||||||
assert token_id in logprobs
|
assert token_id in logprobs
|
||||||
self._append_tokens_to_blocks([token_id])
|
self._append_tokens_to_blocks([token_id])
|
||||||
self.output_logprobs.append(logprobs)
|
self.output_logprobs.append(logprobs)
|
||||||
self.data.append_token(token_id, logprobs[token_id])
|
self.data.append_token_id(token_id, logprobs[token_id])
|
||||||
|
|
||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return self.data.get_len()
|
return self.data.get_len()
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from cacheflow.outputs import RequestOutput
|
|||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
from cacheflow.server.arg_utils import ServerArgs
|
from cacheflow.server.arg_utils import ServerArgs
|
||||||
from cacheflow.server.ray_utils import initialize_cluster
|
from cacheflow.server.ray_utils import initialize_cluster
|
||||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||||
|
detokenize_incrementally)
|
||||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
from cacheflow.utils import Counter
|
from cacheflow.utils import Counter
|
||||||
from cacheflow.worker.worker import Worker
|
from cacheflow.worker.worker import Worker
|
||||||
@ -185,18 +186,17 @@ class LLMServer:
|
|||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||||
# Batch-decode the sequence outputs.
|
# Decode the sequence outputs.
|
||||||
seqs: List[Sequence] = []
|
|
||||||
for seq_group in seq_groups:
|
for seq_group in seq_groups:
|
||||||
seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING))
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
output_tokens_per_seq = []
|
new_token, new_output_text = detokenize_incrementally(
|
||||||
for seq in seqs:
|
self.tokenizer,
|
||||||
output_tokens_per_seq.append(seq.get_output_token_ids())
|
seq.output_tokens,
|
||||||
output_texts = self.tokenizer.batch_decode(output_tokens_per_seq,
|
seq.get_last_token_id(),
|
||||||
skip_special_tokens=True)
|
skip_special_tokens=True,
|
||||||
# Update the sequences with the output texts.
|
)
|
||||||
for seq, output_text in zip(seqs, output_texts):
|
seq.output_tokens.append(new_token)
|
||||||
seq.output_text = output_text
|
seq.output_text = new_output_text
|
||||||
|
|
||||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||||
# Stop the sequences.
|
# Stop the sequences.
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
from typing import Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast)
|
PreTrainedTokenizerFast)
|
||||||
|
|
||||||
|
from cacheflow.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
|
||||||
# LLaMA fast tokenizer has a bug related to protobuf.
|
# LLaMA fast tokenizer has a bug related to protobuf.
|
||||||
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
|
||||||
@ -17,5 +21,62 @@ def get_tokenizer(
|
|||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||||
config = AutoConfig.from_pretrained(model_name)
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
|
||||||
|
if getattr(kwargs, "use_fast", False) == True:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot use the fast tokenizer for {config.model_type} due to "
|
||||||
|
"bugs in the fast tokenizer.")
|
||||||
|
logger.info(
|
||||||
|
f"Using the slow tokenizer for {config.model_type} due to bugs in "
|
||||||
|
"the fast tokenizer. This could potentially lead to performance "
|
||||||
|
"degradation.")
|
||||||
kwargs["use_fast"] = False
|
kwargs["use_fast"] = False
|
||||||
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
|
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def detokenize_incrementally(
|
||||||
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
|
prev_output_tokens: List[str],
|
||||||
|
new_token_id: int,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Detokenizes the new token in conjuction with the previous output tokens.
|
||||||
|
|
||||||
|
NOTE: This function does not update prev_output_tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new_token: The new token as a string.
|
||||||
|
output_text: The new output text as a string.
|
||||||
|
"""
|
||||||
|
new_token = tokenizer.convert_ids_to_tokens(
|
||||||
|
new_token_id, skip_special_tokens=skip_special_tokens)
|
||||||
|
output_tokens = prev_output_tokens + [new_token]
|
||||||
|
|
||||||
|
# Convert the tokens to a string.
|
||||||
|
# Optimization: If the tokenizer does not have `added_tokens_encoder`,
|
||||||
|
# then we can directly use `convert_tokens_to_string`.
|
||||||
|
if not getattr(tokenizer, "added_tokens_encoder", {}):
|
||||||
|
output_text = tokenizer.convert_tokens_to_string(output_tokens)
|
||||||
|
return new_token, output_text
|
||||||
|
|
||||||
|
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
|
||||||
|
# NOTE(woosuk): The following code is slow because it runs a for loop over
|
||||||
|
# the output_tokens. In Python, running a for loop over a list can be slow
|
||||||
|
# even when the loop body is very simple.
|
||||||
|
sub_texts = []
|
||||||
|
current_sub_text = []
|
||||||
|
for token in output_tokens:
|
||||||
|
if skip_special_tokens and token in tokenizer.all_special_ids:
|
||||||
|
continue
|
||||||
|
if token in tokenizer.added_tokens_encoder:
|
||||||
|
if current_sub_text:
|
||||||
|
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||||
|
sub_texts.append(sub_text)
|
||||||
|
current_sub_text = []
|
||||||
|
sub_texts.append(token)
|
||||||
|
else:
|
||||||
|
current_sub_text.append(token)
|
||||||
|
if current_sub_text:
|
||||||
|
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
|
||||||
|
sub_texts.append(sub_text)
|
||||||
|
output_text = " ".join(sub_texts)
|
||||||
|
return new_token, output_text
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user