vllm/tests/tokenization/test_detokenize.py
Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

242 lines
7.6 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Generator
from typing import Any
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import (
FastIncrementalDetokenizer,
IncrementalDetokenizer,
SlowIncrementalDetokenizer,
)
SPECIAL_TOKS_TRUTH = [
"Some text with adjacent special tokens <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>", # noqa
]
TRUTH = [
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa
"我很感谢你的热情",
# Burmese text triggers an edge-case for Mistral's V3-Tekken tokenizer (eg.
# for mistralai/Pixtral-12B-2409) where tokens may map to bytes with
# incomplete UTF-8 characters
# see https://github.com/vllm-project/vllm/pull/9625
"ပုံပြင်လေးပြောပြပါ်",
] + SPECIAL_TOKS_TRUTH
TOKENIZERS = [
"facebook/opt-125m",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
"tiiuae/falcon-7b",
"meta-llama/Llama-3.2-1B-Instruct",
"codellama/CodeLlama-7b-hf",
"mistralai/Pixtral-12B-2409",
]
def _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens: bool,
starting_index: int,
spaces_between_special_tokens: bool = True,
fast: bool | None = None,
):
prompt_token_ids = all_input_ids[:starting_index]
params = SamplingParams(
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
request = EngineCoreRequest(
request_id="",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
elif fast:
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
else:
detokenizer = SlowIncrementalDetokenizer(tokenizer, request)
output_text = ""
for i, token_id in enumerate(all_input_ids[starting_index:]):
detokenizer.update([token_id], False)
finished = i == len(all_input_ids) - 1
output_text += detokenizer.get_next_output_text(finished, delta=True)
return output_text, detokenizer.output_token_ids
@pytest.fixture
def tokenizer(tokenizer_name):
return (
MistralTokenizer.from_pretrained(tokenizer_name)
if "mistral" in tokenizer_name
else AutoTokenizer.from_pretrained(tokenizer_name)
)
@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"])
@pytest.mark.parametrize(
"truth",
[
# Burmese text triggers an edge-case where tokens may map to bytes with
# incomplete UTF-8 characters
"ပုံပြင်လေးပြောပြပါ",
# Using "URGENCY" since "CY" has token id 130282
"URGENCY🌶",
],
)
def test_mistral_edge_case(tokenizer, truth):
"""Test for a specific edge cases with V3-Tekken MistralTokenizer.
See https://github.com/vllm-project/vllm/pull/9625
"""
starting_index = 0
all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=True,
starting_index=starting_index,
)
assert decoded_text == truth
assert out_ids == all_input_ids[starting_index:]
@pytest.fixture
def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
if "mistral" in tokenizer_name:
yield (
True
if request.param
else pytest.skip("mistral doesn't support skip_special_tokens=False")
)
else:
yield bool(request.param)
@pytest.mark.parametrize("truth", TRUTH)
@pytest.mark.parametrize("with_prompt", [True, False])
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True)
@pytest.mark.parametrize("spaces_between_special_tokens", (True, False))
@pytest.mark.parametrize("fast", (True, False))
def test_decode_streaming(
tokenizer,
truth,
with_prompt,
skip_special_tokens,
spaces_between_special_tokens,
fast,
):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
if skip_special_tokens and not spaces_between_special_tokens:
pytest.skip()
if not fast and isinstance(tokenizer, PreTrainedTokenizerFast):
# Fix up inconsistency in fast/slow tokenizer behaviour.
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
at
for at in tokenizer._tokenizer.get_added_tokens_decoder().values()
if at.special
]
}
)
extra_decode_args = (
{}
if not isinstance(tokenizer, PreTrainedTokenizer)
else {"spaces_between_special_tokens": spaces_between_special_tokens}
)
truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids
if tokenizer.bos_token_id is not None:
truth_tokens.insert(0, tokenizer.bos_token_id)
truth_tokens.append(tokenizer.eos_token_id)
new_truth = tokenizer.decode(
truth_tokens, skip_special_tokens=skip_special_tokens, **extra_decode_args
)
if with_prompt:
num_prompt_tokens = len(
tokenizer(truth[: len(truth) // 2], add_special_tokens=False).input_ids
)
if tokenizer.bos_token_id is not None:
num_prompt_tokens += 1
prompt_input_ids = truth_tokens[:num_prompt_tokens]
generated_input_ids = truth_tokens[num_prompt_tokens:]
all_input_ids = prompt_input_ids + generated_input_ids
starting_index = len(prompt_input_ids)
prompt = tokenizer.decode(
prompt_input_ids,
skip_special_tokens=skip_special_tokens,
**extra_decode_args,
)
generated = new_truth[len(prompt) :]
else:
generated = new_truth
starting_index = 0
all_input_ids = truth_tokens
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
all_input_ids,
skip_special_tokens=skip_special_tokens,
starting_index=starting_index,
spaces_between_special_tokens=spaces_between_special_tokens,
fast=fast,
)
assert decoded_text == generated
assert out_ids == all_input_ids[starting_index:]
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("fast", (True, False))
def test_oov_decode(tokenizer, fast):
if fast and not isinstance(tokenizer, PreTrainedTokenizerFast):
pytest.skip()
decoded_text, out_ids = _run_incremental_decode(
tokenizer,
[len(tokenizer)],
skip_special_tokens=True,
starting_index=0,
spaces_between_special_tokens=True,
fast=fast,
)
assert decoded_text == ""
assert out_ids == [len(tokenizer)]