From d5bdf899e4ea7db68f731af6be0635b54de4adb3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 11 Jun 2025 23:43:20 -0700 Subject: [PATCH] [BugFix] Work-around incremental detokenization edge case error (#19449) Signed-off-by: Nick Hill --- .../v1/engine/test_fast_incdec_prefix_err.py | 80 +++++++++++++++++++ vllm/v1/engine/detokenizer.py | 39 +++++++-- 2 files changed, 113 insertions(+), 6 deletions(-) create mode 100644 tests/v1/engine/test_fast_incdec_prefix_err.py diff --git a/tests/v1/engine/test_fast_incdec_prefix_err.py b/tests/v1/engine/test_fast_incdec_prefix_err.py new file mode 100644 index 0000000000000..5c844e0e7095e --- /dev/null +++ b/tests/v1/engine/test_fast_incdec_prefix_err.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from transformers import AutoTokenizer + +from vllm.sampling_params import SamplingParams +from vllm.v1.engine import EngineCoreRequest +from vllm.v1.engine.detokenizer import IncrementalDetokenizer + +# ruff: noqa: E501 + + +def test_fast_inc_detok_invalid_utf8_err_case(): + """ + Test edge case where tokenizer can produce non-monotonic, + invalid UTF-8 output, which breaks the internal state of + tokenizers' DecodeStream. + See https://github.com/vllm-project/vllm/issues/17448. + + Thanks to reproducer from @fpaupier: + https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3. + """ + tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") + + # Create a test request + prompt_token_ids = [107, 4606, 236787, 107] + params = SamplingParams(skip_special_tokens=True) + request = EngineCoreRequest( + "test", + prompt_token_ids, + None, + None, + None, + params, + None, + 0.0, + None, + cache_salt=None, + data_parallel_rank=None, + ) + + detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request) + + assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \ + "Should use FastIncrementalDetokenizer by default" + + # Process tokens incrementally + test_tokens = [ + 236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908, + 147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292, + 827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418, + 569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118, + 35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140, + 236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654, + 236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654, + 236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817, + 4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509, + 19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398, + 432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745, + 2555, 513, 236789, 602, 31118, 569 + ] + + output = "" + for i, token_id in enumerate(test_tokens): + detokenizer.update([token_id], False) + + finished = i == len(test_tokens) - 1 + output += detokenizer.get_next_output_text(finished, delta=True) + + +# fmt: off + assert output == r'''[ + { + "source": "Résultats", + "source_type": "CONCEPT", + "source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise", + "target": "Israël", + "target_type": "ORGANIZATION", + "target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »", + "relationship": "Obtention d'un niveau de''' diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index c6fe2d339c93d..35aceba0fe766 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -17,6 +17,14 @@ from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) +# Only tokenizers >= 0.21.1 supports DecodeStream used for +# FastIncrementalDetokenizer. +USE_FAST_DETOKENIZER = version.parse( + tokenizers.__version__) >= version.parse("0.21.1") + +# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042 +INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered" + class IncrementalDetokenizer: @@ -46,10 +54,9 @@ class IncrementalDetokenizer: # No tokenizer => skipping detokenization. return IncrementalDetokenizer() - if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse( - tokenizers.__version__) >= version.parse("0.21.1")): + if USE_FAST_DETOKENIZER and isinstance(tokenizer, + PreTrainedTokenizerFast): # Fast tokenizer => use tokenizers library DecodeStream. - # And only tokenizers >= 0.21.1 supports Fast Detokenizer. return FastIncrementalDetokenizer(tokenizer, request) # Fall back to slow python-based incremental detokenization. @@ -157,8 +164,11 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): super().__init__(request) sampling_params = request.sampling_params + + self.request_id = request.request_id + self.skip_special_tokens = sampling_params.skip_special_tokens self.stream = DecodeStream( - skip_special_tokens=sampling_params.skip_special_tokens) + skip_special_tokens=self.skip_special_tokens) self.tokenizer: Tokenizer = tokenizer._tokenizer @@ -174,7 +184,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): # Prime the stream. for tid in prompt_suffix: - self.stream.step(self.tokenizer, tid) + self._protected_step(tid) self.spaces_between_special_tokens = ( sampling_params.skip_special_tokens @@ -199,7 +209,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.spaces_between_special_tokens = True def decode_next(self, next_token_id: int) -> str: - token = self.stream.step(self.tokenizer, next_token_id) + token = self._protected_step(next_token_id) if not self.spaces_between_special_tokens: special_token = self.added_token_ids.get(next_token_id) @@ -211,6 +221,23 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): return token or "" + def _protected_step(self, next_token_id: int) -> Optional[str]: + try: + token = self.stream.step(self.tokenizer, next_token_id) + except Exception as e: + if str(e) != INVALID_PREFIX_ERR_MSG: + raise e + # Recover from edge case where tokenizer can produce non-monotonic, + # invalid UTF-8 output, which breaks the internal state of + # tokenizers' DecodeStream. + # See https://github.com/vllm-project/vllm/issues/17448. + logger.warning( + "Encountered invalid prefix detokenization error" + " for request %s, resetting decode stream.", self.request_id) + self.stream = DecodeStream(self.skip_special_tokens) + token = self.stream.step(self.tokenizer, next_token_id) + return token + class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):