[BugFix] Work-around incremental detokenization edge case error (#19449)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-11 23:43:20 -07:00 committed by GitHub
parent 7e3e74c97c
commit d5bdf899e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 6 deletions

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
# ruff: noqa: E501
def test_fast_inc_detok_invalid_utf8_err_case():
"""
Test edge case where tokenizer can produce non-monotonic,
invalid UTF-8 output, which breaks the internal state of
tokenizers' DecodeStream.
See https://github.com/vllm-project/vllm/issues/17448.
Thanks to reproducer from @fpaupier:
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
"""
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
# Create a test request
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
0.0,
None,
cache_salt=None,
data_parallel_rank=None,
)
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
"Should use FastIncrementalDetokenizer by default"
# Process tokens incrementally
test_tokens = [
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
2555, 513, 236789, 602, 31118, 569
]
output = ""
for i, token_id in enumerate(test_tokens):
detokenizer.update([token_id], False)
finished = i == len(test_tokens) - 1
output += detokenizer.get_next_output_text(finished, delta=True)
# fmt: off
assert output == r'''[
{
"source": "Résultats",
"source_type": "CONCEPT",
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
"target": "Israël",
"target_type": "ORGANIZATION",
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
"relationship": "Obtention d'un niveau de'''

View File

@ -17,6 +17,14 @@ from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
# Only tokenizers >= 0.21.1 supports DecodeStream used for
# FastIncrementalDetokenizer.
USE_FAST_DETOKENIZER = version.parse(
tokenizers.__version__) >= version.parse("0.21.1")
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
class IncrementalDetokenizer:
@ -46,10 +54,9 @@ class IncrementalDetokenizer:
# No tokenizer => skipping detokenization.
return IncrementalDetokenizer()
if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse(
tokenizers.__version__) >= version.parse("0.21.1")):
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
PreTrainedTokenizerFast):
# Fast tokenizer => use tokenizers library DecodeStream.
# And only tokenizers >= 0.21.1 supports Fast Detokenizer.
return FastIncrementalDetokenizer(tokenizer, request)
# Fall back to slow python-based incremental detokenization.
@ -157,8 +164,11 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
super().__init__(request)
sampling_params = request.sampling_params
self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
self.stream = DecodeStream(
skip_special_tokens=sampling_params.skip_special_tokens)
skip_special_tokens=self.skip_special_tokens)
self.tokenizer: Tokenizer = tokenizer._tokenizer
@ -174,7 +184,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
# Prime the stream.
for tid in prompt_suffix:
self.stream.step(self.tokenizer, tid)
self._protected_step(tid)
self.spaces_between_special_tokens = (
sampling_params.skip_special_tokens
@ -199,7 +209,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
self.spaces_between_special_tokens = True
def decode_next(self, next_token_id: int) -> str:
token = self.stream.step(self.tokenizer, next_token_id)
token = self._protected_step(next_token_id)
if not self.spaces_between_special_tokens:
special_token = self.added_token_ids.get(next_token_id)
@ -211,6 +221,23 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
return token or ""
def _protected_step(self, next_token_id: int) -> Optional[str]:
try:
token = self.stream.step(self.tokenizer, next_token_id)
except Exception as e:
if str(e) != INVALID_PREFIX_ERR_MSG:
raise e
# Recover from edge case where tokenizer can produce non-monotonic,
# invalid UTF-8 output, which breaks the internal state of
# tokenizers' DecodeStream.
# See https://github.com/vllm-project/vllm/issues/17448.
logger.warning(
"Encountered invalid prefix detokenization error"
" for request %s, resetting decode stream.", self.request_id)
self.stream = DecodeStream(self.skip_special_tokens)
token = self.stream.step(self.tokenizer, next_token_id)
return token
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):