mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)
This commit is contained in:
parent
9606c7197d
commit
6fc4e6e07a
@ -11,4 +11,5 @@ pydantic >= 2.8
|
||||
torch
|
||||
py-cpuinfo
|
||||
transformers
|
||||
mistral_common >= 1.3.4
|
||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
|
||||
@ -26,3 +26,4 @@ librosa # Required for audio processing
|
||||
soundfile # Required for audio processing
|
||||
gguf == 0.9.1
|
||||
importlib_metadata
|
||||
mistral_common >= 1.3.4
|
||||
|
||||
@ -30,9 +30,11 @@ def test_models(
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
|
||||
@ -61,7 +61,8 @@ class ModelConfig:
|
||||
output when `served_model_name` is not specified.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
available, "slow" will always use the slow tokenizer, and
|
||||
"mistral" will always use the tokenizer from `mistral_common`.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
@ -246,10 +247,10 @@ class ModelConfig:
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto' or 'slow'.")
|
||||
"either 'auto', 'slow' or 'mistral'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_embedding_mode(self) -> None:
|
||||
|
||||
@ -198,10 +198,11 @@ class EngineArgs:
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
choices=['auto', 'slow', 'mistral'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='Trust remote code from huggingface.')
|
||||
|
||||
@ -267,7 +267,7 @@ def apply_chat_template(
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
) -> Union[str, List[int]]:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
@ -280,6 +280,4 @@ def apply_chat_template(
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
return prompt
|
||||
|
||||
@ -390,15 +390,21 @@ class LLM:
|
||||
conversations, _ = parse_chat_messages(messages, model_config,
|
||||
tokenizer)
|
||||
|
||||
prompts = apply_chat_template(
|
||||
prompt = apply_chat_template(
|
||||
tokenizer,
|
||||
conversations,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
inputs: PromptInputs
|
||||
if isinstance(prompt, list) and isinstance(prompt[0], int):
|
||||
inputs = TokensPrompt(prompt_token_ids=prompt)
|
||||
else:
|
||||
inputs = TextPrompt(prompt=prompt)
|
||||
|
||||
return self.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
@ -130,13 +131,22 @@ class OpenAIServingChat(OpenAIServing):
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
assert isinstance(prompt, list) and isinstance(
|
||||
prompt[0], int
|
||||
), "Prompt has to be either a string or a list of token ids"
|
||||
prompt_inputs = TextTokensPrompt(
|
||||
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
tokenizer,
|
||||
|
||||
@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
|
||||
prefix_offset = max(
|
||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||
# This is required to guard against out-of-vocab prompt token ids
|
||||
_replace_none_with_empty(new_tokens)
|
||||
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
|
||||
MistralTokenizer)
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
@ -99,45 +102,64 @@ def get_tokenizer(
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
if (not trust_remote_code and
|
||||
("does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e))):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
||||
"tokenizer not yet available in the HuggingFace transformers "
|
||||
"library, consider setting `trust_remote_code=True` in LLM "
|
||||
"or using the `--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
except AttributeError as e:
|
||||
if "BaichuanTokenizer" in str(e):
|
||||
# This is for the error "'BaichuanTokenizer' object has no
|
||||
# attribute 'sp_model'".
|
||||
tokenizer = BaichuanTokenizer.from_pretrained(
|
||||
# if tokenizer is from official mistral org
|
||||
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
|
||||
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||
warnings.warn(
|
||||
'It is strongly recommended to run mistral models with '
|
||||
'`--tokenizer_mode "mistral"` to ensure correct '
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||
revision=revision)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs)
|
||||
else:
|
||||
raise e
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported,
|
||||
# suggest using the --trust-remote-code flag.
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)):
|
||||
err_msg = ("Failed to load the tokenizer. If the tokenizer "
|
||||
"is a custom tokenizer not yet available in the "
|
||||
"HuggingFace transformers library, consider "
|
||||
"setting `trust_remote_code=True` in LLM or using "
|
||||
"the `--trust-remote-code` flag in the CLI.")
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
except AttributeError as e:
|
||||
if "BaichuanTokenizer" in str(e):
|
||||
# This is for the error "'BaichuanTokenizer' object has no
|
||||
# attribute 'sp_model'".
|
||||
tokenizer = BaichuanTokenizer.from_pretrained(
|
||||
tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead.")
|
||||
return get_cached_tokenizer(tokenizer)
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead.")
|
||||
tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
__all__ = [
|
||||
"BaichuanTokenizer",
|
||||
]
|
||||
__all__ = ["BaichuanTokenizer", "MistralTokenizer"]
|
||||
|
||||
174
vllm/transformers_utils/tokenizers/mistral.py
Normal file
174
vllm/transformers_utils/tokenizers/mistral.py
Normal file
@ -0,0 +1,174 @@
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
# yapf: disable
|
||||
from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
# yapf: enable
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ConversationMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Encoding:
|
||||
input_ids: List[int]
|
||||
|
||||
|
||||
def find_tokenizer_file(files: List[str]):
|
||||
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
|
||||
|
||||
matched_files = [file for file in files if file_pattern.match(file)]
|
||||
if len(matched_files) > 1:
|
||||
raise OSError(f"Found {len(matched_files)} files matching the "
|
||||
"pattern: {matched_files}. Make sure only one Mistral "
|
||||
"tokenizer is present in {tokenizer_name}.")
|
||||
elif len(matched_files) == 0:
|
||||
raise OSError(f"Found {len(matched_files)} files matching the "
|
||||
"pattern: {matched_files}. Make sure that a Mistral "
|
||||
"tokenizer is present in {tokenizer_name}.")
|
||||
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
class MistralTokenizer:
|
||||
|
||||
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
|
||||
|
||||
self.vocab_size = len(self.tokenizer.vocab())
|
||||
|
||||
assert isinstance(self.tokenizer,
|
||||
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||
self.tokenizer)
|
||||
self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
|
||||
|
||||
if self._is_tekken:
|
||||
# Make sure special tokens will not raise
|
||||
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
|
||||
# the following attributes are set to fit VLLM's design
|
||||
self.is_fast = True
|
||||
self.chat_template = True
|
||||
self.all_special_ids: List[Any] = []
|
||||
self.all_special_tokens: List[Any] = []
|
||||
self.all_special_tokens_extended: List[Any] = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
path_or_repo_id: str,
|
||||
*,
|
||||
revision: Optional[str] = None) -> "MistralTokenizer":
|
||||
if not Path(path_or_repo_id).exists():
|
||||
assert len(path_or_repo_id.split("/")) == 2, (
|
||||
"You have either provided a non-existent path: "
|
||||
"{path_or_repo_id} or an invalid HF Hub repo id.")
|
||||
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
|
||||
path_or_repo_id, revision)
|
||||
elif Path(path_or_repo_id).is_dir():
|
||||
tokenizer_file_name = find_tokenizer_file(
|
||||
os.listdir(path_or_repo_id))
|
||||
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
|
||||
else:
|
||||
assert Path(
|
||||
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||
|
||||
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||
return cls(mistral_tokenizer)
|
||||
|
||||
@staticmethod
|
||||
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
|
||||
revision: Optional[str]) -> str:
|
||||
api = HfApi()
|
||||
repo_info = api.model_info(tokenizer_name)
|
||||
files = [s.rfilename for s in repo_info.siblings]
|
||||
|
||||
filename = find_tokenizer_file(files)
|
||||
|
||||
tokenizer_file = hf_hub_download(tokenizer_name,
|
||||
filename=filename,
|
||||
revision=revision)
|
||||
return tokenizer_file
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_added_vocab(self) -> List[str]:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
return []
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
# `encode ` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
conversation: List["ConversationMessage"],
|
||||
tools: Optional[Dict[str, Any]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
assert tools is None, "`tools` are not yet supported."
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
messages=conversation) # type: ignore[type-var]
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
|
||||
# encode-decode to get clean prompt
|
||||
return encoded.tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
if self._is_tekken:
|
||||
return "".join(tokens)
|
||||
else:
|
||||
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||
|
||||
def decode(self, ids: Union[List[int], int]) -> str:
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.tokenizer.decode(ids)
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: Optional[bool] = True) -> List[str]:
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
skip_special_tokens
|
||||
), "Skipping special tokens is not supported for Mistral tokenizers."
|
||||
|
||||
assert isinstance(self.tokenizer,
|
||||
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||
self.tokenizer)
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
return tokens
|
||||
|
||||
def __len__(self):
|
||||
return self.vocab_size
|
||||
Loading…
x
Reference in New Issue
Block a user