Fix handling of special tokens in decoding. (#418)

This commit is contained in:
xcnick 2023-07-12 23:14:56 +08:00 committed by GitHub
parent 51be365143
commit c6dfc3cdbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -276,8 +276,9 @@ class LLMEngine:
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Stop the finished sequences."""

View File

@ -80,6 +80,8 @@ def detokenize_incrementally(
new_token: The new token as a string.
output_text: The new output text as a string.
"""
if skip_special_tokens and (new_token_id in tokenizer.all_special_ids):
return None, prev_output_tokens
new_token = tokenizer.convert_ids_to_tokens(
new_token_id, skip_special_tokens=skip_special_tokens)
output_tokens = prev_output_tokens + [new_token]
@ -99,7 +101,7 @@ def detokenize_incrementally(
sub_texts = []
current_sub_text = []
for token in output_tokens:
if skip_special_tokens and token in tokenizer.all_special_ids:
if skip_special_tokens and token in tokenizer.all_special_tokens:
continue
if token in tokenizer.added_tokens_encoder:
if current_sub_text: