mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:35:28 +08:00
[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)
This commit is contained in:
parent
9606c7197d
commit
6fc4e6e07a
@ -11,4 +11,5 @@ pydantic >= 2.8
|
|||||||
torch
|
torch
|
||||||
py-cpuinfo
|
py-cpuinfo
|
||||||
transformers
|
transformers
|
||||||
|
mistral_common >= 1.3.4
|
||||||
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||||
|
|||||||
@ -26,3 +26,4 @@ librosa # Required for audio processing
|
|||||||
soundfile # Required for audio processing
|
soundfile # Required for audio processing
|
||||||
gguf == 0.9.1
|
gguf == 0.9.1
|
||||||
importlib_metadata
|
importlib_metadata
|
||||||
|
mistral_common >= 1.3.4
|
||||||
|
|||||||
@ -30,9 +30,11 @@ def test_models(
|
|||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype,
|
||||||
|
tokenizer_mode="mistral") as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
|
|||||||
@ -61,7 +61,8 @@ class ModelConfig:
|
|||||||
output when `served_model_name` is not specified.
|
output when `served_model_name` is not specified.
|
||||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
available, and "slow" will always use the slow tokenizer.
|
available, "slow" will always use the slow tokenizer, and
|
||||||
|
"mistral" will always use the tokenizer from `mistral_common`.
|
||||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||||
downloading the model and tokenizer.
|
downloading the model and tokenizer.
|
||||||
dtype: Data type for model weights and activations. The "auto" option
|
dtype: Data type for model weights and activations. The "auto" option
|
||||||
@ -246,10 +247,10 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = self.tokenizer_mode.lower()
|
||||||
if tokenizer_mode not in ["auto", "slow"]:
|
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||||
"either 'auto' or 'slow'.")
|
"either 'auto', 'slow' or 'mistral'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_embedding_mode(self) -> None:
|
def _verify_embedding_mode(self) -> None:
|
||||||
|
|||||||
@ -198,10 +198,11 @@ class EngineArgs:
|
|||||||
'--tokenizer-mode',
|
'--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.tokenizer_mode,
|
default=EngineArgs.tokenizer_mode,
|
||||||
choices=['auto', 'slow'],
|
choices=['auto', 'slow', 'mistral'],
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
'fast tokenizer if available.\n* "slow" will '
|
||||||
'always use the slow tokenizer.')
|
'always use the slow tokenizer. \n* '
|
||||||
|
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Trust remote code from huggingface.')
|
help='Trust remote code from huggingface.')
|
||||||
|
|||||||
@ -267,7 +267,7 @@ def apply_chat_template(
|
|||||||
*,
|
*,
|
||||||
tokenize: bool = False, # Different from HF's default
|
tokenize: bool = False, # Different from HF's default
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> Union[str, List[int]]:
|
||||||
if chat_template is None and tokenizer.chat_template is None:
|
if chat_template is None and tokenizer.chat_template is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"As of transformers v4.44, default chat template is no longer "
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
@ -280,6 +280,4 @@ def apply_chat_template(
|
|||||||
tokenize=tokenize,
|
tokenize=tokenize,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
assert isinstance(prompt, str)
|
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
@ -390,15 +390,21 @@ class LLM:
|
|||||||
conversations, _ = parse_chat_messages(messages, model_config,
|
conversations, _ = parse_chat_messages(messages, model_config,
|
||||||
tokenizer)
|
tokenizer)
|
||||||
|
|
||||||
prompts = apply_chat_template(
|
prompt = apply_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
conversations,
|
conversations,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
add_generation_prompt=add_generation_prompt)
|
add_generation_prompt=add_generation_prompt)
|
||||||
|
|
||||||
|
inputs: PromptInputs
|
||||||
|
if isinstance(prompt, list) and isinstance(prompt[0], int):
|
||||||
|
inputs = TokensPrompt(prompt_token_ids=prompt)
|
||||||
|
else:
|
||||||
|
inputs = TextPrompt(prompt=prompt)
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompts,
|
inputs,
|
||||||
sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
FunctionCall, ToolCall, UsageInfo)
|
FunctionCall, ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
PromptAdapterPath)
|
PromptAdapterPath,
|
||||||
|
TextTokensPrompt)
|
||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
@ -130,13 +131,22 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await self._guided_decode_logits_processor(request, tokenizer))
|
await self._guided_decode_logits_processor(request, tokenizer))
|
||||||
|
|
||||||
prompt_inputs = self._tokenize_prompt_input(
|
if isinstance(prompt, str):
|
||||||
request,
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
tokenizer,
|
request,
|
||||||
prompt,
|
tokenizer,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
prompt,
|
||||||
add_special_tokens=request.add_special_tokens,
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
)
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(prompt, list) and isinstance(
|
||||||
|
prompt[0], int
|
||||||
|
), "Prompt has to be either a string or a list of token ids"
|
||||||
|
prompt_inputs = TextTokensPrompt(
|
||||||
|
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
|
||||||
|
|
||||||
|
assert prompt_inputs is not None
|
||||||
|
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
|||||||
@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens(
|
|||||||
prefix_offset = max(
|
prefix_offset = max(
|
||||||
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
|
||||||
# This is required to guard against out-of-vocab prompt token ids
|
# This is required to guard against out-of-vocab prompt token ids
|
||||||
_replace_none_with_empty(new_tokens)
|
_replace_none_with_empty(new_tokens) # type: ignore[arg-type]
|
||||||
return new_tokens, prefix_offset, read_offset
|
return new_tokens, prefix_offset, read_offset
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
@ -9,12 +10,14 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
|||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
from vllm.transformers_utils.tokenizers import (BaichuanTokenizer,
|
||||||
|
MistralTokenizer)
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||||
|
MistralTokenizer]
|
||||||
|
|
||||||
|
|
||||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||||
@ -99,45 +102,64 @@ def get_tokenizer(
|
|||||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||||
tokenizer_name = Path(tokenizer_name).parent
|
tokenizer_name = Path(tokenizer_name).parent
|
||||||
|
|
||||||
try:
|
# if tokenizer is from official mistral org
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
|
||||||
tokenizer_name,
|
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||||
*args,
|
warnings.warn(
|
||||||
trust_remote_code=trust_remote_code,
|
'It is strongly recommended to run mistral models with '
|
||||||
revision=revision,
|
'`--tokenizer_mode "mistral"` to ensure correct '
|
||||||
**kwargs)
|
'encoding and decoding.',
|
||||||
except ValueError as e:
|
FutureWarning,
|
||||||
# If the error pertains to the tokenizer class not existing or not
|
stacklevel=2)
|
||||||
# currently being imported, suggest using the --trust-remote-code flag.
|
|
||||||
if (not trust_remote_code and
|
if tokenizer_mode == "mistral":
|
||||||
("does not exist or is not currently imported." in str(e)
|
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||||
or "requires you to execute the tokenizer file" in str(e))):
|
revision=revision)
|
||||||
err_msg = (
|
else:
|
||||||
"Failed to load the tokenizer. If the tokenizer is a custom "
|
try:
|
||||||
"tokenizer not yet available in the HuggingFace transformers "
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"library, consider setting `trust_remote_code=True` in LLM "
|
|
||||||
"or using the `--trust-remote-code` flag in the CLI.")
|
|
||||||
raise RuntimeError(err_msg) from e
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
except AttributeError as e:
|
|
||||||
if "BaichuanTokenizer" in str(e):
|
|
||||||
# This is for the error "'BaichuanTokenizer' object has no
|
|
||||||
# attribute 'sp_model'".
|
|
||||||
tokenizer = BaichuanTokenizer.from_pretrained(
|
|
||||||
tokenizer_name,
|
tokenizer_name,
|
||||||
*args,
|
*args,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
else:
|
)
|
||||||
raise e
|
except ValueError as e:
|
||||||
|
# If the error pertains to the tokenizer class not existing or not
|
||||||
|
# currently being imported,
|
||||||
|
# suggest using the --trust-remote-code flag.
|
||||||
|
if not trust_remote_code and (
|
||||||
|
"does not exist or is not currently imported." in str(e)
|
||||||
|
or "requires you to execute the tokenizer file" in str(e)):
|
||||||
|
err_msg = ("Failed to load the tokenizer. If the tokenizer "
|
||||||
|
"is a custom tokenizer not yet available in the "
|
||||||
|
"HuggingFace transformers library, consider "
|
||||||
|
"setting `trust_remote_code=True` in LLM or using "
|
||||||
|
"the `--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
except AttributeError as e:
|
||||||
|
if "BaichuanTokenizer" in str(e):
|
||||||
|
# This is for the error "'BaichuanTokenizer' object has no
|
||||||
|
# attribute 'sp_model'".
|
||||||
|
tokenizer = BaichuanTokenizer.from_pretrained(
|
||||||
|
tokenizer_name,
|
||||||
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Using a slow tokenizer. This might cause a significant "
|
"Using a slow tokenizer. This might cause a significant "
|
||||||
"slowdown. Consider using a fast tokenizer instead.")
|
"slowdown. Consider using a fast tokenizer instead.")
|
||||||
return get_cached_tokenizer(tokenizer)
|
tokenizer = get_cached_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
|
from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["BaichuanTokenizer", "MistralTokenizer"]
|
||||||
"BaichuanTokenizer",
|
|
||||||
]
|
|
||||||
|
|||||||
174
vllm/transformers_utils/tokenizers/mistral.py
Normal file
174
vllm/transformers_utils/tokenizers/mistral.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
|
# yapf: disable
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.mistral import (
|
||||||
|
MistralTokenizer as PublicMistralTokenizer)
|
||||||
|
# yapf: enable
|
||||||
|
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||||
|
SentencePieceTokenizer)
|
||||||
|
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||||
|
Tekkenizer)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.entrypoints.chat_utils import ConversationMessage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Encoding:
|
||||||
|
input_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
def find_tokenizer_file(files: List[str]):
|
||||||
|
file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
|
||||||
|
|
||||||
|
matched_files = [file for file in files if file_pattern.match(file)]
|
||||||
|
if len(matched_files) > 1:
|
||||||
|
raise OSError(f"Found {len(matched_files)} files matching the "
|
||||||
|
"pattern: {matched_files}. Make sure only one Mistral "
|
||||||
|
"tokenizer is present in {tokenizer_name}.")
|
||||||
|
elif len(matched_files) == 0:
|
||||||
|
raise OSError(f"Found {len(matched_files)} files matching the "
|
||||||
|
"pattern: {matched_files}. Make sure that a Mistral "
|
||||||
|
"tokenizer is present in {tokenizer_name}.")
|
||||||
|
|
||||||
|
return matched_files[0]
|
||||||
|
|
||||||
|
|
||||||
|
class MistralTokenizer:
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
|
||||||
|
self.mistral = tokenizer
|
||||||
|
self.instruct = tokenizer.instruct_tokenizer
|
||||||
|
self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
|
||||||
|
|
||||||
|
self.vocab_size = len(self.tokenizer.vocab())
|
||||||
|
|
||||||
|
assert isinstance(self.tokenizer,
|
||||||
|
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||||
|
self.tokenizer)
|
||||||
|
self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
|
||||||
|
|
||||||
|
if self._is_tekken:
|
||||||
|
# Make sure special tokens will not raise
|
||||||
|
self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||||
|
|
||||||
|
# the following attributes are set to fit VLLM's design
|
||||||
|
self.is_fast = True
|
||||||
|
self.chat_template = True
|
||||||
|
self.all_special_ids: List[Any] = []
|
||||||
|
self.all_special_tokens: List[Any] = []
|
||||||
|
self.all_special_tokens_extended: List[Any] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls,
|
||||||
|
path_or_repo_id: str,
|
||||||
|
*,
|
||||||
|
revision: Optional[str] = None) -> "MistralTokenizer":
|
||||||
|
if not Path(path_or_repo_id).exists():
|
||||||
|
assert len(path_or_repo_id.split("/")) == 2, (
|
||||||
|
"You have either provided a non-existent path: "
|
||||||
|
"{path_or_repo_id} or an invalid HF Hub repo id.")
|
||||||
|
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
|
||||||
|
path_or_repo_id, revision)
|
||||||
|
elif Path(path_or_repo_id).is_dir():
|
||||||
|
tokenizer_file_name = find_tokenizer_file(
|
||||||
|
os.listdir(path_or_repo_id))
|
||||||
|
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
|
||||||
|
else:
|
||||||
|
assert Path(
|
||||||
|
path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||||
|
|
||||||
|
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||||
|
return cls(mistral_tokenizer)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
|
||||||
|
revision: Optional[str]) -> str:
|
||||||
|
api = HfApi()
|
||||||
|
repo_info = api.model_info(tokenizer_name)
|
||||||
|
files = [s.rfilename for s in repo_info.siblings]
|
||||||
|
|
||||||
|
filename = find_tokenizer_file(files)
|
||||||
|
|
||||||
|
tokenizer_file = hf_hub_download(tokenizer_name,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision)
|
||||||
|
return tokenizer_file
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
add_special_tokens: bool = False,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
# Mistral Tokenizers should not add special tokens
|
||||||
|
input_ids = self.encode(prompt)
|
||||||
|
|
||||||
|
if truncation:
|
||||||
|
input_ids = input_ids[:max_length]
|
||||||
|
|
||||||
|
return Encoding(input_ids=input_ids)
|
||||||
|
|
||||||
|
def get_added_vocab(self) -> List[str]:
|
||||||
|
# Mistral tokenizers have no added vocabulary
|
||||||
|
return []
|
||||||
|
|
||||||
|
def encode(self, prompt: str) -> List[int]:
|
||||||
|
# `encode ` should only be used for prompt completion
|
||||||
|
# it should never be used for chat_completion.
|
||||||
|
# For chat completion use `apply_chat_template`
|
||||||
|
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||||
|
|
||||||
|
def apply_chat_template(self,
|
||||||
|
conversation: List["ConversationMessage"],
|
||||||
|
tools: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs) -> List[int]:
|
||||||
|
assert tools is None, "`tools` are not yet supported."
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
messages=conversation) # type: ignore[type-var]
|
||||||
|
encoded = self.mistral.encode_chat_completion(request)
|
||||||
|
|
||||||
|
# encode-decode to get clean prompt
|
||||||
|
return encoded.tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||||
|
if self._is_tekken:
|
||||||
|
return "".join(tokens)
|
||||||
|
else:
|
||||||
|
return self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def decode(self, ids: Union[List[int], int]) -> str:
|
||||||
|
if isinstance(ids, int):
|
||||||
|
ids = [ids]
|
||||||
|
return self.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self):
|
||||||
|
return self.tokenizer.eos_id
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(
|
||||||
|
self,
|
||||||
|
ids: List[int],
|
||||||
|
skip_special_tokens: Optional[bool] = True) -> List[str]:
|
||||||
|
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||||
|
assert (
|
||||||
|
skip_special_tokens
|
||||||
|
), "Skipping special tokens is not supported for Mistral tokenizers."
|
||||||
|
|
||||||
|
assert isinstance(self.tokenizer,
|
||||||
|
(Tekkenizer, SentencePieceTokenizer)), type(
|
||||||
|
self.tokenizer)
|
||||||
|
|
||||||
|
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.vocab_size
|
||||||
Loading…
x
Reference in New Issue
Block a user