mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 13:03:04 +08:00
[mistral_common] Add v11 tokenizer (#19193)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
9bc8bb07cf
commit
f20f9f063b
@ -44,11 +44,17 @@ class MistralToolCall(ToolCall):
|
||||
return id.isalnum() and len(id) == 9
|
||||
|
||||
|
||||
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
|
||||
return isinstance(model_tokenizer, MistralTokenizer) \
|
||||
and model_tokenizer.version >= 11
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
examples/tool_chat_template_mistral.jinja template.
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
|
||||
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
|
||||
- the examples/tool_chat_template_mistral.jinja template.
|
||||
|
||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||
"""
|
||||
@ -70,6 +76,12 @@ class MistralToolParser(ToolParser):
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
|
||||
re.DOTALL)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
@ -109,11 +121,25 @@ class MistralToolParser(ToolParser):
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
if self.fn_name_regex:
|
||||
matches = self.fn_name_regex.findall(tool_content)
|
||||
|
||||
function_call_arr = []
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append({
|
||||
"name": fn_name,
|
||||
"arguments": json.loads(args)
|
||||
})
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
|
||||
@ -187,6 +187,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
_mistral_version_str = self.instruct.tokenizer.version.value
|
||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user