mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[BugFix] Work-around incremental detokenization edge case error (#19449)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
7e3e74c97c
commit
d5bdf899e4
80
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
80
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
|
||||
def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
"""
|
||||
Test edge case where tokenizer can produce non-monotonic,
|
||||
invalid UTF-8 output, which breaks the internal state of
|
||||
tokenizers' DecodeStream.
|
||||
See https://github.com/vllm-project/vllm/issues/17448.
|
||||
|
||||
Thanks to reproducer from @fpaupier:
|
||||
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
||||
|
||||
# Create a test request
|
||||
prompt_token_ids = [107, 4606, 236787, 107]
|
||||
params = SamplingParams(skip_special_tokens=True)
|
||||
request = EngineCoreRequest(
|
||||
"test",
|
||||
prompt_token_ids,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
params,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||
|
||||
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
|
||||
"Should use FastIncrementalDetokenizer by default"
|
||||
|
||||
# Process tokens incrementally
|
||||
test_tokens = [
|
||||
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
|
||||
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
|
||||
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
|
||||
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
|
||||
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
|
||||
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
|
||||
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
|
||||
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
|
||||
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
|
||||
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
|
||||
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
|
||||
2555, 513, 236789, 602, 31118, 569
|
||||
]
|
||||
|
||||
output = ""
|
||||
for i, token_id in enumerate(test_tokens):
|
||||
detokenizer.update([token_id], False)
|
||||
|
||||
finished = i == len(test_tokens) - 1
|
||||
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
|
||||
# fmt: off
|
||||
assert output == r'''[
|
||||
{
|
||||
"source": "Résultats",
|
||||
"source_type": "CONCEPT",
|
||||
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
|
||||
"target": "Israël",
|
||||
"target_type": "ORGANIZATION",
|
||||
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
||||
"relationship": "Obtention d'un niveau de'''
|
||||
@ -17,6 +17,14 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||
# FastIncrementalDetokenizer.
|
||||
USE_FAST_DETOKENIZER = version.parse(
|
||||
tokenizers.__version__) >= version.parse("0.21.1")
|
||||
|
||||
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||
|
||||
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
@ -46,10 +54,9 @@ class IncrementalDetokenizer:
|
||||
# No tokenizer => skipping detokenization.
|
||||
return IncrementalDetokenizer()
|
||||
|
||||
if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse(
|
||||
tokenizers.__version__) >= version.parse("0.21.1")):
|
||||
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
|
||||
PreTrainedTokenizerFast):
|
||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||
# And only tokenizers >= 0.21.1 supports Fast Detokenizer.
|
||||
return FastIncrementalDetokenizer(tokenizer, request)
|
||||
|
||||
# Fall back to slow python-based incremental detokenization.
|
||||
@ -157,8 +164,11 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
super().__init__(request)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
|
||||
self.request_id = request.request_id
|
||||
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||
self.stream = DecodeStream(
|
||||
skip_special_tokens=sampling_params.skip_special_tokens)
|
||||
skip_special_tokens=self.skip_special_tokens)
|
||||
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
@ -174,7 +184,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
# Prime the stream.
|
||||
for tid in prompt_suffix:
|
||||
self.stream.step(self.tokenizer, tid)
|
||||
self._protected_step(tid)
|
||||
|
||||
self.spaces_between_special_tokens = (
|
||||
sampling_params.skip_special_tokens
|
||||
@ -199,7 +209,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
self.spaces_between_special_tokens = True
|
||||
|
||||
def decode_next(self, next_token_id: int) -> str:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
token = self._protected_step(next_token_id)
|
||||
|
||||
if not self.spaces_between_special_tokens:
|
||||
special_token = self.added_token_ids.get(next_token_id)
|
||||
@ -211,6 +221,23 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
return token or ""
|
||||
|
||||
def _protected_step(self, next_token_id: int) -> Optional[str]:
|
||||
try:
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
except Exception as e:
|
||||
if str(e) != INVALID_PREFIX_ERR_MSG:
|
||||
raise e
|
||||
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||
# invalid UTF-8 output, which breaks the internal state of
|
||||
# tokenizers' DecodeStream.
|
||||
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||
logger.warning(
|
||||
"Encountered invalid prefix detokenization error"
|
||||
" for request %s, resetting decode stream.", self.request_id)
|
||||
self.stream = DecodeStream(self.skip_special_tokens)
|
||||
token = self.stream.step(self.tokenizer, next_token_id)
|
||||
return token
|
||||
|
||||
|
||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user