[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Aaron Pham 2025-09-04 23:39:12 -04:00 committed by GitHub
parent 65e038931d
commit c29fb540ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 573 additions and 63 deletions

View File

@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import pytest_asyncio
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
@ -17,6 +20,164 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
if TYPE_CHECKING:
from openai import OpenAI
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
args = [
"--enforce-eager",
"--max-model-len",
"8192",
"--tool-call-parser",
"openai",
"--reasoning-parser",
"openai_gptoss",
"--enable-auto-tool-choice",
]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
async with gptoss_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string"
},
"state": {
"type": "string"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}]
messages = [
{
"role": "user",
"content": "What is the weather in Dallas, TX?"
},
]
stream = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
name = None
args_buf = ""
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.tool_calls:
tc = delta.tool_calls[0]
if tc.function and tc.function.name:
name = tc.function.name
if tc.function and tc.function.arguments:
args_buf += tc.function.arguments
assert name is not None
assert len(args_buf) > 0
@pytest.mark.asyncio
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string"
},
"state": {
"type": "string"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}]
messages = [
{
"role": "system",
"content": "you are a helpful assistant"
},
{
"role": "user",
"content": "What is the weather in Dallas, TX?"
},
]
first = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
temperature=0.0,
)
first_msg = first.choices[0].message
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
tc = first_msg.tool_calls[0]
assert tc.function is not None and tc.function.name == "get_current_weather"
args1 = tc.function.arguments
assert args1 is not None and len(args1) > 0
messages.append({"role": "assistant", "content": args1})
messages.append({
"role": "user",
"content": "Now convert to celsius and return JSON only"
})
second = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
temperature=0.0,
)
second_msg = second.choices[0].message
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from openai_harmony import (Conversation, DeveloperContent,
HarmonyEncodingName, Message, Role, SystemContent,
load_harmony_encoding)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL = "gpt2"
@pytest.fixture(scope="module")
def openai_tokenizer():
# The parser does not use the tokenizer, but the constructor requires it.
return get_tokenizer(MODEL)
@pytest.fixture
def openai_tool_parser(openai_tokenizer):
return OpenAIToolParser(openai_tokenizer)
@pytest.fixture(scope="module")
def harmony_encoding():
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def assert_tool_calls(
actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16 # Default from protocol.py
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.SYSTEM,
SystemContent.new(),
),
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions("Talk like a pirate!")),
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
Message.from_role_and_content(Role.ASSISTANT,
"This is a test").with_channel("final")
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert not extracted_info.tools_called
assert extracted_info.tool_calls == []
assert extracted_info.content == "This is a test"
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages([
Message.from_role_and_content(Role.USER,
"What is the weather in Tokyo?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
))
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
def test_extract_tool_calls_multiple_tools(
openai_tool_parser,
harmony_encoding,
):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_user_location").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
))
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import datetime import datetime
import json import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
@ -18,7 +21,8 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
Role, StreamableParser, SystemContent, TextContent, Role, StreamableParser, SystemContent, TextContent,
ToolDescription, load_harmony_encoding) ToolDescription, load_harmony_encoding)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
ResponseInputOutputItem)
from vllm.utils import random_uuid from vllm.utils import random_uuid
REASONING_EFFORT = { REASONING_EFFORT = {
@ -63,13 +67,29 @@ def get_system_message(
return sys_msg return sys_msg
def get_developer_message(instructions: Optional[str] = None, def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
tools: Optional[list[Tool]] = None) -> Message: if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: Optional[str] = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
) -> Message:
dev_msg_content = DeveloperContent.new() dev_msg_content = DeveloperContent.new()
if instructions is not None: if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions) dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None: if tools is not None:
function_tools = [] function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools: for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"): if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message. # These are built-in tools that are added to the system message.
@ -80,11 +100,7 @@ def get_developer_message(instructions: Optional[str] = None,
raise ValueError(f"tool type {tool.type} not supported") raise ValueError(f"tool type {tool.type} not supported")
if function_tools: if function_tools:
function_tool_descriptions = [ function_tool_descriptions = [
ToolDescription.new( create_tool_definition(tool) for tool in function_tools
name=tool.name,
description=tool.description,
parameters=tool.parameters,
) for tool in function_tools
] ]
dev_msg_content = dev_msg_content.with_function_tools( dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions) function_tool_descriptions)
@ -148,16 +164,46 @@ def parse_response_input(
return msg return msg
def parse_chat_input(chat_msg) -> Message: def parse_chat_input(chat_msg) -> list[Message]:
role = chat_msg["role"] if not isinstance(chat_msg, dict):
content = chat_msg["content"] # Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"),
content).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str): if isinstance(content, str):
contents = [TextContent(text=content)] contents = [TextContent(text=content)]
else: else:
# TODO: Support refusal. # TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content] contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents) msg = Message.from_role_and_contents(role, contents)
return msg return [msg]
def render_for_completion(messages: list[Message]) -> list[int]: def render_for_completion(messages: list[Message]) -> list[int]:

View File

@ -6,7 +6,7 @@ import json
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union from typing import TYPE_CHECKING, Callable, Final, Optional, Union
import jinja2 import jinja2
import partial_json_parser import partial_json_parser
@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
get_streamable_parser_for_assistant() get_streamable_parser_for_assistant()
for _ in range(num_choices) for _ in range(num_choices)
] ]
harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony: if self.use_harmony:
harmony_parser = harmony_parsers[i] harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
for token_id in output.token_ids: for token_id in output.token_ids:
harmony_parser.process(token_id) harmony_parser.process(token_id)
is_reasoning = \ cur_channel = harmony_parser.current_channel
harmony_parser.current_channel == "analysis" cur_recipient = harmony_parser.current_recipient
if not request.include_reasoning and is_reasoning:
# Skip the reasoning content.
continue
delta_text = harmony_parser.last_content_delta or "" delta_text = harmony_parser.last_content_delta or ""
else: else:
delta_text = output.text delta_text = output.text
@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage] delta_message: Optional[DeltaMessage]
# just update previous_texts and previous_token_ids # just update previous_texts and previous_token_ids
if ((tool_choice_auto or self.reasoning_parser) if tool_choice_auto or self.reasoning_parser:
and not self.use_harmony):
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_text = previous_texts[i] previous_text = previous_texts[i]
@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids = as_list(output.token_ids) current_token_ids = as_list(output.token_ids)
if self.use_harmony: if self.use_harmony:
if is_reasoning: if cur_channel == "final":
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if request.include_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = None
elif (cur_channel == "commentary" and cur_recipient
and cur_recipient.startswith("functions.")):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith(
"functions.")):
base_index += 1
if prev_recipient != cur_recipient:
tool_name = cur_recipient.split(
"functions.", 1)[1]
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
),
index=base_index,
)
])
elif delta_text:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
index=base_index,
function=DeltaFunctionCall(
arguments=delta_text),
)
])
else:
delta_message = None
if delta_message is not None:
harmony_tools_streamed[i] = True
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name: elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i] if (self.reasoning_parser and not reasoning_end_arr[i]
@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
delta_tool_call, delta_tool_call,
]) ])
tools_streamed[i] = True
elif request.tool_choice == "required": elif request.tool_choice == "required":
assert previous_texts is not None assert previous_texts is not None
@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
if (delta_message and delta_message.tool_calls and if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None): delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1 history_tool_call_cnt += 1
tools_streamed[i] = True
# update the previous values for the next iteration # update the previous values for the next iteration
previous_texts[i] = current_text previous_texts[i] = current_text
@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids, delta_token_ids=delta_token_ids,
request=request)) request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only tool calls # when only tool calls
elif tool_choice_auto: elif tool_choice_auto:
assert tool_parser is not None assert tool_parser is not None
@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=output.token_ids, delta_token_ids=output.token_ids,
request=request)) request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only reasoning # when only reasoning
elif self.reasoning_parser: elif self.reasoning_parser:
@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
# wasn't ready to send a token, then # wasn't ready to send a token, then
# get the next token without streaming a chunk # get the next token without streaming a chunk
if delta_message is None: if delta_message is None:
continue if output.finish_reason is None:
continue
else:
delta_message = DeltaMessage()
# Log streaming delta if output logging is enabled # Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger: if self.enable_log_outputs and self.request_logger:
@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
]) ])
# Send the finish response for each request.n only once # Send the finish response for each request.n only once
if auto_tools_called or tools_streamed[i] or (
self.use_harmony
and harmony_tools_streamed[i]):
finish_reason_ = "tool_calls"
else:
finish_reason_ = output.finish_reason \
if output.finish_reason else "stop"
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=delta_message, delta=delta_message,
logprobs=logprobs, logprobs=logprobs,
finish_reason=output.finish_reason finish_reason=finish_reason_,
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
token_ids=(as_list(output.token_ids) token_ids=(as_list(output.token_ids)
if request.return_token_ids else None)) if request.return_token_ids else None))
@ -1131,31 +1188,32 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None logprobs = None
if self.use_harmony: if self.use_harmony:
reasoning_content, final_content, is_tool_call = ( if TYPE_CHECKING:
parse_chat_output(token_ids)) assert self.tool_parser is not None
if not request.include_reasoning: tool_parser = self.tool_parser(tokenizer)
reasoning_content = None # NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
if is_tool_call: "",
# TODO(woosuk): Implement tool call for gpt-oss. request=request,
# For now, only Responses API supports tool call for token_ids=token_ids, # type: ignore
# gpt-oss. )
raise NotImplementedError( reasoning_content, content = None, tool_call_info.content
"Tool call in Chat Completion API is not supported " if request.include_reasoning:
"for gpt-oss yet. Please use Responses API instead.") reasoning_content, content, _ = parse_chat_output(
else: token_ids)
# Normal message message = ChatMessage(
message = ChatMessage( role=role,
role=role, reasoning_content=reasoning_content,
reasoning_content=reasoning_content, content=content,
content=final_content, tool_calls=tool_call_info.tool_calls,
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=message, message=message,
logprobs=logprobs, logprobs=logprobs,
finish_reason="tool_calls" if is_tool_call else finish_reason="tool_calls"
if tool_call_info.tools_called else
output.finish_reason if output.finish_reason else "stop", output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason, stop_reason=output.stop_reason,
) )
@ -1504,12 +1562,12 @@ class OpenAIServingChat(OpenAIServing):
messages.append(sys_msg) messages.append(sys_msg)
# Add developer message. # Add developer message.
dev_msg = get_developer_message() dev_msg = get_developer_message(tools=request.tools)
messages.append(dev_msg) messages.append(dev_msg)
# Add user message. # Add user message.
for chat_msg in request.messages: for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg)) messages.extend(parse_chat_input(chat_msg))
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)

View File

@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser from .llama_tool_parser import Llama3JsonToolParser
from .minimax_tool_parser import MinimaxToolParser from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser from .mistral_tool_parser import MistralToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser from .qwen3coder_tool_parser import Qwen3CoderToolParser
@ -46,4 +47,5 @@ __all__ = [
"Qwen3CoderToolParser", "Qwen3CoderToolParser",
"SeedOssToolParser", "SeedOssToolParser",
"Step3ToolParser", "Step3ToolParser",
"OpenAIToolParser",
] ]

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ToolParserManager.register_module("openai")
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
token_ids: Sequence[int] | None = None,
) -> ExtractedToolCallInformation:
if token_ids is None:
raise NotImplementedError(
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
)
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
if msg.recipient and msg.recipient.startswith("functions."):
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=msg.content[0].text,
),
))
elif msg.channel == "final":
final_content = msg.content[0].text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError(
"Not being used, manual parsing in serving_chat.py" # noqa: E501
)

View File

@ -256,7 +256,7 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
def verify_and_update_config(vllm_config: "VllmConfig") -> None: def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config decoding_config = vllm_config.decoding_config
if decoding_config.reasoning_backend == "": if decoding_config.reasoning_backend == "":
decoding_config.reasoning_backend = "GptOss" decoding_config.reasoning_backend = "openai_gptoss"
# Increase the max capture size from 512 to 1024 for performance. # Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs # NOTE(woosuk): This will increase the number of CUDA graphs

View File

@ -6,6 +6,7 @@ from typing import Optional, Union
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.harmony_utils import parse_chat_output
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -14,7 +15,7 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
@ReasoningParserManager.register_module("GptOss") @ReasoningParserManager.register_module("openai_gptoss")
class GptOssReasoningParser(ReasoningParser): class GptOssReasoningParser(ReasoningParser):
""" """
Reasoning parser for GptOss model. Reasoning parser for GptOss model.
@ -39,9 +40,10 @@ class GptOssReasoningParser(ReasoningParser):
return False return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
raise RuntimeError( _, content, _ = parse_chat_output(input_ids)
"GptOss model uses harmony to extract reasoning content. This " if content is None:
"function should not be called.") return []
return self.model_tokenizer.encode(content)
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
@ -52,13 +54,34 @@ class GptOssReasoningParser(ReasoningParser):
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
raise RuntimeError( prev_reasoning, prev_content, _ = parse_chat_output(
"GptOss model uses harmony to extract reasoning content. This " list(previous_token_ids))
"function should not be called.") cur_reasoning, cur_content, _ = parse_chat_output(
list(current_token_ids))
reasoning_delta = None
content_delta = None
if cur_reasoning is not None:
prev_r = prev_reasoning or ""
if cur_reasoning.startswith(prev_r):
reasoning_delta = cur_reasoning[len(prev_r):] or None
else:
reasoning_delta = cur_reasoning
if cur_content is not None:
prev_c = prev_content or ""
if cur_content.startswith(prev_c):
content_delta = cur_content[len(prev_c):] or None
else:
content_delta = cur_content
if reasoning_delta is None and content_delta is None:
return None
return DeltaMessage(reasoning_content=reasoning_delta,
content=content_delta)
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self,
model_output: str,
request: ChatCompletionRequest,
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
raise RuntimeError( raise NotImplementedError(
"GptOss model uses harmony to extract reasoning content. This " "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
"function should not be called.") )