[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)

This commit is contained in:
Patrick von Platen 2024-08-27 14:40:02 +02:00 committed by GitHub
parent 9606c7197d
commit 6fc4e6e07a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 275 additions and 60 deletions

View File

@ -11,4 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

View File

@ -26,3 +26,4 @@ librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4

View File

@ -30,9 +30,11 @@ def test_models(
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model, dtype=dtype) as vllm_model:
with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,

View File

@ -61,7 +61,8 @@ class ModelConfig:
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
dtype: Data type for model weights and activations. The "auto" option
@ -246,10 +247,10 @@ class ModelConfig:
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
if tokenizer_mode not in ["auto", "slow", "mistral"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:

View File

@ -198,10 +198,11 @@ class EngineArgs:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
choices=['auto', 'slow', 'mistral'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')

View File

@ -267,7 +267,7 @@ def apply_chat_template(
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
) -> Union[str, List[int]]:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
@ -280,6 +280,4 @@ def apply_chat_template(
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt

View File

@ -390,15 +390,21 @@ class LLM:
conversations, _ = parse_chat_messages(messages, model_config,
tokenizer)
prompts = apply_chat_template(
prompt = apply_chat_template(
tokenizer,
conversations,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt)
inputs: PromptInputs
if isinstance(prompt, list) and isinstance(prompt[0], int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
inputs = TextPrompt(prompt=prompt)
return self.generate(
prompts,
sampling_params,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
)

View File

@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
@ -130,6 +131,7 @@ class OpenAIServingChat(OpenAIServing):
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
@ -137,6 +139,14 @@ class OpenAIServingChat(OpenAIServing):
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,

View File

@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
prefix_offset = max(
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
# This is required to guard against out-of-vocab prompt token ids
_replace_none_with_empty(new_tokens)
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
return new_tokens, prefix_offset, read_offset

View File

@ -1,4 +1,5 @@
import os
import warnings
from pathlib import Path
from typing import Optional, Union
@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
MistralTokenizer)
from vllm.utils import make_async
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
@ -99,24 +102,40 @@ def get_tokenizer(
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
# if tokenizer is from official mistral org
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
'It is strongly recommended to run mistral models with '
'`--tokenizer_mode "mistral"` to ensure correct '
'encoding and decoding.',
FutureWarning,
stacklevel=2)
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)):
err_msg = ("Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
@ -129,7 +148,8 @@ def get_tokenizer(
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs)
**kwargs,
)
else:
raise e
@ -137,7 +157,9 @@ def get_tokenizer(
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return get_cached_tokenizer(tokenizer)
tokenizer = get_cached_tokenizer(tokenizer)
return tokenizer
def get_lora_tokenizer(lora_request: LoRARequest, *args,

View File

@ -1,5 +1,4 @@
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
__all__ = [
"BaichuanTokenizer",
]
__all__ = ["BaichuanTokenizer", "MistralTokenizer"]

View File

@ -0,0 +1,174 @@
import os
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from huggingface_hub import HfApi, hf_hub_download
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
# yapf: enable
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
@dataclass
class Encoding:
input_ids: List[int]
def find_tokenizer_file(files: List[str]):
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
raise OSError(f"Found {len(matched_files)} files matching the "
"pattern: {matched_files}. Make sure only one Mistral "
"tokenizer is present in {tokenizer_name}.")
elif len(matched_files) == 0:
raise OSError(f"Found {len(matched_files)} files matching the "
"pattern: {matched_files}. Make sure that a Mistral "
"tokenizer is present in {tokenizer_name}.")
return matched_files[0]
class MistralTokenizer:
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
self.vocab_size = len(self.tokenizer.vocab())
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
if self._is_tekken:
# Make sure special tokens will not raise
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
# the following attributes are set to fit VLLM's design
self.is_fast = True
self.chat_template = True
self.all_special_ids: List[Any] = []
self.all_special_tokens: List[Any] = []
self.all_special_tokens_extended: List[Any] = []
@classmethod
def from_pretrained(cls,
path_or_repo_id: str,
*,
revision: Optional[str] = None) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists():
assert len(path_or_repo_id.split("/")) == 2, (
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id.")
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision)
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(
os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@staticmethod
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
revision: Optional[str]) -> str:
api = HfApi()
repo_info = api.model_info(tokenizer_name)
files = [s.rfilename for s in repo_info.siblings]
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(tokenizer_name,
filename=filename,
revision=revision)
return tokenizer_file
def __call__(
self,
prompt: str,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
if truncation:
input_ids = input_ids[:max_length]
return Encoding(input_ids=input_ids)
def get_added_vocab(self) -> List[str]:
# Mistral tokenizers have no added vocabulary
return []
def encode(self, prompt: str) -> List[int]:
# `encode ` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
def apply_chat_template(self,
conversation: List["ConversationMessage"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."
request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt
return encoded.tokens
def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self._is_tekken:
return "".join(tokens)
else:
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
def decode(self, ids: Union[List[int], int]) -> str:
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
@property
def eos_token_id(self):
return self.tokenizer.eos_id
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: Optional[bool] = True) -> List[str]:
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert (
skip_special_tokens
), "Skipping special tokens is not supported for Mistral tokenizers."
assert isinstance(self.tokenizer,
(Tekkenizer, SentencePieceTokenizer)), type(
self.tokenizer)
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
return tokens
def __len__(self):
return self.vocab_size