[Bugfix]: Fix TokenizerLike interface (#30009)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar 2025-12-05 22:56:40 -06:00 committed by GitHub
parent e858bc4d14
commit 40a046cd82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 52 deletions

View File

@ -32,7 +32,6 @@ from typing import Any, cast
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
@ -189,7 +188,7 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
@ -201,7 +200,7 @@ class BenchmarkDataset(ABC):
for generating a list of SampleRequest objects.
Args:
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
tokenizer (TokenizerLike): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
request_id_prefix (str): The prefix of request_id.
@ -380,7 +379,7 @@ def process_video(video: Any) -> Mapping[str, Any]:
def gen_prompt_decode_to_target_len(
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
token_sequence: list[int],
target_token_len: int,
max_retry: int = 10,
@ -468,7 +467,7 @@ class RandomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
@ -580,7 +579,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
@ -626,7 +625,7 @@ class RandomDataset(BenchmarkDataset):
def generate_token_sequence(
self,
*,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
@ -686,7 +685,7 @@ class RandomDatasetForReranking(RandomDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
@ -716,7 +715,11 @@ class RandomDatasetForReranking(RandomDataset):
doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer
)
vocab_size = tokenizer.vocab_size
prohibited_tokens = tokenizer.all_special_ids
all_tokens = np.arange(vocab_size)
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence(
@ -727,6 +730,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=query_len,
offset=int(query_offsets[0]),
index=0,
allowed_tokens=allowed_tokens,
)
)
@ -740,6 +744,7 @@ class RandomDatasetForReranking(RandomDataset):
input_len=int(doc_lens[i]),
offset=int(doc_offsets[i]),
index=i + 1,
allowed_tokens=allowed_tokens,
)
token_mismatch_total += token_mismatch
requests.append((prompt, total_input_len))
@ -1077,7 +1082,7 @@ class RandomMultiModalDataset(RandomDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
@ -1231,7 +1236,7 @@ class ShareGPTDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
lora_path: str | None = None,
max_loras: int | None = None,
@ -1633,7 +1638,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
)
def get_samples(args, tokenizer) -> list[SampleRequest]:
def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
if not hasattr(args, "request_id_prefix"):
args.request_id_prefix = ""
@ -1971,7 +1976,7 @@ class CustomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
lora_path: str | None = None,
max_loras: int | None = None,
@ -2101,7 +2106,7 @@ class SonnetDataset(BenchmarkDataset):
def sample(
self,
tokenizer,
tokenizer: TokenizerLike,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
input_len: int = DEFAULT_INPUT_LEN,
@ -2202,7 +2207,7 @@ class BurstGPTDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
max_loras: int | None = None,
lora_path: str | None = None,
@ -2287,7 +2292,7 @@ class ConversationDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2347,7 +2352,7 @@ class MultiModalConversationDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2416,7 +2421,7 @@ class VisionArenaDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2470,7 +2475,7 @@ class MMVUDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2531,7 +2536,7 @@ class InstructCoderDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2595,7 +2600,7 @@ class MTBenchDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,
@ -2661,7 +2666,7 @@ class BlazeditDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
skip_chat_template: bool = False,
@ -2742,7 +2747,7 @@ class AIMODataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
@ -2852,7 +2857,7 @@ class NextEditPredictionDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
@ -2924,7 +2929,7 @@ class ASRDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
@ -3002,7 +3007,7 @@ class MLPerfDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
request_id_prefix: str = "",
@ -3081,7 +3086,7 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
prefix_len: int = DEFAULT_PREFIX_LEN,
suffix_len: int = DEFAULT_SUFFIX_LEN,
@ -3167,7 +3172,7 @@ class MMStarDataset(HuggingFaceDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
num_requests: int,
output_len: int | None = None,
enable_multimodal_chat: bool = False,

View File

@ -36,7 +36,6 @@ from typing import Any, Literal
import aiohttp
import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks.datasets import SampleRequest, add_dataset_parser, get_samples
from vllm.benchmarks.lib.endpoint_request_func import (
@ -47,7 +46,7 @@ from vllm.benchmarks.lib.endpoint_request_func import (
)
from vllm.benchmarks.lib.ready_checker import wait_for_endpoint
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.tokenizers import get_tokenizer
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.network_utils import join_host_port
@ -286,7 +285,7 @@ def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
selected_percentiles: list[float],
goodput_config_dict: dict[str, float],
) -> tuple[BenchmarkMetrics, list[int]]:
@ -489,7 +488,7 @@ async def benchmark(
base_url: str,
model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
input_requests: list[SampleRequest],
logprobs: int | None,
request_rate: float,
@ -1032,6 +1031,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help="""Tokenizer mode:\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--logprobs",
@ -1228,18 +1240,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Common prefix length shared by all prompts (used by random dataset)",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
choices=["auto", "slow", "mistral", "custom"],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
"always use the slow tokenizer. \n* "
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.',
)
parser.add_argument(
"--served-model-name",
type=str,

View File

@ -14,7 +14,7 @@ from typing import Any
import torch
import uvloop
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (
AIMODataset,
@ -35,6 +35,7 @@ from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.async_utils import merge_async_iterators
@ -246,12 +247,15 @@ async def run_vllm_async(
def run_hf(
requests: list[SampleRequest],
model: str,
tokenizer: PreTrainedTokenizerBase,
tokenizer: TokenizerLike,
n: int,
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
) -> float:
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
"the hf backend only supports HF tokenizers"
)
llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code
)
@ -692,15 +696,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
def main(args: argparse.Namespace):
if args.tokenizer is None:
args.tokenizer = args.model
validate_args(args)
if args.seed is None:
args.seed = 0
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
if (
args.backend == "hf" or args.backend == "mii"
) and args.tokenizer_mode == "auto":
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
# for hf and mii backends, we use hf tokenizer
args.tokenizer_mode = "hf"
tokenizer = get_tokenizer(
args.tokenizer,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code,
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)

View File

@ -136,7 +136,8 @@ class ModelConfig:
name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto"
"""Tokenizer mode:\n
- "auto" will use "hf" tokenizer if Mistral's tokenizer is not available.\n
- "auto" will use the tokenizer from `mistral_common` for Mistral models
if available, otherwise it will use the "hf" tokenizer.\n
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n

View File

@ -54,6 +54,9 @@ class DeepseekV32Tokenizer(HfTokenizer):
prompt_str = encode_messages(messages, **encode_config) # type: ignore
return prompt_str
def num_special_tokens_to_add(self) -> int:
return len(self.encode(""))
@property
def all_special_tokens(self) -> list[str]:
return self.tokenizer.all_special_tokens

View File

@ -309,6 +309,9 @@ class MistralTokenizer(TokenizerLike):
for i in all_special_ids
]
def num_special_tokens_to_add(self) -> int:
return len(self.encode(""))
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
@ -421,6 +424,7 @@ class MistralTokenizer(TokenizerLike):
) -> list[int]:
add_generation_prompt = kwargs.pop("add_generation_prompt", False)
continue_final_message = kwargs.get("continue_final_message", False)
tokenize = kwargs.get("tokenize", True)
padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
@ -433,7 +437,7 @@ class MistralTokenizer(TokenizerLike):
conversation=messages,
tools=tools,
continue_final_message=continue_final_message,
tokenize=True,
tokenize=tokenize,
padding=padding,
truncation=truncation,
max_length=max_length,

View File

@ -22,6 +22,9 @@ class TokenizerLike(Protocol):
) -> "TokenizerLike":
raise NotImplementedError
def num_special_tokens_to_add(self) -> int:
raise NotImplementedError
@property
def all_special_tokens(self) -> list[str]:
raise NotImplementedError

View File

@ -183,7 +183,7 @@ def get_tokenizer(
"`tokenizer_mode='custom'` when initializing vLLM.",
tokenizer_args,
str(tokenizer_kwargs),
tokenizer_mode,
tokenizer_name,
)
tokenizer_mode = str(tokenizer_name)