mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
65e038931d
commit
c29fb540ff
@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
@ -17,6 +20,164 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import OpenAI
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gptoss_client(gptoss_server):
|
||||
async with gptoss_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX?"
|
||||
},
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function and tc.function.name:
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX?"
|
||||
},
|
||||
]
|
||||
|
||||
first = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
first_msg = first.choices[0].message
|
||||
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
|
||||
tc = first_msg.tool_calls[0]
|
||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||
args1 = tc.function.arguments
|
||||
assert args1 is not None and len(args1) > 0
|
||||
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Now convert to celsius and return JSON only"
|
||||
})
|
||||
|
||||
second = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
|
||||
147
tests/tool_use/test_openai_tool_parser.py
Normal file
147
tests/tool_use/test_openai_tool_parser.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from openai_harmony import (Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, Message, Role, SystemContent,
|
||||
load_harmony_encoding)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL = "gpt2"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def openai_tokenizer():
|
||||
# The parser does not use the tokenizer, but the constructor requires it.
|
||||
return get_tokenizer(MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_tool_parser(openai_tokenizer):
|
||||
return OpenAIToolParser(openai_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def harmony_encoding():
|
||||
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall],
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16 # Default from protocol.py
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.SYSTEM,
|
||||
SystemContent.new(),
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.DEVELOPER,
|
||||
DeveloperContent.new().with_instructions("Talk like a pirate!")),
|
||||
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
|
||||
Message.from_role_and_content(Role.ASSISTANT,
|
||||
"This is a test").with_channel("final")
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT)
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert not extracted_info.tools_called
|
||||
assert extracted_info.tool_calls == []
|
||||
assert extracted_info.content == "This is a test"
|
||||
|
||||
|
||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(Role.USER,
|
||||
"What is the weather in Tokyo?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
))
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_tools(
|
||||
openai_tool_parser,
|
||||
harmony_encoding,
|
||||
):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_user_location").with_content_type("json"),
|
||||
])
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_user_location",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
))
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from collections.abc import Iterable, Sequence
|
||||
@ -18,7 +21,8 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
|
||||
Role, StreamableParser, SystemContent, TextContent,
|
||||
ToolDescription, load_harmony_encoding)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||
ResponseInputOutputItem)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
REASONING_EFFORT = {
|
||||
@ -63,13 +67,29 @@ def get_system_message(
|
||||
return sys_msg
|
||||
|
||||
|
||||
def get_developer_message(instructions: Optional[str] = None,
|
||||
tools: Optional[list[Tool]] = None) -> Message:
|
||||
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
|
||||
if isinstance(tool, ChatCompletionToolsParam):
|
||||
return ToolDescription.new(
|
||||
name=tool.function.name,
|
||||
description=tool.function.description,
|
||||
parameters=tool.function.parameters,
|
||||
)
|
||||
return ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
)
|
||||
|
||||
|
||||
def get_developer_message(
|
||||
instructions: Optional[str] = None,
|
||||
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
|
||||
) -> Message:
|
||||
dev_msg_content = DeveloperContent.new()
|
||||
if instructions is not None:
|
||||
dev_msg_content = dev_msg_content.with_instructions(instructions)
|
||||
if tools is not None:
|
||||
function_tools = []
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
@ -80,11 +100,7 @@ def get_developer_message(instructions: Optional[str] = None,
|
||||
raise ValueError(f"tool type {tool.type} not supported")
|
||||
if function_tools:
|
||||
function_tool_descriptions = [
|
||||
ToolDescription.new(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.parameters,
|
||||
) for tool in function_tools
|
||||
create_tool_definition(tool) for tool in function_tools
|
||||
]
|
||||
dev_msg_content = dev_msg_content.with_function_tools(
|
||||
function_tool_descriptions)
|
||||
@ -148,16 +164,46 @@ def parse_response_input(
|
||||
return msg
|
||||
|
||||
|
||||
def parse_chat_input(chat_msg) -> Message:
|
||||
role = chat_msg["role"]
|
||||
content = chat_msg["content"]
|
||||
def parse_chat_input(chat_msg) -> list[Message]:
|
||||
if not isinstance(chat_msg, dict):
|
||||
# Handle Pydantic models
|
||||
chat_msg = chat_msg.model_dump(exclude_none=True)
|
||||
|
||||
role = chat_msg.get("role")
|
||||
|
||||
# Assistant message with tool calls
|
||||
tool_calls = chat_msg.get("tool_calls")
|
||||
if role == "assistant" and tool_calls:
|
||||
msgs: list[Message] = []
|
||||
for call in tool_calls:
|
||||
func = call.get("function", {})
|
||||
name = func.get("name", "")
|
||||
arguments = func.get("arguments", "") or ""
|
||||
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
|
||||
msg = msg.with_channel("commentary")
|
||||
msg = msg.with_recipient(f"functions.{name}")
|
||||
msg = msg.with_content_type("json")
|
||||
msgs.append(msg)
|
||||
return msgs
|
||||
|
||||
# Tool role message (tool output)
|
||||
if role == "tool":
|
||||
name = chat_msg.get("name", "")
|
||||
content = chat_msg.get("content", "") or ""
|
||||
msg = Message.from_author_and_content(
|
||||
Author.new(Role.TOOL, f"functions.{name}"),
|
||||
content).with_channel("commentary")
|
||||
return [msg]
|
||||
|
||||
# Default: user/assistant/system messages with content
|
||||
content = chat_msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
contents = [TextContent(text=content)]
|
||||
else:
|
||||
# TODO: Support refusal.
|
||||
contents = [TextContent(text=c.get("text", "")) for c in content]
|
||||
msg = Message.from_role_and_contents(role, contents)
|
||||
return msg
|
||||
return [msg]
|
||||
|
||||
|
||||
def render_for_completion(messages: list[Message]) -> list[int]:
|
||||
|
||||
@ -6,7 +6,7 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
get_streamable_parser_for_assistant()
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
harmony_tools_streamed = [False] * num_choices
|
||||
tools_streamed = [False] * num_choices
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
prev_recipient = harmony_parser.current_recipient
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
is_reasoning = \
|
||||
harmony_parser.current_channel == "analysis"
|
||||
if not request.include_reasoning and is_reasoning:
|
||||
# Skip the reasoning content.
|
||||
continue
|
||||
cur_channel = harmony_parser.current_channel
|
||||
cur_recipient = harmony_parser.current_recipient
|
||||
delta_text = harmony_parser.last_content_delta or ""
|
||||
else:
|
||||
delta_text = output.text
|
||||
@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# just update previous_texts and previous_token_ids
|
||||
if ((tool_choice_auto or self.reasoning_parser)
|
||||
and not self.use_harmony):
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids = as_list(output.token_ids)
|
||||
|
||||
if self.use_harmony:
|
||||
if is_reasoning:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
else:
|
||||
if cur_channel == "final":
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
elif cur_channel == "analysis":
|
||||
if request.include_reasoning:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
else:
|
||||
delta_message = None
|
||||
elif (cur_channel == "commentary" and cur_recipient
|
||||
and cur_recipient.startswith("functions.")):
|
||||
# Count completed tool calls to determine index
|
||||
base_index = 0
|
||||
for msg in harmony_parser.messages:
|
||||
if (msg.channel == "commentary"
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith(
|
||||
"functions.")):
|
||||
base_index += 1
|
||||
|
||||
if prev_recipient != cur_recipient:
|
||||
tool_name = cur_recipient.split(
|
||||
"functions.", 1)[1]
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_name,
|
||||
arguments="",
|
||||
),
|
||||
index=base_index,
|
||||
)
|
||||
])
|
||||
elif delta_text:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=base_index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text),
|
||||
)
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
if delta_message is not None:
|
||||
harmony_tools_streamed[i] = True
|
||||
else:
|
||||
delta_message = None
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
elif tool_choice_function_name:
|
||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||
@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
delta_tool_call,
|
||||
])
|
||||
tools_streamed[i] = True
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (delta_message and delta_message.tool_calls and
|
||||
delta_message.tool_calls[0].id is not None):
|
||||
history_tool_call_cnt += 1
|
||||
tools_streamed[i] = True
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=request))
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
# when only tool calls
|
||||
elif tool_choice_auto:
|
||||
assert tool_parser is not None
|
||||
@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
if delta_message and delta_message.tool_calls:
|
||||
tools_streamed[i] = True
|
||||
|
||||
# when only reasoning
|
||||
elif self.reasoning_parser:
|
||||
@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# wasn't ready to send a token, then
|
||||
# get the next token without streaming a chunk
|
||||
if delta_message is None:
|
||||
continue
|
||||
if output.finish_reason is None:
|
||||
continue
|
||||
else:
|
||||
delta_message = DeltaMessage()
|
||||
|
||||
# Log streaming delta if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
if auto_tools_called or tools_streamed[i] or (
|
||||
self.use_harmony
|
||||
and harmony_tools_streamed[i]):
|
||||
finish_reason_ = "tool_calls"
|
||||
else:
|
||||
finish_reason_ = output.finish_reason \
|
||||
if output.finish_reason else "stop"
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
finish_reason=finish_reason_,
|
||||
stop_reason=output.stop_reason,
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None))
|
||||
@ -1131,31 +1188,32 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = None
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, final_content, is_tool_call = (
|
||||
parse_chat_output(token_ids))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if is_tool_call:
|
||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||
# For now, only Responses API supports tool call for
|
||||
# gpt-oss.
|
||||
raise NotImplementedError(
|
||||
"Tool call in Chat Completion API is not supported "
|
||||
"for gpt-oss yet. Please use Responses API instead.")
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
assert self.tool_parser is not None
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
# NOTE: We use token_ids for openai tool parser
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=request,
|
||||
token_ids=token_ids, # type: ignore
|
||||
)
|
||||
reasoning_content, content = None, tool_call_info.content
|
||||
if request.include_reasoning:
|
||||
reasoning_content, content, _ = parse_chat_output(
|
||||
token_ids)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content,
|
||||
tool_calls=tool_call_info.tool_calls,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if is_tool_call else
|
||||
finish_reason="tool_calls"
|
||||
if tool_call_info.tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
@ -1504,12 +1562,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message()
|
||||
dev_msg = get_developer_message(tools=request.tools)
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.append(parse_chat_input(chat_msg))
|
||||
messages.extend(parse_chat_input(chat_msg))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
|
||||
@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .minimax_tool_parser import MinimaxToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .openai_tool_parser import OpenAIToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
from .pythonic_tool_parser import PythonicToolParser
|
||||
from .qwen3coder_tool_parser import Qwen3CoderToolParser
|
||||
@ -46,4 +47,5 @@ __all__ = [
|
||||
"Qwen3CoderToolParser",
|
||||
"SeedOssToolParser",
|
||||
"Step3ToolParser",
|
||||
"OpenAIToolParser",
|
||||
]
|
||||
|
||||
73
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
73
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.harmony_utils import parse_output_into_messages
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
@ToolParserManager.register_module("openai")
|
||||
class OpenAIToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
token_ids: Sequence[int] | None = None,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if token_ids is None:
|
||||
raise NotImplementedError(
|
||||
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
|
||||
)
|
||||
|
||||
parser = parse_output_into_messages(token_ids)
|
||||
tool_calls = []
|
||||
final_content = None
|
||||
|
||||
if len(parser.messages) > 0:
|
||||
for msg in parser.messages:
|
||||
if msg.recipient and msg.recipient.startswith("functions."):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=msg.recipient.split("functions.")[1],
|
||||
arguments=msg.content[0].text,
|
||||
),
|
||||
))
|
||||
elif msg.channel == "final":
|
||||
final_content = msg.content[0].text
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=len(tool_calls) > 0,
|
||||
tool_calls=tool_calls,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
raise NotImplementedError(
|
||||
"Not being used, manual parsing in serving_chat.py" # noqa: E501
|
||||
)
|
||||
@ -256,7 +256,7 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
decoding_config = vllm_config.decoding_config
|
||||
if decoding_config.reasoning_backend == "":
|
||||
decoding_config.reasoning_backend = "GptOss"
|
||||
decoding_config.reasoning_backend = "openai_gptoss"
|
||||
|
||||
# Increase the max capture size from 512 to 1024 for performance.
|
||||
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.harmony_utils import parse_chat_output
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.logger import init_logger
|
||||
@ -14,7 +15,7 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("GptOss")
|
||||
@ReasoningParserManager.register_module("openai_gptoss")
|
||||
class GptOssReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for GptOss model.
|
||||
@ -39,9 +40,10 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
return False
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
_, content, _ = parse_chat_output(input_ids)
|
||||
if content is None:
|
||||
return []
|
||||
return self.model_tokenizer.encode(content)
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@ -52,13 +54,34 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
prev_reasoning, prev_content, _ = parse_chat_output(
|
||||
list(previous_token_ids))
|
||||
cur_reasoning, cur_content, _ = parse_chat_output(
|
||||
list(current_token_ids))
|
||||
reasoning_delta = None
|
||||
content_delta = None
|
||||
if cur_reasoning is not None:
|
||||
prev_r = prev_reasoning or ""
|
||||
if cur_reasoning.startswith(prev_r):
|
||||
reasoning_delta = cur_reasoning[len(prev_r):] or None
|
||||
else:
|
||||
reasoning_delta = cur_reasoning
|
||||
if cur_content is not None:
|
||||
prev_c = prev_content or ""
|
||||
if cur_content.startswith(prev_c):
|
||||
content_delta = cur_content[len(prev_c):] or None
|
||||
else:
|
||||
content_delta = cur_content
|
||||
if reasoning_delta is None and content_delta is None:
|
||||
return None
|
||||
return DeltaMessage(reasoning_content=reasoning_delta,
|
||||
content=content_delta)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
raise NotImplementedError(
|
||||
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user