mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 16:22:16 +08:00
[Frontend][Bug] allow tool calls in analysis channel (#28139)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
parent
086b96339f
commit
455949675d
212
tests/entrypoints/openai/test_serving_chat_stream_harmony.py
Normal file
212
tests/entrypoints/openai/test_serving_chat_stream_harmony.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Unit tests for harmony streaming delta extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.serving_chat_stream_harmony import (
|
||||||
|
extract_harmony_streaming_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockMessage:
|
||||||
|
"""Mock message object for testing."""
|
||||||
|
|
||||||
|
channel: str | None = None
|
||||||
|
recipient: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockStreamableParser:
|
||||||
|
"""Mock StreamableParser for testing without openai_harmony dependency."""
|
||||||
|
|
||||||
|
messages: list[MockMessage] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractHarmonyStreamingDelta:
|
||||||
|
"""Tests for extract_harmony_streaming_delta function."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"delta_text,expected_content",
|
||||||
|
[
|
||||||
|
("Hello, world!", "Hello, world!"),
|
||||||
|
("", ""),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_final_channel_returns_content_delta(self, delta_text, expected_content):
|
||||||
|
"""Test that final channel returns a DeltaMessage with content."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel="final",
|
||||||
|
cur_recipient=None,
|
||||||
|
prev_recipient=None,
|
||||||
|
delta_text=delta_text,
|
||||||
|
include_reasoning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message is not None
|
||||||
|
assert delta_message.content == expected_content
|
||||||
|
assert tools_streamed is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"include_reasoning,expected_has_message",
|
||||||
|
[
|
||||||
|
(True, True),
|
||||||
|
(False, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
|
||||||
|
"""Test analysis channel respects include_reasoning flag."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel="analysis",
|
||||||
|
cur_recipient=None,
|
||||||
|
prev_recipient=None,
|
||||||
|
delta_text="Let me think...",
|
||||||
|
include_reasoning=include_reasoning,
|
||||||
|
)
|
||||||
|
|
||||||
|
if expected_has_message:
|
||||||
|
assert delta_message is not None
|
||||||
|
assert delta_message.reasoning == "Let me think..."
|
||||||
|
else:
|
||||||
|
assert delta_message is None
|
||||||
|
assert tools_streamed is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||||
|
@patch("vllm.entrypoints.openai.serving_chat_stream_harmony.make_tool_call_id")
|
||||||
|
def test_new_tool_call(self, mock_make_tool_call_id, channel):
|
||||||
|
"""Test new tool call creation when recipient changes."""
|
||||||
|
mock_make_tool_call_id.return_value = "call_test123"
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel=channel,
|
||||||
|
cur_recipient="functions.get_weather",
|
||||||
|
prev_recipient=None,
|
||||||
|
delta_text="",
|
||||||
|
include_reasoning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message is not None
|
||||||
|
assert len(delta_message.tool_calls) == 1
|
||||||
|
tool_call = delta_message.tool_calls[0]
|
||||||
|
assert tool_call.id == "call_test123"
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
assert tool_call.function.name == "get_weather"
|
||||||
|
assert tool_call.function.arguments == ""
|
||||||
|
assert tool_call.index == 0
|
||||||
|
assert tools_streamed is True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||||
|
def test_tool_call_argument_streaming(self, channel):
|
||||||
|
"""Test streaming tool call arguments (same recipient)."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel=channel,
|
||||||
|
cur_recipient="functions.get_weather",
|
||||||
|
prev_recipient="functions.get_weather",
|
||||||
|
delta_text='{"location": "Paris"}',
|
||||||
|
include_reasoning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message is not None
|
||||||
|
tool_call = delta_message.tool_calls[0]
|
||||||
|
assert tool_call.id is None
|
||||||
|
assert tool_call.function.arguments == '{"location": "Paris"}'
|
||||||
|
assert tool_call.index == 0
|
||||||
|
assert tools_streamed is True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
|
||||||
|
def test_tool_call_empty_arguments_returns_none(self, channel):
|
||||||
|
"""Test empty delta_text with same recipient returns None."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel=channel,
|
||||||
|
cur_recipient="functions.get_weather",
|
||||||
|
prev_recipient="functions.get_weather",
|
||||||
|
delta_text="",
|
||||||
|
include_reasoning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message is None
|
||||||
|
assert tools_streamed is False
|
||||||
|
|
||||||
|
def test_tool_call_index_from_previous_messages(self):
|
||||||
|
"""Test tool call index accounts for previous function messages."""
|
||||||
|
messages = [
|
||||||
|
MockMessage(channel="analysis", recipient=None), # Not counted
|
||||||
|
MockMessage(channel="commentary", recipient="functions.tool1"), # Counted
|
||||||
|
MockMessage(channel="final", recipient=None), # Not counted
|
||||||
|
]
|
||||||
|
parser = MockStreamableParser(messages=messages)
|
||||||
|
|
||||||
|
delta_message, _ = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel="commentary",
|
||||||
|
cur_recipient="functions.tool2",
|
||||||
|
prev_recipient="functions.tool2",
|
||||||
|
delta_text="args",
|
||||||
|
include_reasoning=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message.tool_calls[0].index == 1
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"channel,recipient",
|
||||||
|
[
|
||||||
|
("commentary", None),
|
||||||
|
("commentary", "browser.search"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_returns_tool_call_preambles(self, channel, recipient):
|
||||||
|
"""Test that invalid channel/recipient combinations return None."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
delta_text = "some text"
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel=channel,
|
||||||
|
cur_recipient=recipient,
|
||||||
|
prev_recipient=None,
|
||||||
|
delta_text=delta_text,
|
||||||
|
include_reasoning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message.content == delta_text
|
||||||
|
assert tools_streamed is False
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"channel,recipient",
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
("unknown_channel", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_returns_none_for_invalid_inputs(self, channel, recipient):
|
||||||
|
"""Test that invalid channel/recipient combinations return None."""
|
||||||
|
parser = MockStreamableParser()
|
||||||
|
|
||||||
|
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||||
|
harmony_parser=parser,
|
||||||
|
cur_channel=channel,
|
||||||
|
cur_recipient=recipient,
|
||||||
|
prev_recipient=None,
|
||||||
|
delta_text="some text",
|
||||||
|
include_reasoning=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert delta_message is None
|
||||||
|
assert tools_streamed is False
|
||||||
@ -51,6 +51,9 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_chat_stream_harmony import (
|
||||||
|
extract_harmony_streaming_delta,
|
||||||
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import (
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
GenerationError,
|
GenerationError,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
@ -837,64 +840,17 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids = as_list(output.token_ids)
|
current_token_ids = as_list(output.token_ids)
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
if cur_channel == "final":
|
delta_message, tools_streamed_flag = (
|
||||||
delta_message = DeltaMessage(content=delta_text)
|
extract_harmony_streaming_delta(
|
||||||
elif cur_channel == "analysis":
|
harmony_parser=harmony_parser,
|
||||||
if request.include_reasoning:
|
cur_channel=cur_channel,
|
||||||
delta_message = DeltaMessage(reasoning=delta_text)
|
cur_recipient=cur_recipient,
|
||||||
else:
|
prev_recipient=prev_recipient,
|
||||||
delta_message = None
|
delta_text=delta_text,
|
||||||
elif (
|
include_reasoning=request.include_reasoning,
|
||||||
cur_channel == "commentary"
|
)
|
||||||
and cur_recipient
|
)
|
||||||
and cur_recipient.startswith("functions.")
|
harmony_tools_streamed[i] |= tools_streamed_flag
|
||||||
):
|
|
||||||
# Count completed tool calls to determine index
|
|
||||||
base_index = 0
|
|
||||||
for msg in harmony_parser.messages:
|
|
||||||
if (
|
|
||||||
msg.channel == "commentary"
|
|
||||||
and msg.recipient
|
|
||||||
and msg.recipient.startswith("functions.")
|
|
||||||
):
|
|
||||||
base_index += 1
|
|
||||||
|
|
||||||
if prev_recipient != cur_recipient:
|
|
||||||
tool_name = cur_recipient.split("functions.", 1)[1]
|
|
||||||
delta_message = DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
id=make_tool_call_id(),
|
|
||||||
type="function",
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
name=tool_name,
|
|
||||||
arguments="",
|
|
||||||
),
|
|
||||||
index=base_index,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
elif delta_text:
|
|
||||||
delta_message = DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=base_index,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments=delta_text
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
delta_message = None
|
|
||||||
|
|
||||||
if delta_message is not None:
|
|
||||||
harmony_tools_streamed[i] = True
|
|
||||||
elif cur_channel == "commentary":
|
|
||||||
# Tool call preambles meant to be shown to the user
|
|
||||||
delta_message = DeltaMessage(content=delta_text)
|
|
||||||
else:
|
|
||||||
delta_message = None
|
|
||||||
# handle streaming deltas for tools with named tool_choice
|
# handle streaming deltas for tools with named tool_choice
|
||||||
elif tool_choice_function_name:
|
elif tool_choice_function_name:
|
||||||
if (
|
if (
|
||||||
|
|||||||
101
vllm/entrypoints/openai/serving_chat_stream_harmony.py
Normal file
101
vllm/entrypoints/openai/serving_chat_stream_harmony.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Harmony-specific streaming delta extraction for chat completions.
|
||||||
|
|
||||||
|
This module handles the extraction of DeltaMessage objects from
|
||||||
|
harmony parser state during streaming chat completions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from openai_harmony import StreamableParser
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
DeltaFunctionCall,
|
||||||
|
DeltaMessage,
|
||||||
|
DeltaToolCall,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_harmony_streaming_delta(
|
||||||
|
harmony_parser: StreamableParser,
|
||||||
|
cur_channel: str | None,
|
||||||
|
cur_recipient: str | None,
|
||||||
|
prev_recipient: str | None,
|
||||||
|
delta_text: str,
|
||||||
|
include_reasoning: bool,
|
||||||
|
) -> tuple[DeltaMessage | None, bool]:
|
||||||
|
"""
|
||||||
|
Extract a DeltaMessage from harmony parser state during streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
harmony_parser: The StreamableParser instance tracking parse state
|
||||||
|
cur_channel: Current channel ("final", "analysis", "commentary", etc.)
|
||||||
|
cur_recipient: Current recipient (e.g., "functions.my_func")
|
||||||
|
prev_recipient: Previous recipient for detecting tool call transitions
|
||||||
|
delta_text: The text delta to include in the message
|
||||||
|
include_reasoning: Whether to include reasoning content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (DeltaMessage or None, tools_streamed_flag)
|
||||||
|
"""
|
||||||
|
tools_streamed = False
|
||||||
|
|
||||||
|
if cur_channel == "final":
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
elif (
|
||||||
|
(cur_channel == "commentary" or cur_channel == "analysis")
|
||||||
|
and cur_recipient
|
||||||
|
and cur_recipient.startswith("functions.")
|
||||||
|
):
|
||||||
|
# Count completed tool calls to determine index
|
||||||
|
base_index = 0
|
||||||
|
for msg in harmony_parser.messages:
|
||||||
|
if (
|
||||||
|
(msg.channel == "commentary" or msg.channel == "analysis")
|
||||||
|
and msg.recipient
|
||||||
|
and msg.recipient.startswith("functions.")
|
||||||
|
):
|
||||||
|
base_index += 1
|
||||||
|
|
||||||
|
if prev_recipient != cur_recipient:
|
||||||
|
tool_name = cur_recipient.split("functions.", 1)[1]
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
id=make_tool_call_id(),
|
||||||
|
type="function",
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=tool_name,
|
||||||
|
arguments="",
|
||||||
|
),
|
||||||
|
index=base_index,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif delta_text:
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=base_index,
|
||||||
|
function=DeltaFunctionCall(arguments=delta_text),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
|
|
||||||
|
if delta_message is not None:
|
||||||
|
tools_streamed = True
|
||||||
|
elif cur_channel == "commentary":
|
||||||
|
# Tool call preambles meant to be shown to the user
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
elif cur_channel == "analysis":
|
||||||
|
if include_reasoning:
|
||||||
|
delta_message = DeltaMessage(reasoning=delta_text)
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
|
|
||||||
|
return delta_message, tools_streamed
|
||||||
Loading…
x
Reference in New Issue
Block a user