[Bugfix] Multiple fixes for gpt-oss Chat Completion prompting (#28729)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Ben Browning 2025-12-11 23:59:53 -05:00 committed by GitHub
parent fe1787107e
commit 8f8fda261a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1749 additions and 139 deletions

View File

@ -1,21 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from openai.types.responses.response_output_item import McpCall
from openai_harmony import Author, Message, Role, TextContent
from tests.entrypoints.openai.utils import verify_harmony_messages
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
parse_input_to_harmony_message,
parse_output_message,
)
class TestParseInputToHarmonyMessage:
"""Tests for parse_input_to_harmony_message function."""
class TestCommonParseInputToHarmonyMessage:
"""
Tests for scenarios that are common to both Chat Completion
parse_chat_input_to_harmony_message and Responsees API
parse_input_to_harmony_message functions.
"""
def test_assistant_message_with_tool_calls(self):
@pytest.fixture(
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
)
def parse_function(self, request):
return request.param
def test_assistant_message_with_tool_calls(self, parse_function):
"""Test parsing assistant message with tool calls."""
chat_msg = {
"role": "assistant",
@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 2
@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
assert messages[1].recipient == "functions.search_web"
assert messages[1].content_type == "json"
def test_assistant_message_with_empty_tool_call_arguments(self):
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
"""Test parsing assistant message with tool call having None arguments."""
chat_msg = {
"role": "assistant",
@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
assert messages[0].recipient == "functions.get_current_time"
def test_system_message(self, parse_function):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self, parse_function):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self, parse_function):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self, parse_function):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self, parse_function):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self, parse_function):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_tool_call_with_missing_function_fields(self, parse_function):
"""Test parsing tool call with missing name or arguments."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
}
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
def test_array_content_with_missing_text(self, parse_function):
"""Test parsing array content where text field is missing."""
chat_msg = {
"role": "user",
"content": [
{}, # Missing text field
{"text": "actual text"},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
class TestParseInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Responses API
parse_input_to_harmony_message function.
"""
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
def test_tool_message_with_string_content(self):
"""Test parsing tool message with string content."""
chat_msg = {
@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.search_results"
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
def test_tool_message_with_empty_content(self):
@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.empty_tool"
assert messages[0].content[0].text == ""
def test_system_message(self):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_input_to_harmony_message(chat_msg)
class TestParseChatInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Chat Completion API
parse_chat_input_to_harmony_message function.
"""
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
def test_user_message_with_empty_content(self):
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_tool_call_with_missing_function_fields(self):
"""Test parsing tool call with missing name or arguments."""
def test_user_message_with_none_content(self):
chat_msg = {
"role": "user",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_assistant_message_with_empty_content(self):
chat_msg = {
"role": "assistant",
"content": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 0
def test_assistant_message_with_none_content(self):
chat_msg = {
"role": "assistant",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 0
def test_assistant_message_with_content_but_empty_reasoning(self):
chat_msg = {
"role": "assistant",
"content": "The answer is 4.",
"reasoning": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "final",
"content": "The answer is 4.",
},
],
)
def test_assistant_message_with_reasoning_but_empty_content(self):
chat_msg = {
"role": "assistant",
"reasoning": "I'm thinking about the user's question.",
"content": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_assistant_message_with_reasoning_but_none_content(self):
chat_msg = {
"role": "assistant",
"reasoning": "I'm thinking about the user's question.",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_assistant_message_with_tool_calls_but_no_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_array_content_with_missing_text(self):
"""Test parsing array content where text field is missing."""
def test_assistant_message_with_tool_calls_and_content(self):
chat_msg = {
"role": "user",
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"content": "I'll call the tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_assistant_message_with_tool_calls_and_reasoning(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
"content": "I'll call the tool.",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_tool_message_with_string_content(self):
tool_id_names = {
"call_123": "get_weather",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": "The weather in San Francisco is sunny, 72°F",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.get_weather",
"content": "The weather in San Francisco is sunny, 72°F",
"channel": "commentary",
},
],
)
def test_tool_message_with_array_content(self):
tool_id_names = {
"call_123": "search_results",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": [
{}, # Missing text field
{"text": "actual text"},
{"type": "text", "text": "Result 1: "},
{"type": "text", "text": "Result 2: "},
{
"type": "image",
"url": "http://example.com/img.png",
}, # Should be ignored
{"type": "text", "text": "Result 3"},
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.search_results",
"content": "Result 1: Result 2: Result 3",
"channel": "commentary",
},
],
)
def test_tool_message_with_empty_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": "",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
def test_tool_message_with_none_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": None,
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
class TestAutoDropAnalysisMessages:
def test_no_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_analysis_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_multiple_analysis_messages_without_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_drops_one_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first analysis message
assert cleaned_messages == messages[1:]
def test_drops_all_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first 3 analysis messages
assert cleaned_messages == messages[3:]
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 5."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped all those analysis messages
assert len(cleaned_messages) == 2
assert cleaned_messages[0].content[0].text == "The answer is 4."
assert cleaned_messages[1].content[0].text == "The answer is 5."
def test_drops_non_assistant_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.TOOL, "The tool thinks we should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the analysis message
assert cleaned_messages == messages[1:]
class TestParseChatOutput:
def test_parse_chat_output_interrupted_first_message(self) -> None:
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm in the middle of thinking"
assert final_content is None
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
"<|start|>assistant<|channel|>final"
"<|message|>I'm in the middle of answering"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm thinking."
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_complete_content(self) -> None:
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "The answer is 4."
def test_parse_chat_output_complete_commentary(self) -> None:
harmony_str = (
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I need to call some tools."
def test_parse_chat_output_complete_reasoning(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content is None
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content == "The answer is 4."
class TestParseOutputMessage:

View File

@ -11,13 +11,25 @@ import pytest_asyncio
from openai import OpenAI
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
from ...utils import RemoteOpenAIServer
from .utils import (
accumulate_streaming_response,
verify_chat_response,
verify_harmony_messages,
)
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
# Verify that data_parallel_rank defaults to None
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
class TestServingChatWithHarmony:
"""
These tests ensure Chat Completion requests are being properly converted into
Harmony messages and Harmony response messages back into Chat Completion responses.
These tests are not exhaustive, but each one was created to cover a specific case
that we got wrong but is now fixed.
Any changes to the tests and their expectations may result in changes to the
accuracy of model prompting and responses generated. It is suggested to run
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
any impact of changes in how we prompt Harmony models.
"""
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
def stream(self, request) -> bool:
"""Parameterize tests to run in both non-streaming and streaming modes."""
return request.param
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
return mock_engine
@pytest.fixture()
def serving_chat(self, mock_engine) -> OpenAIServingChat:
chat = _build_serving_chat(mock_engine)
chat.use_harmony = True
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
return chat
def mock_request_output_from_req_and_token_ids(
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
) -> RequestOutput:
# Our tests don't use most fields, so just get the token ids correct
completion_output = CompletionOutput(
index=0,
text="",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
)
return RequestOutput(
request_id=req.request_id,
prompt=[],
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[completion_output],
finished=finished,
)
@pytest.fixture
def weather_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
},
},
},
]
@pytest.fixture
def weather_messages_start(self) -> list[dict[str, Any]]:
return [
{
"role": "user",
"content": "What's the weather like in Paris today?",
},
]
async def generate_response_from_harmony_str(
self,
serving_chat: OpenAIServingChat,
req: ChatCompletionRequest,
harmony_str: str,
stream: bool = False,
) -> ChatCompletionResponse:
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
async def result_generator():
if stream:
for token_id in harmony_token_ids:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
else:
yield self.mock_request_output_from_req_and_token_ids(
req, harmony_token_ids, finished=True
)
generator_func = (
serving_chat.chat_completion_stream_generator
if stream
else serving_chat.chat_completion_full_generator
)
result = generator_func(
request=req,
result_generator=result_generator(),
request_id=req.request_id,
model_name=req.model,
conversation=[],
tokenizer=get_tokenizer(req.model),
request_metadata=RequestResponseMetadata(
request_id=req.request_id,
model_name=req.model,
),
)
if stream:
return await accumulate_streaming_response(result)
return await result
@pytest.mark.asyncio
async def test_simple_chat(self, serving_chat, stream):
messages = [{"role": "user", "content": "what is 1+1?"}]
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "We need to think really hard about this."
final_str = "The answer is 2."
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
# The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel.
{"role": "assistant", "channel": "final", "content": final_str},
],
)
@pytest.mark.asyncio
async def test_tool_call_response_with_content(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
commentary_str = "We'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
content=commentary_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"content": commentary_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_multi_turn_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
paris_tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", paris_tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
# Test the Chat Completion response for the second turn's output
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
response_2 = await self.generate_response_from_harmony_str(
serving_chat, req_2, response_str, stream=stream
)
verify_chat_response(response_2, content=paris_weather_str)
# Add the output messages from the second turn as input to the third turn
for choice in response_2.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add a new user message for the third turn
messages.append(
{
"role": "user",
"content": "What's the weather like in Boston today?",
},
)
# Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages(
input_messages_3,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
{
"role": "assistant",
"channel": "final",
"content": paris_weather_str,
},
{"role": "user", "content": messages[-1]["content"]},
],
)
# Test the Chat Completion response for the third turn's output
reasoning_str = "I'll call get_weather."
boston_tool_args_str = '{"location": "Boston"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
)
response_3 = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response_3,
reasoning=reasoning_str,
tool_calls=[("get_weather", boston_tool_args_str)],
)
tool_call = response_3.choices[0].message.tool_calls[0]
# Add the output messages from the third turn as input to the fourth turn
for choice in response_3.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "10 degrees Celsius",
},
)
# Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages(
input_messages_4,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{"role": "assistant"},
{"role": "tool"},
{
"role": "assistant",
"channel": "final",
},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": boston_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "10 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "4",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
# The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel.
{
"role": "assistant",
"channel": "final",
"content": messages[1]["content"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": [],
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator
from typing import Any
from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
UsageInfo,
)
async def accumulate_streaming_response(
stream_generator: AsyncGenerator[str, None],
) -> ChatCompletionResponse:
"""
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
This helper parses the SSE format and builds up the complete response
by combining all the delta chunks.
"""
accumulated_content = ""
accumulated_reasoning = None
accumulated_tool_calls: list[dict[str, Any]] = []
role = None
finish_reason = None
response_id = None
created = None
model = None
index = 0
async for chunk_str in stream_generator:
# Skip empty lines and [DONE] marker
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
continue
# Parse SSE format: "data: {json}\n\n"
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
try:
chunk_data = json.loads(json_str)
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
chunk = ChatCompletionStreamResponse(**chunk_data)
# Store metadata from first chunk
if response_id is None:
response_id = chunk.id
created = chunk.created
model = chunk.model
# Process each choice in the chunk
for choice in chunk.choices:
if choice.delta.role:
role = choice.delta.role
if choice.delta.content:
accumulated_content += choice.delta.content
if choice.delta.reasoning:
if accumulated_reasoning is None:
accumulated_reasoning = ""
accumulated_reasoning += choice.delta.reasoning
if choice.delta.tool_calls:
# Accumulate tool calls
for tool_call_delta in choice.delta.tool_calls:
# Find or create the tool call at this index
while len(accumulated_tool_calls) <= tool_call_delta.index:
accumulated_tool_calls.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if tool_call_delta.id:
accumulated_tool_calls[tool_call_delta.index]["id"] = (
tool_call_delta.id
)
if tool_call_delta.function:
if tool_call_delta.function.name:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["name"] += tool_call_delta.function.name
if tool_call_delta.function.arguments:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["arguments"] += tool_call_delta.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
if choice.index is not None:
index = choice.index
except json.JSONDecodeError:
continue
# Build the final message
message_kwargs = {
"role": role or "assistant",
"content": accumulated_content if accumulated_content else None,
"reasoning": accumulated_reasoning,
}
# Only include tool_calls if there are any
if accumulated_tool_calls:
message_kwargs["tool_calls"] = [
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
for tc in accumulated_tool_calls
]
message = ChatMessage(**message_kwargs)
# Build the final response
choice = ChatCompletionResponseChoice(
index=index,
message=message,
finish_reason=finish_reason or "stop",
)
# Create usage info (with dummy values for tests)
usage = UsageInfo(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
response = ChatCompletionResponse(
id=response_id or "chatcmpl-test",
object="chat.completion",
created=created or 0,
model=model or "test-model",
choices=[choice],
usage=usage,
)
return response
def verify_harmony_messages(
messages: list[Any], expected_messages: list[dict[str, Any]]
):
assert len(messages) == len(expected_messages)
for msg, expected in zip(messages, expected_messages):
if "role" in expected:
assert msg.author.role == expected["role"]
if "author_name" in expected:
assert msg.author.name == expected["author_name"]
if "channel" in expected:
assert msg.channel == expected["channel"]
if "recipient" in expected:
assert msg.recipient == expected["recipient"]
if "content" in expected:
assert msg.content[0].text == expected["content"]
if "content_type" in expected:
assert msg.content_type == expected["content_type"]
if "tool_definitions" in expected:
# Check that the tool definitions match the expected list of tool names
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
assert actual_tools == expected["tool_definitions"]
def verify_chat_response(
response: ChatCompletionResponse,
content: str | None = None,
reasoning: str | None = None,
tool_calls: list[tuple[str, str]] | None = None,
):
assert len(response.choices) == 1
message = response.choices[0].message
if content is not None:
assert message.content == content
else:
assert not message.content
if reasoning is not None:
assert message.reasoning == reasoning
else:
assert not message.reasoning
if tool_calls:
assert message.tool_calls is not None
assert len(message.tool_calls) == len(tool_calls)
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
assert tc.function.name == expected_name
assert tc.function.arguments == expected_args
else:
assert not message.tool_calls

View File

@ -232,7 +232,177 @@ def parse_response_input(
return msg
def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs: list[Message] = []
tool_id_names: dict[str, str] = {}
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
msgs = auto_drop_analysis_messages(msgs)
return msgs
def auto_drop_analysis_messages(msgs: list[Message]) -> list[Message]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index = -1
for i in range(len(msgs) - 1, -1, -1):
msg = msgs[i]
if msg.author.role == "assistant" and msg.channel == "final":
last_assistant_final_index = i
break
cleaned_msgs: list[Message] = []
for i, msg in enumerate(msgs):
if i < last_assistant_final_index and msg.channel == "analysis":
continue
cleaned_msgs.append(msg)
return cleaned_msgs
def flatten_chat_text_content(content: str | list | None) -> str | None:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if isinstance(content, list):
return "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
return content
def parse_chat_input_to_harmony_message(
chat_msg, tool_id_names: dict[str, str] | None = None
) -> list[Message]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names = tool_id_names or {}
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
msgs: list[Message] = []
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls", [])
if role == "assistant" and tool_calls:
content = flatten_chat_text_content(chat_msg.get("content"))
if content:
commentary_msg = Message.from_role_and_content(Role.ASSISTANT, content)
commentary_msg = commentary_msg.with_channel("commentary")
msgs.append(commentary_msg)
reasoning_content = chat_msg.get("reasoning") or chat_msg.get(
"reasoning_content"
)
if reasoning_content:
analysis_msg = Message.from_role_and_content(
Role.ASSISTANT, reasoning_content
)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)
msg = (
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
)
.with_channel("commentary")
.with_recipient("assistant")
)
return [msg]
# Non-tool reasoning content
reasoning_content = chat_msg.get("reasoning") or chat_msg.get("reasoning_content")
if role == "assistant" and reasoning_content:
analysis_msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_content)
analysis_msg = analysis_msg.with_channel("analysis")
msgs.append(analysis_msg)
# Default: user/assistant/system messages with content
content = chat_msg.get("content") or ""
if content is None:
content = ""
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if role == "assistant" and contents and contents[0].text:
msg = Message.from_role_and_contents(role, contents)
# Send non-tool assistant messages to the final channel
msg = msg.with_channel("final")
msgs.append(msg)
# For user/system/developer messages, add them directly even if no content.
elif role != "assistant":
msg = Message.from_role_and_contents(role, contents)
msgs.append(msg)
return msgs
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
if isinstance(content, list):
# Handle array format for tool message content
# by concatenating all text parts.
content = "".join(
item.get("text", "")
for item in content
if isinstance(item, dict) and item.get("type") == "text"
)
content = flatten_chat_text_content(content)
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"), content
@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[str | None, str | None, bool]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
reasoning = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
reasoning = output_msgs[0].content[0].text
final_content = parser.current_content
else:
reasoning_msg = output_msgs[:-1]
final_msg = output_msgs[-1]
reasoning = "\n".join([msg.content[0].text for msg in reasoning_msg])
final_content = final_msg.content[0].text
# Get completed messages from the parser
reasoning_texts = [
msg.content[0].text for msg in output_msgs if msg.channel == "analysis"
]
final_texts = [
msg.content[0].text for msg in output_msgs if msg.channel != "analysis"
]
# Extract partial messages from the parser
if parser.current_channel == "analysis" and parser.current_content:
reasoning_texts.append(parser.current_content)
elif parser.current_channel != "analysis" and parser.current_content:
final_texts.append(parser.current_content)
# Flatten multiple messages into a single string
reasoning: str | None = "\n".join(reasoning_texts)
final_content: str | None = "\n".join(final_texts)
# Return None instead of empty string since existing callers check for None
reasoning = reasoning or None
final_content = final_content or None
return reasoning, final_content, is_tool_call

View File

@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_inputs_to_harmony_messages,
parse_chat_output,
parse_input_to_harmony_message,
render_for_completion,
)
from vllm.entrypoints.openai.protocol import (
@ -822,6 +822,9 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is not None:
harmony_tools_streamed[i] = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice
@ -1770,6 +1773,11 @@ class OpenAIServingChat(OpenAIServing):
):
messages: list[OpenAIMessage] = []
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing.
@ -1788,8 +1796,7 @@ class OpenAIServingChat(OpenAIServing):
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.extend(parse_input_to_harmony_message(chat_msg))
messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)

View File

@ -43,6 +43,7 @@ class OpenAIToolParser(ToolParser):
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
commentary_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
@ -75,11 +76,15 @@ class OpenAIToolParser(ToolParser):
)
elif msg.channel == "final":
final_content = msg_text
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
# prefer final content over commentary content if both are present
# commentary content is tool call preambles meant to be shown to the user
content=final_content or commentary_content,
)
def extract_tool_calls_streaming(