[gpt-oss][2/N] Support input_messages in responsesRequest (#26962)

Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia 2025-10-27 16:15:49 -07:00 committed by GitHub
parent 69f064062b
commit 53a56e658b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 481 additions and 17 deletions

View File

@ -535,11 +535,17 @@ def get_place_to_travel():
return "Paris"
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_weather":
return get_weather(**args)
elif name == "get_place_to_travel":
return get_place_to_travel()
elif name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@ -828,3 +834,126 @@ async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
assert response.status == "completed"
assert len(response.input_messages) > 0
assert len(response.output_messages) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_with_previous_input_messages(
client: OpenAI, model_name: str
):
"""Test function calling using previous_input_messages
for multi-turn conversation with a function call"""
# Define the get_horoscope tool
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
# Step 1: First call with the function tool
stream_response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
extra_body={"enable_response_messages": True},
stream=True,
)
response = None
async for event in stream_response:
if event.type == "response.completed":
response = event.response
assert response is not None
assert response.status == "completed"
# Step 2: Parse the first output to find the function_call type
function_call = None
for item in response.output:
if item.type == "function_call":
function_call = item
break
assert function_call is not None, "Expected a function_call in the output"
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
# Verify the format matches expectations
args = json.loads(function_call.arguments)
assert "sign" in args
# Step 3: Call the get_horoscope function
result = call_function(function_call.name, args)
assert "Aquarius" in result
assert "baby otter" in result
# Get the input_messages and output_messages from the first response
first_input_messages = response.input_messages
first_output_messages = response.output_messages
# Construct the full conversation history using previous_input_messages
previous_messages = (
first_input_messages
+ first_output_messages
+ [
{
"role": "tool",
"name": "functions.get_horoscope",
"content": [{"type": "text", "text": str(result)}],
}
]
)
# Step 4: Make another responses.create() call with previous_input_messages
stream_response_2 = await client.responses.create(
model=model_name,
tools=tools,
input="",
extra_body={
"previous_input_messages": previous_messages,
"enable_response_messages": True,
},
stream=True,
)
async for event in stream_response_2:
if event.type == "response.completed":
response_2 = event.response
assert response_2 is not None
assert response_2.status == "completed"
assert response_2.output_text is not None
# verify only one system message / developer message
num_system_messages_input = 0
num_developer_messages_input = 0
num_function_call_input = 0
for message_dict in response_2.input_messages:
message = Message.from_dict(message_dict)
if message.author.role == "system":
num_system_messages_input += 1
elif message.author.role == "developer":
num_developer_messages_input += 1
elif message.author.role == "tool":
num_function_call_input += 1
assert num_system_messages_input == 1
assert num_developer_messages_input == 1
assert num_function_call_input == 1
# Verify the output makes sense - should contain information about the horoscope
output_text = response_2.output_text.lower()
assert (
"aquarius" in output_text or "otter" in output_text or "tuesday" in output_text
)

View File

@ -125,6 +125,28 @@ class TestInitializeToolSessions:
# Verify that init_tool_sessions was called
assert mock_context.init_tool_sessions_called
def test_validate_create_responses_input(
self, serving_responses_instance, mock_context, mock_exit_stack
):
request = ResponsesRequest(
input="test input",
previous_input_messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is my horoscope? I am an Aquarius.",
}
],
}
],
previous_response_id="lol",
)
error = serving_responses_instance._validate_create_responses_input(request)
assert error is not None
assert error.error.type == "invalid_request_error"
class TestValidateGeneratorInput:
"""Test class for _validate_generator_input method"""

View File

@ -0,0 +1,254 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from openai_harmony import Role
from vllm.entrypoints.harmony_utils import parse_input_to_harmony_message
class TestParseInputToHarmonyMessage:
"""Tests for parse_input_to_harmony_message function."""
def test_assistant_message_with_tool_calls(self):
"""Test parsing assistant message with tool calls."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
},
{
"function": {
"name": "search_web",
"arguments": '{"query": "latest news"}',
}
},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 2
# First tool call
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == '{"location": "San Francisco"}'
assert messages[0].channel == "commentary"
assert messages[0].recipient == "functions.get_weather"
assert messages[0].content_type == "json"
# Second tool call
assert messages[1].author.role == Role.ASSISTANT
assert messages[1].content[0].text == '{"query": "latest news"}'
assert messages[1].channel == "commentary"
assert messages[1].recipient == "functions.search_web"
assert messages[1].content_type == "json"
def test_assistant_message_with_empty_tool_call_arguments(self):
"""Test parsing assistant message with tool call having None arguments."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_current_time",
"arguments": None,
}
}
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
assert messages[0].recipient == "functions.get_current_time"
def test_tool_message_with_string_content(self):
"""Test parsing tool message with string content."""
chat_msg = {
"role": "tool",
"name": "get_weather",
"content": "The weather in San Francisco is sunny, 72°F",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.get_weather"
assert (
messages[0].content[0].text == "The weather in San Francisco is sunny, 72°F"
)
assert messages[0].channel == "commentary"
def test_tool_message_with_array_content(self):
"""Test parsing tool message with array content."""
chat_msg = {
"role": "tool",
"name": "search_results",
"content": [
{"type": "text", "text": "Result 1: "},
{"type": "text", "text": "Result 2: "},
{
"type": "image",
"url": "http://example.com/img.png",
}, # Should be ignored
{"type": "text", "text": "Result 3"},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
def test_tool_message_with_empty_content(self):
"""Test parsing tool message with None content."""
chat_msg = {
"role": "tool",
"name": "empty_tool",
"content": None,
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
def test_system_message(self):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
def test_tool_call_with_missing_function_fields(self):
"""Test parsing tool call with missing name or arguments."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
}
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
def test_array_content_with_missing_text(self):
"""Test parsing array content where text field is missing."""
chat_msg = {
"role": "user",
"content": [
{}, # Missing text field
{"text": "actual text"},
],
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"

View File

@ -38,11 +38,14 @@ from openai_harmony import (
ToolDescription,
load_harmony_encoding,
)
from openai_harmony import Message as OpenAIHarmonyMessage
from openai_harmony import Role as OpenAIHarmonyRole
from vllm import envs
from vllm.entrypoints.openai.protocol import (
ChatCompletionToolsParam,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.utils import random_uuid
@ -228,7 +231,7 @@ def parse_response_input(
return msg
def parse_chat_input(chat_msg) -> list[Message]:
def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
@ -279,6 +282,40 @@ def parse_chat_input(chat_msg) -> list[Message]:
return [msg]
def construct_harmony_previous_input_messages(
request: ResponsesRequest,
) -> list[OpenAIHarmonyMessage]:
messages: list[OpenAIHarmonyMessage] = []
if request.previous_input_messages:
for message in request.previous_input_messages:
# Handle both OpenAIHarmonyMessage objects and dictionary inputs
if isinstance(message, OpenAIHarmonyMessage):
message_role = message.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(message)
else:
harmony_messages = parse_input_to_harmony_message(message)
for harmony_msg in harmony_messages:
message_role = harmony_msg.author.role
# To match OpenAI, instructions, reasoning and tools are
# always taken from the most recent Responses API request
# not carried over from previous requests
if (
message_role == OpenAIHarmonyRole.SYSTEM
or message_role == OpenAIHarmonyRole.DEVELOPER
):
continue
messages.append(harmony_msg)
return messages
def render_for_completion(messages: list[Message]) -> list[int]:
conversation = Conversation.from_messages(messages)
token_ids = get_encoding().render_conversation_for_completion(

View File

@ -47,6 +47,7 @@ from openai.types.responses import (
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai_harmony import Message as OpenAIHarmonyMessage
from vllm.utils.serial_utils import (
EmbedDType,
@ -383,10 +384,15 @@ class ResponsesRequest(OpenAIBaseModel):
default=False,
description=(
"Dictates whether or not to return messages as part of the "
"response object. Currently only supported for non-streaming "
"response object. Currently only supported for"
"non-background and gpt-oss only. "
),
)
# similar to input_messages / output_messages in ResponsesResponse
# we take in previous_input_messages (ie in harmony format)
# this cannot be used in conjunction with previous_response_id
# TODO: consider supporting non harmony messages as well
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params]
_DEFAULT_SAMPLING_PARAMS = {

View File

@ -27,8 +27,8 @@ from vllm.entrypoints.harmony_utils import (
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_input,
parse_chat_output,
parse_input_to_harmony_message,
render_for_completion,
)
from vllm.entrypoints.logger import RequestLogger
@ -1351,11 +1351,13 @@ class OpenAIServingChat(OpenAIServing):
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls"
if (tool_call_info is not None and tool_call_info.tools_called)
else output.finish_reason
if output.finish_reason
else "stop",
finish_reason=(
"tool_calls"
if (tool_call_info is not None and tool_call_info.tools_called)
else output.finish_reason
if output.finish_reason
else "stop"
),
stop_reason=output.stop_reason,
)
choices.append(choice_data)
@ -1522,11 +1524,13 @@ class OpenAIServingChat(OpenAIServing):
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls"
if auto_tools_called
else output.finish_reason
if output.finish_reason
else "stop",
finish_reason=(
"tool_calls"
if auto_tools_called
else output.finish_reason
if output.finish_reason
else "stop"
),
stop_reason=output.stop_reason,
token_ids=(
as_list(output.token_ids) if request.return_token_ids else None
@ -1685,9 +1689,11 @@ class OpenAIServingChat(OpenAIServing):
should_return_as_token_id,
),
logprob=max(step_token.logprob, -9999.0),
bytes=None
if step_decoded is None
else list(step_decoded.encode("utf-8", errors="replace")),
bytes=(
None
if step_decoded is None
else list(step_decoded.encode("utf-8", errors="replace"))
),
top_logprobs=self._get_top_logprobs(
step_top_logprobs,
num_output_top_logprobs,
@ -1764,7 +1770,7 @@ class OpenAIServingChat(OpenAIServing):
# Add user message.
for chat_msg in request.messages:
messages.extend(parse_chat_input(chat_msg))
messages.extend(parse_input_to_harmony_message(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)

View File

@ -63,6 +63,7 @@ from vllm.entrypoints.context import (
StreamingHarmonyContext,
)
from vllm.entrypoints.harmony_utils import (
construct_harmony_previous_input_messages,
get_developer_message,
get_stop_tokens_for_assistant_actions,
get_system_message,
@ -248,6 +249,13 @@ class OpenAIServingResponses(OpenAIServing):
),
status_code=HTTPStatus.BAD_REQUEST,
)
if request.previous_input_messages and request.previous_response_id:
return self.create_error_response(
err_type="invalid_request_error",
message="Only one of `previous_input_messages` and "
"`previous_response_id` can be set.",
status_code=HTTPStatus.BAD_REQUEST,
)
return None
async def create_responses(
@ -941,6 +949,8 @@ class OpenAIServingResponses(OpenAIServing):
instructions=request.instructions, tools=request.tools
)
messages.append(dev_msg)
messages += construct_harmony_previous_input_messages(request)
else:
# Continue the previous conversation.
# FIXME(woosuk): Currently, request params like reasoning and