[Bugfix] Support missing tool parameters in mistral tokenizer (#12884)

Signed-off-by: Florian Greinacher <florian.greinacher@siemens.com>
This commit is contained in:
Florian Greinacher 2025-02-11 04:33:33 +01:00 committed by GitHub
parent 2c0f58203c
commit cb080f32e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 19 deletions

View File

@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from vllm.transformers_utils.tokenizers.mistral import (
make_mistral_chat_completion_request)
# yapf: enable
@pytest.mark.parametrize(
"openai_request,expected_mistral_request",
[(
{
"messages": [{
"role": "user",
"content": "What is the current local date and time?",
}],
"tools": [{
"type": "function",
"function": {
"description": "Fetch the current local date and time.",
"name": "get_current_time",
},
}],
},
ChatCompletionRequest(
messages=[
UserMessage(content="What is the current local date and time?")
],
tools=[
Tool(
type="function",
function=Function(
name="get_current_time",
description="Fetch the current local date and time.",
parameters={},
),
)
],
),
)],
)
def test_make_mistral_chat_completion_request(openai_request,
expected_mistral_request):
assert (make_mistral_chat_completion_request(
openai_request["messages"],
openai_request["tools"]) == expected_mistral_request)

View File

@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
return matched_files[0]
def make_mistral_chat_completion_request(
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str,
Any]]] = None) -> "ChatCompletionRequest":
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
# mistral-common requires AssistantMessage content to be string [1].
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages:
if message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)
message["content"] = content
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict to be present, even if it's empty.
if tools:
for function in [
tool["function"] for tool in tools
if tool["type"] == "function"
]:
function.setdefault("parameters", {})
from mistral_common.protocol.instruct.request import ChatCompletionRequest
return ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
class MistralTokenizer:
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
@ -283,27 +319,10 @@ class MistralTokenizer:
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
last_message = cast(Dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest)
# mistral-common requires AssistantMessage content to be string [1].
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
for message in messages:
if message.get("role") == "assistant":
content = message.get("content")
if isinstance(content, list):
content = "\n".join(chunk.get("text") for chunk in content)
message["content"] = content
request = ChatCompletionRequest(messages=messages,
tools=tools) # type: ignore[type-var]
request = make_mistral_chat_completion_request(messages, tools)
encoded = self.mistral.encode_chat_completion(request)
# encode-decode to get clean prompt