mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 01:45:01 +08:00
[BugFix] Work-around incremental detokenization edge case error (#19449)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
7e3e74c97c
commit
d5bdf899e4
80
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
80
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||||
|
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def test_fast_inc_detok_invalid_utf8_err_case():
|
||||||
|
"""
|
||||||
|
Test edge case where tokenizer can produce non-monotonic,
|
||||||
|
invalid UTF-8 output, which breaks the internal state of
|
||||||
|
tokenizers' DecodeStream.
|
||||||
|
See https://github.com/vllm-project/vllm/issues/17448.
|
||||||
|
|
||||||
|
Thanks to reproducer from @fpaupier:
|
||||||
|
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
|
||||||
|
"""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
||||||
|
|
||||||
|
# Create a test request
|
||||||
|
prompt_token_ids = [107, 4606, 236787, 107]
|
||||||
|
params = SamplingParams(skip_special_tokens=True)
|
||||||
|
request = EngineCoreRequest(
|
||||||
|
"test",
|
||||||
|
prompt_token_ids,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
params,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
None,
|
||||||
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||||
|
|
||||||
|
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
|
||||||
|
"Should use FastIncrementalDetokenizer by default"
|
||||||
|
|
||||||
|
# Process tokens incrementally
|
||||||
|
test_tokens = [
|
||||||
|
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
|
||||||
|
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
|
||||||
|
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
|
||||||
|
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
|
||||||
|
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
|
||||||
|
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
|
||||||
|
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
|
||||||
|
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
|
||||||
|
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
|
||||||
|
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
|
||||||
|
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
|
||||||
|
2555, 513, 236789, 602, 31118, 569
|
||||||
|
]
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
for i, token_id in enumerate(test_tokens):
|
||||||
|
detokenizer.update([token_id], False)
|
||||||
|
|
||||||
|
finished = i == len(test_tokens) - 1
|
||||||
|
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
assert output == r'''[
|
||||||
|
{
|
||||||
|
"source": "Résultats",
|
||||||
|
"source_type": "CONCEPT",
|
||||||
|
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
|
||||||
|
"target": "Israël",
|
||||||
|
"target_type": "ORGANIZATION",
|
||||||
|
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
||||||
|
"relationship": "Obtention d'un niveau de'''
|
||||||
@ -17,6 +17,14 @@ from vllm.v1.engine import EngineCoreRequest
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
||||||
|
# FastIncrementalDetokenizer.
|
||||||
|
USE_FAST_DETOKENIZER = version.parse(
|
||||||
|
tokenizers.__version__) >= version.parse("0.21.1")
|
||||||
|
|
||||||
|
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
||||||
|
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
||||||
|
|
||||||
|
|
||||||
class IncrementalDetokenizer:
|
class IncrementalDetokenizer:
|
||||||
|
|
||||||
@ -46,10 +54,9 @@ class IncrementalDetokenizer:
|
|||||||
# No tokenizer => skipping detokenization.
|
# No tokenizer => skipping detokenization.
|
||||||
return IncrementalDetokenizer()
|
return IncrementalDetokenizer()
|
||||||
|
|
||||||
if (isinstance(tokenizer, PreTrainedTokenizerFast) and version.parse(
|
if USE_FAST_DETOKENIZER and isinstance(tokenizer,
|
||||||
tokenizers.__version__) >= version.parse("0.21.1")):
|
PreTrainedTokenizerFast):
|
||||||
# Fast tokenizer => use tokenizers library DecodeStream.
|
# Fast tokenizer => use tokenizers library DecodeStream.
|
||||||
# And only tokenizers >= 0.21.1 supports Fast Detokenizer.
|
|
||||||
return FastIncrementalDetokenizer(tokenizer, request)
|
return FastIncrementalDetokenizer(tokenizer, request)
|
||||||
|
|
||||||
# Fall back to slow python-based incremental detokenization.
|
# Fall back to slow python-based incremental detokenization.
|
||||||
@ -157,8 +164,11 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
super().__init__(request)
|
super().__init__(request)
|
||||||
|
|
||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
|
|
||||||
|
self.request_id = request.request_id
|
||||||
|
self.skip_special_tokens = sampling_params.skip_special_tokens
|
||||||
self.stream = DecodeStream(
|
self.stream = DecodeStream(
|
||||||
skip_special_tokens=sampling_params.skip_special_tokens)
|
skip_special_tokens=self.skip_special_tokens)
|
||||||
|
|
||||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||||
|
|
||||||
@ -174,7 +184,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
|
|
||||||
# Prime the stream.
|
# Prime the stream.
|
||||||
for tid in prompt_suffix:
|
for tid in prompt_suffix:
|
||||||
self.stream.step(self.tokenizer, tid)
|
self._protected_step(tid)
|
||||||
|
|
||||||
self.spaces_between_special_tokens = (
|
self.spaces_between_special_tokens = (
|
||||||
sampling_params.skip_special_tokens
|
sampling_params.skip_special_tokens
|
||||||
@ -199,7 +209,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
self.spaces_between_special_tokens = True
|
self.spaces_between_special_tokens = True
|
||||||
|
|
||||||
def decode_next(self, next_token_id: int) -> str:
|
def decode_next(self, next_token_id: int) -> str:
|
||||||
token = self.stream.step(self.tokenizer, next_token_id)
|
token = self._protected_step(next_token_id)
|
||||||
|
|
||||||
if not self.spaces_between_special_tokens:
|
if not self.spaces_between_special_tokens:
|
||||||
special_token = self.added_token_ids.get(next_token_id)
|
special_token = self.added_token_ids.get(next_token_id)
|
||||||
@ -211,6 +221,23 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||||||
|
|
||||||
return token or ""
|
return token or ""
|
||||||
|
|
||||||
|
def _protected_step(self, next_token_id: int) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
token = self.stream.step(self.tokenizer, next_token_id)
|
||||||
|
except Exception as e:
|
||||||
|
if str(e) != INVALID_PREFIX_ERR_MSG:
|
||||||
|
raise e
|
||||||
|
# Recover from edge case where tokenizer can produce non-monotonic,
|
||||||
|
# invalid UTF-8 output, which breaks the internal state of
|
||||||
|
# tokenizers' DecodeStream.
|
||||||
|
# See https://github.com/vllm-project/vllm/issues/17448.
|
||||||
|
logger.warning(
|
||||||
|
"Encountered invalid prefix detokenization error"
|
||||||
|
" for request %s, resetting decode stream.", self.request_id)
|
||||||
|
self.stream = DecodeStream(self.skip_special_tokens)
|
||||||
|
token = self.stream.step(self.tokenizer, next_token_id)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user