mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
[Bugfix] Support missing tool parameters in mistral tokenizer (#12884)
Signed-off-by: Florian Greinacher <florian.greinacher@siemens.com>
This commit is contained in:
parent
2c0f58203c
commit
cb080f32e3
50
tests/tokenization/test_mistral_tokenizer.py
Normal file
50
tests/tokenization/test_mistral_tokenizer.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mistral_common.protocol.instruct.messages import UserMessage
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||||
|
|
||||||
|
from vllm.transformers_utils.tokenizers.mistral import (
|
||||||
|
make_mistral_chat_completion_request)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"openai_request,expected_mistral_request",
|
||||||
|
[(
|
||||||
|
{
|
||||||
|
"messages": [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the current local date and time?",
|
||||||
|
}],
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"description": "Fetch the current local date and time.",
|
||||||
|
"name": "get_current_time",
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
ChatCompletionRequest(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="What is the current local date and time?")
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_current_time",
|
||||||
|
description="Fetch the current local date and time.",
|
||||||
|
parameters={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
def test_make_mistral_chat_completion_request(openai_request,
|
||||||
|
expected_mistral_request):
|
||||||
|
assert (make_mistral_chat_completion_request(
|
||||||
|
openai_request["messages"],
|
||||||
|
openai_request["tools"]) == expected_mistral_request)
|
||||||
@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
|
|||||||
return matched_files[0]
|
return matched_files[0]
|
||||||
|
|
||||||
|
|
||||||
|
def make_mistral_chat_completion_request(
|
||||||
|
messages: List["ChatCompletionMessageParam"],
|
||||||
|
tools: Optional[List[Dict[str,
|
||||||
|
Any]]] = None) -> "ChatCompletionRequest":
|
||||||
|
last_message = cast(Dict[str, Any], messages[-1])
|
||||||
|
if last_message["role"] == "assistant":
|
||||||
|
last_message["prefix"] = True
|
||||||
|
|
||||||
|
last_message = cast(Dict[str, Any], messages[-1])
|
||||||
|
if last_message["role"] == "assistant":
|
||||||
|
last_message["prefix"] = True
|
||||||
|
|
||||||
|
# mistral-common requires AssistantMessage content to be string [1].
|
||||||
|
#
|
||||||
|
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "assistant":
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
content = "\n".join(chunk.get("text") for chunk in content)
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||||
|
# "parameters" dict to be present, even if it's empty.
|
||||||
|
if tools:
|
||||||
|
for function in [
|
||||||
|
tool["function"] for tool in tools
|
||||||
|
if tool["type"] == "function"
|
||||||
|
]:
|
||||||
|
function.setdefault("parameters", {})
|
||||||
|
|
||||||
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
return ChatCompletionRequest(messages=messages,
|
||||||
|
tools=tools) # type: ignore[type-var]
|
||||||
|
|
||||||
|
|
||||||
class MistralTokenizer:
|
class MistralTokenizer:
|
||||||
|
|
||||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||||
@ -283,27 +319,10 @@ class MistralTokenizer:
|
|||||||
|
|
||||||
def apply_chat_template(self,
|
def apply_chat_template(self,
|
||||||
messages: List["ChatCompletionMessageParam"],
|
messages: List["ChatCompletionMessageParam"],
|
||||||
tools: Optional[Dict[str, Any]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
**kwargs) -> List[int]:
|
**kwargs) -> List[int]:
|
||||||
|
|
||||||
last_message = cast(Dict[str, Any], messages[-1])
|
request = make_mistral_chat_completion_request(messages, tools)
|
||||||
if last_message["role"] == "assistant":
|
|
||||||
last_message["prefix"] = True
|
|
||||||
|
|
||||||
from mistral_common.protocol.instruct.request import (
|
|
||||||
ChatCompletionRequest)
|
|
||||||
|
|
||||||
# mistral-common requires AssistantMessage content to be string [1].
|
|
||||||
#
|
|
||||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
|
||||||
for message in messages:
|
|
||||||
if message.get("role") == "assistant":
|
|
||||||
content = message.get("content")
|
|
||||||
if isinstance(content, list):
|
|
||||||
content = "\n".join(chunk.get("text") for chunk in content)
|
|
||||||
message["content"] = content
|
|
||||||
request = ChatCompletionRequest(messages=messages,
|
|
||||||
tools=tools) # type: ignore[type-var]
|
|
||||||
encoded = self.mistral.encode_chat_completion(request)
|
encoded = self.mistral.encode_chat_completion(request)
|
||||||
|
|
||||||
# encode-decode to get clean prompt
|
# encode-decode to get clean prompt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user