mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:00:12 +08:00
[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
65e038931d
commit
c29fb540ff
@ -1,13 +1,16 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
from vllm.config import MultiModalConfig
|
from vllm.config import MultiModalConfig
|
||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
@ -17,6 +20,164 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
|||||||
OpenAIServingModels)
|
OpenAIServingModels)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def monkeypatch_module():
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
mpatch = MonkeyPatch()
|
||||||
|
yield mpatch
|
||||||
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||||
|
with monkeypatch_module.context() as m:
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||||
|
args = [
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--tool-call-parser",
|
||||||
|
"openai",
|
||||||
|
"--reasoning-parser",
|
||||||
|
"openai_gptoss",
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
]
|
||||||
|
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def gptoss_client(gptoss_server):
|
||||||
|
async with gptoss_server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather in Dallas, TX?"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
stream = await gptoss_client.chat.completions.create(
|
||||||
|
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
||||||
|
|
||||||
|
name = None
|
||||||
|
args_buf = ""
|
||||||
|
async for chunk in stream:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta.tool_calls:
|
||||||
|
tc = delta.tool_calls[0]
|
||||||
|
if tc.function and tc.function.name:
|
||||||
|
name = tc.function.name
|
||||||
|
if tc.function and tc.function.arguments:
|
||||||
|
args_buf += tc.function.arguments
|
||||||
|
|
||||||
|
assert name is not None
|
||||||
|
assert len(args_buf) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"description": "Get the current weather in a given location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"state": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city", "state", "unit"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the weather in Dallas, TX?"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
first = await gptoss_client.chat.completions.create(
|
||||||
|
model=GPT_OSS_MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
first_msg = first.choices[0].message
|
||||||
|
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
|
||||||
|
tc = first_msg.tool_calls[0]
|
||||||
|
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||||
|
args1 = tc.function.arguments
|
||||||
|
assert args1 is not None and len(args1) > 0
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": args1})
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": "Now convert to celsius and return JSON only"
|
||||||
|
})
|
||||||
|
|
||||||
|
second = await gptoss_client.chat.completions.create(
|
||||||
|
model=GPT_OSS_MODEL_NAME,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
second_msg = second.choices[0].message
|
||||||
|
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||||
|
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME = "openai-community/gpt2"
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||||
|
|||||||
147
tests/tool_use/test_openai_tool_parser.py
Normal file
147
tests/tool_use/test_openai_tool_parser.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai_harmony import (Conversation, DeveloperContent,
|
||||||
|
HarmonyEncodingName, Message, Role, SystemContent,
|
||||||
|
load_harmony_encoding)
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||||
|
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
MODEL = "gpt2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def openai_tokenizer():
|
||||||
|
# The parser does not use the tokenizer, but the constructor requires it.
|
||||||
|
return get_tokenizer(MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def openai_tool_parser(openai_tokenizer):
|
||||||
|
return OpenAIToolParser(openai_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def harmony_encoding():
|
||||||
|
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tool_calls(
|
||||||
|
actual_tool_calls: list[ToolCall],
|
||||||
|
expected_tool_calls: list[ToolCall],
|
||||||
|
):
|
||||||
|
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||||
|
|
||||||
|
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||||
|
expected_tool_calls):
|
||||||
|
assert isinstance(actual_tool_call.id, str)
|
||||||
|
assert len(actual_tool_call.id) > 16 # Default from protocol.py
|
||||||
|
assert actual_tool_call.type == "function"
|
||||||
|
assert actual_tool_call.function == expected_tool_call.function
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||||
|
convo = Conversation.from_messages([
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.SYSTEM,
|
||||||
|
SystemContent.new(),
|
||||||
|
),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.DEVELOPER,
|
||||||
|
DeveloperContent.new().with_instructions("Talk like a pirate!")),
|
||||||
|
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
|
||||||
|
Message.from_role_and_content(Role.ASSISTANT,
|
||||||
|
"This is a test").with_channel("final")
|
||||||
|
])
|
||||||
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
|
convo, Role.ASSISTANT)
|
||||||
|
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||||
|
"",
|
||||||
|
request=None,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert not extracted_info.tools_called
|
||||||
|
assert extracted_info.tool_calls == []
|
||||||
|
assert extracted_info.content == "This is a test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
||||||
|
convo = Conversation.from_messages([
|
||||||
|
Message.from_role_and_content(Role.USER,
|
||||||
|
"What is the weather in Tokyo?"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
|
||||||
|
).with_channel("analysis"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
|
])
|
||||||
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
|
convo, Role.ASSISTANT)
|
||||||
|
|
||||||
|
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||||
|
"",
|
||||||
|
request=None,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert extracted_info.tools_called
|
||||||
|
expected_tool_calls = [
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
))
|
||||||
|
]
|
||||||
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
|
assert extracted_info.content is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_multiple_tools(
|
||||||
|
openai_tool_parser,
|
||||||
|
harmony_encoding,
|
||||||
|
):
|
||||||
|
convo = Conversation.from_messages([
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||||
|
).with_channel("analysis"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.get_current_weather").with_content_type("json"),
|
||||||
|
Message.from_role_and_content(
|
||||||
|
Role.ASSISTANT,
|
||||||
|
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||||
|
"functions.get_user_location").with_content_type("json"),
|
||||||
|
])
|
||||||
|
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||||
|
convo,
|
||||||
|
Role.ASSISTANT,
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||||
|
"",
|
||||||
|
request=None,
|
||||||
|
token_ids=token_ids,
|
||||||
|
)
|
||||||
|
assert extracted_info.tools_called
|
||||||
|
expected_tool_calls = [
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_current_weather",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
)),
|
||||||
|
ToolCall(function=FunctionCall(
|
||||||
|
name="get_user_location",
|
||||||
|
arguments=json.dumps({"location": "Tokyo"}),
|
||||||
|
))
|
||||||
|
]
|
||||||
|
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||||
|
assert extracted_info.content is None
|
||||||
@ -1,5 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
@ -18,7 +21,8 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
|||||||
Role, StreamableParser, SystemContent, TextContent,
|
Role, StreamableParser, SystemContent, TextContent,
|
||||||
ToolDescription, load_harmony_encoding)
|
ToolDescription, load_harmony_encoding)
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem
|
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||||
|
ResponseInputOutputItem)
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
REASONING_EFFORT = {
|
REASONING_EFFORT = {
|
||||||
@ -63,13 +67,29 @@ def get_system_message(
|
|||||||
return sys_msg
|
return sys_msg
|
||||||
|
|
||||||
|
|
||||||
def get_developer_message(instructions: Optional[str] = None,
|
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
|
||||||
tools: Optional[list[Tool]] = None) -> Message:
|
if isinstance(tool, ChatCompletionToolsParam):
|
||||||
|
return ToolDescription.new(
|
||||||
|
name=tool.function.name,
|
||||||
|
description=tool.function.description,
|
||||||
|
parameters=tool.function.parameters,
|
||||||
|
)
|
||||||
|
return ToolDescription.new(
|
||||||
|
name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=tool.parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_developer_message(
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||||
|
) -> Message:
|
||||||
dev_msg_content = DeveloperContent.new()
|
dev_msg_content = DeveloperContent.new()
|
||||||
if instructions is not None:
|
if instructions is not None:
|
||||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||||
if tools is not None:
|
if tools is not None:
|
||||||
function_tools = []
|
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||||
# These are built-in tools that are added to the system message.
|
# These are built-in tools that are added to the system message.
|
||||||
@ -80,11 +100,7 @@ def get_developer_message(instructions: Optional[str] = None,
|
|||||||
raise ValueError(f"tool type {tool.type} not supported")
|
raise ValueError(f"tool type {tool.type} not supported")
|
||||||
if function_tools:
|
if function_tools:
|
||||||
function_tool_descriptions = [
|
function_tool_descriptions = [
|
||||||
ToolDescription.new(
|
create_tool_definition(tool) for tool in function_tools
|
||||||
name=tool.name,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=tool.parameters,
|
|
||||||
) for tool in function_tools
|
|
||||||
]
|
]
|
||||||
dev_msg_content = dev_msg_content.with_function_tools(
|
dev_msg_content = dev_msg_content.with_function_tools(
|
||||||
function_tool_descriptions)
|
function_tool_descriptions)
|
||||||
@ -148,16 +164,46 @@ def parse_response_input(
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_input(chat_msg) -> Message:
|
def parse_chat_input(chat_msg) -> list[Message]:
|
||||||
role = chat_msg["role"]
|
if not isinstance(chat_msg, dict):
|
||||||
content = chat_msg["content"]
|
# Handle Pydantic models
|
||||||
|
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
role = chat_msg.get("role")
|
||||||
|
|
||||||
|
# Assistant message with tool calls
|
||||||
|
tool_calls = chat_msg.get("tool_calls")
|
||||||
|
if role == "assistant" and tool_calls:
|
||||||
|
msgs: list[Message] = []
|
||||||
|
for call in tool_calls:
|
||||||
|
func = call.get("function", {})
|
||||||
|
name = func.get("name", "")
|
||||||
|
arguments = func.get("arguments", "") or ""
|
||||||
|
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||||
|
msg = msg.with_channel("commentary")
|
||||||
|
msg = msg.with_recipient(f"functions.{name}")
|
||||||
|
msg = msg.with_content_type("json")
|
||||||
|
msgs.append(msg)
|
||||||
|
return msgs
|
||||||
|
|
||||||
|
# Tool role message (tool output)
|
||||||
|
if role == "tool":
|
||||||
|
name = chat_msg.get("name", "")
|
||||||
|
content = chat_msg.get("content", "") or ""
|
||||||
|
msg = Message.from_author_and_content(
|
||||||
|
Author.new(Role.TOOL, f"functions.{name}"),
|
||||||
|
content).with_channel("commentary")
|
||||||
|
return [msg]
|
||||||
|
|
||||||
|
# Default: user/assistant/system messages with content
|
||||||
|
content = chat_msg.get("content", "")
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
contents = [TextContent(text=content)]
|
contents = [TextContent(text=content)]
|
||||||
else:
|
else:
|
||||||
# TODO: Support refusal.
|
# TODO: Support refusal.
|
||||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||||
msg = Message.from_role_and_contents(role, contents)
|
msg = Message.from_role_and_contents(role, contents)
|
||||||
return msg
|
return [msg]
|
||||||
|
|
||||||
|
|
||||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from collections.abc import Sequence as GenericSequence
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import Callable, Final, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
get_streamable_parser_for_assistant()
|
get_streamable_parser_for_assistant()
|
||||||
for _ in range(num_choices)
|
for _ in range(num_choices)
|
||||||
]
|
]
|
||||||
|
harmony_tools_streamed = [False] * num_choices
|
||||||
|
tools_streamed = [False] * num_choices
|
||||||
|
|
||||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
tool_choice_function_name = request.tool_choice.function.name
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
harmony_parser = harmony_parsers[i]
|
harmony_parser = harmony_parsers[i]
|
||||||
|
prev_recipient = harmony_parser.current_recipient
|
||||||
for token_id in output.token_ids:
|
for token_id in output.token_ids:
|
||||||
harmony_parser.process(token_id)
|
harmony_parser.process(token_id)
|
||||||
is_reasoning = \
|
cur_channel = harmony_parser.current_channel
|
||||||
harmony_parser.current_channel == "analysis"
|
cur_recipient = harmony_parser.current_recipient
|
||||||
if not request.include_reasoning and is_reasoning:
|
|
||||||
# Skip the reasoning content.
|
|
||||||
continue
|
|
||||||
delta_text = harmony_parser.last_content_delta or ""
|
delta_text = harmony_parser.last_content_delta or ""
|
||||||
else:
|
else:
|
||||||
delta_text = output.text
|
delta_text = output.text
|
||||||
@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta_message: Optional[DeltaMessage]
|
delta_message: Optional[DeltaMessage]
|
||||||
|
|
||||||
# just update previous_texts and previous_token_ids
|
# just update previous_texts and previous_token_ids
|
||||||
if ((tool_choice_auto or self.reasoning_parser)
|
if tool_choice_auto or self.reasoning_parser:
|
||||||
and not self.use_harmony):
|
|
||||||
assert previous_texts is not None
|
assert previous_texts is not None
|
||||||
assert all_previous_token_ids is not None
|
assert all_previous_token_ids is not None
|
||||||
previous_text = previous_texts[i]
|
previous_text = previous_texts[i]
|
||||||
@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids = as_list(output.token_ids)
|
current_token_ids = as_list(output.token_ids)
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
if is_reasoning:
|
if cur_channel == "final":
|
||||||
delta_message = DeltaMessage(
|
|
||||||
reasoning_content=delta_text)
|
|
||||||
else:
|
|
||||||
delta_message = DeltaMessage(content=delta_text)
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
elif cur_channel == "analysis":
|
||||||
|
if request.include_reasoning:
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
reasoning_content=delta_text)
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
|
elif (cur_channel == "commentary" and cur_recipient
|
||||||
|
and cur_recipient.startswith("functions.")):
|
||||||
|
# Count completed tool calls to determine index
|
||||||
|
base_index = 0
|
||||||
|
for msg in harmony_parser.messages:
|
||||||
|
if (msg.channel == "commentary"
|
||||||
|
and msg.recipient
|
||||||
|
and msg.recipient.startswith(
|
||||||
|
"functions.")):
|
||||||
|
base_index += 1
|
||||||
|
|
||||||
|
if prev_recipient != cur_recipient:
|
||||||
|
tool_name = cur_recipient.split(
|
||||||
|
"functions.", 1)[1]
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
id=make_tool_call_id(),
|
||||||
|
type="function",
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
name=tool_name,
|
||||||
|
arguments="",
|
||||||
|
),
|
||||||
|
index=base_index,
|
||||||
|
)
|
||||||
|
])
|
||||||
|
elif delta_text:
|
||||||
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=base_index,
|
||||||
|
function=DeltaFunctionCall(
|
||||||
|
arguments=delta_text),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
|
|
||||||
|
if delta_message is not None:
|
||||||
|
harmony_tools_streamed[i] = True
|
||||||
|
else:
|
||||||
|
delta_message = None
|
||||||
# handle streaming deltas for tools with named tool_choice
|
# handle streaming deltas for tools with named tool_choice
|
||||||
elif tool_choice_function_name:
|
elif tool_choice_function_name:
|
||||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||||
@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta_message = DeltaMessage(tool_calls=[
|
delta_message = DeltaMessage(tool_calls=[
|
||||||
delta_tool_call,
|
delta_tool_call,
|
||||||
])
|
])
|
||||||
|
tools_streamed[i] = True
|
||||||
|
|
||||||
elif request.tool_choice == "required":
|
elif request.tool_choice == "required":
|
||||||
assert previous_texts is not None
|
assert previous_texts is not None
|
||||||
@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if (delta_message and delta_message.tool_calls and
|
if (delta_message and delta_message.tool_calls and
|
||||||
delta_message.tool_calls[0].id is not None):
|
delta_message.tool_calls[0].id is not None):
|
||||||
history_tool_call_cnt += 1
|
history_tool_call_cnt += 1
|
||||||
|
tools_streamed[i] = True
|
||||||
|
|
||||||
# update the previous values for the next iteration
|
# update the previous values for the next iteration
|
||||||
previous_texts[i] = current_text
|
previous_texts[i] = current_text
|
||||||
@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids=current_token_ids,
|
current_token_ids=current_token_ids,
|
||||||
delta_token_ids=delta_token_ids,
|
delta_token_ids=delta_token_ids,
|
||||||
request=request))
|
request=request))
|
||||||
|
if delta_message and delta_message.tool_calls:
|
||||||
|
tools_streamed[i] = True
|
||||||
# when only tool calls
|
# when only tool calls
|
||||||
elif tool_choice_auto:
|
elif tool_choice_auto:
|
||||||
assert tool_parser is not None
|
assert tool_parser is not None
|
||||||
@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
current_token_ids=current_token_ids,
|
current_token_ids=current_token_ids,
|
||||||
delta_token_ids=output.token_ids,
|
delta_token_ids=output.token_ids,
|
||||||
request=request))
|
request=request))
|
||||||
|
if delta_message and delta_message.tool_calls:
|
||||||
|
tools_streamed[i] = True
|
||||||
|
|
||||||
# when only reasoning
|
# when only reasoning
|
||||||
elif self.reasoning_parser:
|
elif self.reasoning_parser:
|
||||||
@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# wasn't ready to send a token, then
|
# wasn't ready to send a token, then
|
||||||
# get the next token without streaming a chunk
|
# get the next token without streaming a chunk
|
||||||
if delta_message is None:
|
if delta_message is None:
|
||||||
continue
|
if output.finish_reason is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage()
|
||||||
|
|
||||||
# Log streaming delta if output logging is enabled
|
# Log streaming delta if output logging is enabled
|
||||||
if self.enable_log_outputs and self.request_logger:
|
if self.enable_log_outputs and self.request_logger:
|
||||||
@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# Send the finish response for each request.n only once
|
# Send the finish response for each request.n only once
|
||||||
|
if auto_tools_called or tools_streamed[i] or (
|
||||||
|
self.use_harmony
|
||||||
|
and harmony_tools_streamed[i]):
|
||||||
|
finish_reason_ = "tool_calls"
|
||||||
|
else:
|
||||||
|
finish_reason_ = output.finish_reason \
|
||||||
|
if output.finish_reason else "stop"
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=i,
|
index=i,
|
||||||
delta=delta_message,
|
delta=delta_message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason
|
finish_reason=finish_reason_,
|
||||||
if not auto_tools_called else "tool_calls",
|
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
token_ids=(as_list(output.token_ids)
|
token_ids=(as_list(output.token_ids)
|
||||||
if request.return_token_ids else None))
|
if request.return_token_ids else None))
|
||||||
@ -1131,31 +1188,32 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
reasoning_content, final_content, is_tool_call = (
|
if TYPE_CHECKING:
|
||||||
parse_chat_output(token_ids))
|
assert self.tool_parser is not None
|
||||||
if not request.include_reasoning:
|
tool_parser = self.tool_parser(tokenizer)
|
||||||
reasoning_content = None
|
# NOTE: We use token_ids for openai tool parser
|
||||||
|
tool_call_info = tool_parser.extract_tool_calls(
|
||||||
if is_tool_call:
|
"",
|
||||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
request=request,
|
||||||
# For now, only Responses API supports tool call for
|
token_ids=token_ids, # type: ignore
|
||||||
# gpt-oss.
|
)
|
||||||
raise NotImplementedError(
|
reasoning_content, content = None, tool_call_info.content
|
||||||
"Tool call in Chat Completion API is not supported "
|
if request.include_reasoning:
|
||||||
"for gpt-oss yet. Please use Responses API instead.")
|
reasoning_content, content, _ = parse_chat_output(
|
||||||
else:
|
token_ids)
|
||||||
# Normal message
|
message = ChatMessage(
|
||||||
message = ChatMessage(
|
role=role,
|
||||||
role=role,
|
reasoning_content=reasoning_content,
|
||||||
reasoning_content=reasoning_content,
|
content=content,
|
||||||
content=final_content,
|
tool_calls=tool_call_info.tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=output.index,
|
index=output.index,
|
||||||
message=message,
|
message=message,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason="tool_calls" if is_tool_call else
|
finish_reason="tool_calls"
|
||||||
|
if tool_call_info.tools_called else
|
||||||
output.finish_reason if output.finish_reason else "stop",
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason,
|
||||||
)
|
)
|
||||||
@ -1504,12 +1562,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
messages.append(sys_msg)
|
messages.append(sys_msg)
|
||||||
|
|
||||||
# Add developer message.
|
# Add developer message.
|
||||||
dev_msg = get_developer_message()
|
dev_msg = get_developer_message(tools=request.tools)
|
||||||
messages.append(dev_msg)
|
messages.append(dev_msg)
|
||||||
|
|
||||||
# Add user message.
|
# Add user message.
|
||||||
for chat_msg in request.messages:
|
for chat_msg in request.messages:
|
||||||
messages.append(parse_chat_input(chat_msg))
|
messages.extend(parse_chat_input(chat_msg))
|
||||||
|
|
||||||
# Render prompt token ids.
|
# Render prompt token ids.
|
||||||
prompt_token_ids = render_for_completion(messages)
|
prompt_token_ids = render_for_completion(messages)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
|||||||
from .llama_tool_parser import Llama3JsonToolParser
|
from .llama_tool_parser import Llama3JsonToolParser
|
||||||
from .minimax_tool_parser import MinimaxToolParser
|
from .minimax_tool_parser import MinimaxToolParser
|
||||||
from .mistral_tool_parser import MistralToolParser
|
from .mistral_tool_parser import MistralToolParser
|
||||||
|
from .openai_tool_parser import OpenAIToolParser
|
||||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||||
from .pythonic_tool_parser import PythonicToolParser
|
from .pythonic_tool_parser import PythonicToolParser
|
||||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||||
@ -46,4 +47,5 @@ __all__ = [
|
|||||||
"Qwen3CoderToolParser",
|
"Qwen3CoderToolParser",
|
||||||
"SeedOssToolParser",
|
"SeedOssToolParser",
|
||||||
"Step3ToolParser",
|
"Step3ToolParser",
|
||||||
|
"OpenAIToolParser",
|
||||||
]
|
]
|
||||||
|
|||||||
73
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
73
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from vllm.entrypoints.harmony_utils import parse_output_into_messages
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage,
|
||||||
|
ExtractedToolCallInformation,
|
||||||
|
FunctionCall, ToolCall)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
|
ToolParser, ToolParserManager)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserManager.register_module("openai")
|
||||||
|
class OpenAIToolParser(ToolParser):
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: AnyTokenizer):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
|
||||||
|
def extract_tool_calls(
|
||||||
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
token_ids: Sequence[int] | None = None,
|
||||||
|
) -> ExtractedToolCallInformation:
|
||||||
|
if token_ids is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = parse_output_into_messages(token_ids)
|
||||||
|
tool_calls = []
|
||||||
|
final_content = None
|
||||||
|
|
||||||
|
if len(parser.messages) > 0:
|
||||||
|
for msg in parser.messages:
|
||||||
|
if msg.recipient and msg.recipient.startswith("functions."):
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
type="function",
|
||||||
|
function=FunctionCall(
|
||||||
|
name=msg.recipient.split("functions.")[1],
|
||||||
|
arguments=msg.content[0].text,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
elif msg.channel == "final":
|
||||||
|
final_content = msg.content[0].text
|
||||||
|
|
||||||
|
return ExtractedToolCallInformation(
|
||||||
|
tools_called=len(tool_calls) > 0,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
content=final_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_tool_calls_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> DeltaMessage | None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Not being used, manual parsing in serving_chat.py" # noqa: E501
|
||||||
|
)
|
||||||
@ -256,7 +256,7 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
|||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||||
decoding_config = vllm_config.decoding_config
|
decoding_config = vllm_config.decoding_config
|
||||||
if decoding_config.reasoning_backend == "":
|
if decoding_config.reasoning_backend == "":
|
||||||
decoding_config.reasoning_backend = "GptOss"
|
decoding_config.reasoning_backend = "openai_gptoss"
|
||||||
|
|
||||||
# Increase the max capture size from 512 to 1024 for performance.
|
# Increase the max capture size from 512 to 1024 for performance.
|
||||||
# NOTE(woosuk): This will increase the number of CUDA graphs
|
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.harmony_utils import parse_chat_output
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
DeltaMessage)
|
DeltaMessage)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -14,7 +15,7 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ReasoningParserManager.register_module("GptOss")
|
@ReasoningParserManager.register_module("openai_gptoss")
|
||||||
class GptOssReasoningParser(ReasoningParser):
|
class GptOssReasoningParser(ReasoningParser):
|
||||||
"""
|
"""
|
||||||
Reasoning parser for GptOss model.
|
Reasoning parser for GptOss model.
|
||||||
@ -39,9 +40,10 @@ class GptOssReasoningParser(ReasoningParser):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
raise RuntimeError(
|
_, content, _ = parse_chat_output(input_ids)
|
||||||
"GptOss model uses harmony to extract reasoning content. This "
|
if content is None:
|
||||||
"function should not be called.")
|
return []
|
||||||
|
return self.model_tokenizer.encode(content)
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
@ -52,13 +54,34 @@ class GptOssReasoningParser(ReasoningParser):
|
|||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
raise RuntimeError(
|
prev_reasoning, prev_content, _ = parse_chat_output(
|
||||||
"GptOss model uses harmony to extract reasoning content. This "
|
list(previous_token_ids))
|
||||||
"function should not be called.")
|
cur_reasoning, cur_content, _ = parse_chat_output(
|
||||||
|
list(current_token_ids))
|
||||||
|
reasoning_delta = None
|
||||||
|
content_delta = None
|
||||||
|
if cur_reasoning is not None:
|
||||||
|
prev_r = prev_reasoning or ""
|
||||||
|
if cur_reasoning.startswith(prev_r):
|
||||||
|
reasoning_delta = cur_reasoning[len(prev_r):] or None
|
||||||
|
else:
|
||||||
|
reasoning_delta = cur_reasoning
|
||||||
|
if cur_content is not None:
|
||||||
|
prev_c = prev_content or ""
|
||||||
|
if cur_content.startswith(prev_c):
|
||||||
|
content_delta = cur_content[len(prev_c):] or None
|
||||||
|
else:
|
||||||
|
content_delta = cur_content
|
||||||
|
if reasoning_delta is None and content_delta is None:
|
||||||
|
return None
|
||||||
|
return DeltaMessage(reasoning_content=reasoning_delta,
|
||||||
|
content=content_delta)
|
||||||
|
|
||||||
def extract_reasoning_content(
|
def extract_reasoning_content(
|
||||||
self, model_output: str, request: ChatCompletionRequest
|
self,
|
||||||
|
model_output: str,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
) -> tuple[Optional[str], Optional[str]]:
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
raise RuntimeError(
|
raise NotImplementedError(
|
||||||
"GptOss model uses harmony to extract reasoning content. This "
|
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
|
||||||
"function should not be called.")
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user