mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 20:57:58 +08:00
[Bugfix] [Frontend] Cleanup gpt-oss non-streaming chat tool calls (#25514)
Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f52b991db6
commit
d7fb5a4ae8
@ -194,6 +194,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
|||||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||||
args1 = tc.function.arguments
|
args1 = tc.function.arguments
|
||||||
assert args1 is not None and len(args1) > 0
|
assert args1 is not None and len(args1) > 0
|
||||||
|
assert not first_msg.content
|
||||||
|
|
||||||
messages.append({"role": "assistant", "content": args1})
|
messages.append({"role": "assistant", "content": args1})
|
||||||
messages.append({
|
messages.append({
|
||||||
|
|||||||
@ -70,7 +70,12 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
|||||||
assert extracted_info.content == "This is a test"
|
assert extracted_info.content == "This is a test"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
@pytest.mark.parametrize("tool_args", [
|
||||||
|
'{"location": "Tokyo"}',
|
||||||
|
'{\n"location": "Tokyo"\n}',
|
||||||
|
])
|
||||||
|
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
|
||||||
|
tool_args):
|
||||||
convo = Conversation.from_messages([
|
convo = Conversation.from_messages([
|
||||||
Message.from_role_and_content(Role.USER,
|
Message.from_role_and_content(Role.USER,
|
||||||
"What is the weather in Tokyo?"),
|
"What is the weather in Tokyo?"),
|
||||||
@ -80,7 +85,7 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
|||||||
).with_channel("analysis"),
|
).with_channel("analysis"),
|
||||||
Message.from_role_and_content(
|
Message.from_role_and_content(
|
||||||
Role.ASSISTANT,
|
Role.ASSISTANT,
|
||||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
tool_args).with_channel("commentary").with_recipient(
|
||||||
"functions.get_current_weather").with_content_type("json"),
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
])
|
])
|
||||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
@ -121,6 +126,17 @@ def test_extract_tool_calls_multiple_tools(
|
|||||||
Role.ASSISTANT,
|
Role.ASSISTANT,
|
||||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
"functions.get_user_location").with_content_type("json"),
|
"functions.get_user_location").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
|
||||||
|
"commentary").with_recipient("functions.no_content_type"),
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
|
||||||
|
"commentary").with_recipient("functions.not_json_no_content_type"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.empty_args").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
|
||||||
|
"functions.no_args").with_content_type("json"),
|
||||||
])
|
])
|
||||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
convo,
|
convo,
|
||||||
@ -141,7 +157,63 @@ def test_extract_tool_calls_multiple_tools(
|
|||||||
ToolCall(function=FunctionCall(
|
ToolCall(function=FunctionCall(
|
||||||
name="get_user_location",
|
name="get_user_location",
|
||||||
arguments=json.dumps({"location": "Tokyo"}),
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="no_content_type",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="not_json_no_content_type",
|
||||||
|
arguments="foo",
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="empty_args",
|
||||||
|
arguments=json.dumps({}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="no_args",
|
||||||
|
arguments="",
|
||||||
))
|
))
|
||||||
]
|
]
|
||||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
assert extracted_info.content is None
|
assert extracted_info.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_with_content(
|
||||||
|
openai_tool_parser,
|
||||||
|
harmony_encoding,
|
||||||
|
):
|
||||||
|
final_content = "This tool call will get the weather."
|
||||||
|
convo = Conversation.from_messages([
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||||
|
).with_channel("analysis"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT,
|
||||||
|
final_content).with_channel("final"),
|
||||||
|
])
|
||||||
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
|
convo,
|
||||||
|
Role.ASSISTANT,
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||||
|
"",
|
||||||
|
request=None,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert extracted_info.tools_called
|
||||||
|
expected_tool_calls = [
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
]
|
||||||
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
|
assert extracted_info.content == final_content
|
||||||
|
|||||||
@ -1186,6 +1186,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
|
reasoning_content, content, _ = parse_chat_output(token_ids)
|
||||||
|
if not request.include_reasoning:
|
||||||
|
reasoning_content = None
|
||||||
|
|
||||||
if self.tool_parser is not None:
|
if self.tool_parser is not None:
|
||||||
tool_parser = self.tool_parser(tokenizer)
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
# NOTE: We use token_ids for openai tool parser
|
# NOTE: We use token_ids for openai tool parser
|
||||||
@ -1194,10 +1198,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request=request,
|
request=request,
|
||||||
token_ids=token_ids, # type: ignore
|
token_ids=token_ids, # type: ignore
|
||||||
)
|
)
|
||||||
reasoning_content, content = None, tool_call_info.content
|
content = tool_call_info.content
|
||||||
if request.include_reasoning:
|
|
||||||
reasoning_content, content, _ = parse_chat_output(
|
|
||||||
token_ids)
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
@ -1205,10 +1206,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool_calls=tool_call_info.tool_calls,
|
tool_calls=tool_call_info.tool_calls,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reasoning_content, content, _ = parse_chat_output(
|
|
||||||
token_ids)
|
|
||||||
if not request.include_reasoning:
|
|
||||||
reasoning_content = None
|
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
role=role,
|
role=role,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -12,10 +13,13 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser, ToolParserManager)
|
ToolParser, ToolParserManager)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ToolParserManager.register_module("openai")
|
@ToolParserManager.register_module("openai")
|
||||||
class OpenAIToolParser(ToolParser):
|
class OpenAIToolParser(ToolParser):
|
||||||
@ -40,17 +44,33 @@ class OpenAIToolParser(ToolParser):
|
|||||||
|
|
||||||
if len(parser.messages) > 0:
|
if len(parser.messages) > 0:
|
||||||
for msg in parser.messages:
|
for msg in parser.messages:
|
||||||
|
if len(msg.content) < 1:
|
||||||
|
continue
|
||||||
|
msg_text = msg.content[0].text
|
||||||
if msg.recipient and msg.recipient.startswith("functions."):
|
if msg.recipient and msg.recipient.startswith("functions."):
|
||||||
|
# If no content-type is given assume JSON, as that's the
|
||||||
|
# most common case with gpt-oss models.
|
||||||
|
if not msg.content_type or "json" in msg.content_type:
|
||||||
|
# load and dump the JSON text to check validity and
|
||||||
|
# remove any extra newlines or other odd formatting
|
||||||
|
try:
|
||||||
|
tool_args = json.dumps(json.loads(msg_text))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.exception(
|
||||||
|
"Error decoding JSON tool call from response.")
|
||||||
|
tool_args = msg_text
|
||||||
|
else:
|
||||||
|
tool_args = msg_text
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=msg.recipient.split("functions.")[1],
|
name=msg.recipient.split("functions.")[1],
|
||||||
arguments=msg.content[0].text,
|
arguments=tool_args,
|
||||||
),
|
),
|
||||||
))
|
))
|
||||||
elif msg.channel == "final":
|
elif msg.channel == "final":
|
||||||
final_content = msg.content[0].text
|
final_content = msg_text
|
||||||
|
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=len(tool_calls) > 0,
|
tools_called=len(tool_calls) > 0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user