mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 09:45:01 +08:00
[Bugfix] Mistral tool calling when content is list (#18729)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
696259ca01
commit
5873877241
@ -1,15 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mistral_common.protocol.instruct.messages import UserMessage
|
from mistral_common.protocol.instruct.messages import (AssistantMessage,
|
||||||
|
ToolMessage,
|
||||||
|
UserMessage)
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
from mistral_common.protocol.instruct.tool_calls import (Function,
|
||||||
|
FunctionCall, Tool,
|
||||||
|
ToolCall)
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizers.mistral import (
|
from vllm.transformers_utils.tokenizers.mistral import (
|
||||||
make_mistral_chat_completion_request)
|
make_mistral_chat_completion_request)
|
||||||
|
|
||||||
|
|
||||||
# yapf: enable
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"openai_request,expected_mistral_request",
|
"openai_request,expected_mistral_request",
|
||||||
[(
|
[(
|
||||||
@ -78,6 +81,107 @@ from vllm.transformers_utils.tokenizers.mistral import (
|
|||||||
)
|
)
|
||||||
def test_make_mistral_chat_completion_request(openai_request,
|
def test_make_mistral_chat_completion_request(openai_request,
|
||||||
expected_mistral_request):
|
expected_mistral_request):
|
||||||
assert (make_mistral_chat_completion_request(
|
actual_request = make_mistral_chat_completion_request(
|
||||||
openai_request["messages"],
|
openai_request["messages"], openai_request["tools"])
|
||||||
openai_request["tools"]) == expected_mistral_request)
|
assert actual_request == expected_mistral_request
|
||||||
|
|
||||||
|
|
||||||
|
# Tool use with list content and reasoning_content
|
||||||
|
@pytest.mark.parametrize("openai_request,expected_mistral_request", [(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's the weather in Paris?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"reasoning_content":
|
||||||
|
None,
|
||||||
|
"content":
|
||||||
|
None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": '{"city": "Paris"}',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Rainy"
|
||||||
|
}],
|
||||||
|
"name": "get_weather",
|
||||||
|
"tool_call_id": "call123",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"tools": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Gets the current weather in a city.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
ChatCompletionRequest(
|
||||||
|
messages=[
|
||||||
|
UserMessage(content="What's the weather in Paris?"),
|
||||||
|
AssistantMessage(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id="call123",
|
||||||
|
function=FunctionCall(
|
||||||
|
name="get_weather",
|
||||||
|
arguments='{"city": "Paris"}',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage(
|
||||||
|
content="Rainy",
|
||||||
|
tool_call_id="call123",
|
||||||
|
name="get_weather",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
Tool(
|
||||||
|
type="function",
|
||||||
|
function=Function(
|
||||||
|
name="get_weather",
|
||||||
|
description="Gets the current weather in a city.",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city name"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)])
|
||||||
|
def test_make_mistral_chat_completion_request_list_content(
|
||||||
|
openai_request, expected_mistral_request):
|
||||||
|
actual_request = make_mistral_chat_completion_request(
|
||||||
|
openai_request["messages"], openai_request["tools"])
|
||||||
|
assert actual_request == expected_mistral_request
|
||||||
|
|||||||
@ -156,7 +156,11 @@ def make_mistral_chat_completion_request(
|
|||||||
#
|
#
|
||||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.get("role") == "assistant":
|
# Remove reasoning_content as unsupported by Mistral
|
||||||
|
_ = message.pop("reasoning_content", None) # type: ignore
|
||||||
|
|
||||||
|
# Convert list text content to string
|
||||||
|
if message.get("role") in ("assistant", "tool"):
|
||||||
content = message.get("content")
|
content = message.get("content")
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
content = "\n".join(chunk.get("text") for chunk in content)
|
content = "\n".join(chunk.get("text") for chunk in content)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user