[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Aaron Pham 2025-09-04 23:39:12 -04:00 committed by GitHub
parent 65e038931d
commit c29fb540ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 573 additions and 63 deletions

View File

@ -1,13 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
@ -17,6 +20,164 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
if TYPE_CHECKING:
from openai import OpenAI
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
args = [
"--enforce-eager",
"--max-model-len",
"8192",
"--tool-call-parser",
"openai",
"--reasoning-parser",
"openai_gptoss",
"--enable-auto-tool-choice",
]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def gptoss_client(gptoss_server):
async with gptoss_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string"
},
"state": {
"type": "string"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}]
messages = [
{
"role": "user",
"content": "What is the weather in Dallas, TX?"
},
]
stream = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
name = None
args_buf = ""
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.tool_calls:
tc = delta.tool_calls[0]
if tc.function and tc.function.name:
name = tc.function.name
if tc.function and tc.function.arguments:
args_buf += tc.function.arguments
assert name is not None
assert len(args_buf) > 0
@pytest.mark.asyncio
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string"
},
"state": {
"type": "string"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}]
messages = [
{
"role": "system",
"content": "you are a helpful assistant"
},
{
"role": "user",
"content": "What is the weather in Dallas, TX?"
},
]
first = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
temperature=0.0,
)
first_msg = first.choices[0].message
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
tc = first_msg.tool_calls[0]
assert tc.function is not None and tc.function.name == "get_current_weather"
args1 = tc.function.arguments
assert args1 is not None and len(args1) > 0
messages.append({"role": "assistant", "content": args1})
messages.append({
"role": "user",
"content": "Now convert to celsius and return JSON only"
})
second = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
temperature=0.0,
)
second_msg = second.choices[0].message
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from openai_harmony import (Conversation, DeveloperContent,
HarmonyEncodingName, Message, Role, SystemContent,
load_harmony_encoding)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
MODEL = "gpt2"
@pytest.fixture(scope="module")
def openai_tokenizer():
# The parser does not use the tokenizer, but the constructor requires it.
return get_tokenizer(MODEL)
@pytest.fixture
def openai_tool_parser(openai_tokenizer):
return OpenAIToolParser(openai_tokenizer)
@pytest.fixture(scope="module")
def harmony_encoding():
return load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def assert_tool_calls(
actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16 # Default from protocol.py
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.SYSTEM,
SystemContent.new(),
),
Message.from_role_and_content(
Role.DEVELOPER,
DeveloperContent.new().with_instructions("Talk like a pirate!")),
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
Message.from_role_and_content(Role.ASSISTANT,
"This is a test").with_channel("final")
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert not extracted_info.tools_called
assert extracted_info.tool_calls == []
assert extracted_info.content == "This is a test"
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding):
convo = Conversation.from_messages([
Message.from_role_and_content(Role.USER,
"What is the weather in Tokyo?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo, Role.ASSISTANT)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
))
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None
def test_extract_tool_calls_multiple_tools(
openai_tool_parser,
harmony_encoding,
):
convo = Conversation.from_messages([
Message.from_role_and_content(
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
Message.from_role_and_content(
Role.ASSISTANT,
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_current_weather").with_content_type("json"),
Message.from_role_and_content(
Role.ASSISTANT,
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
"functions.get_user_location").with_content_type("json"),
])
token_ids = harmony_encoding.render_conversation_for_completion(
convo,
Role.ASSISTANT,
)
extracted_info = openai_tool_parser.extract_tool_calls(
"",
request=None,
token_ids=token_ids,
)
assert extracted_info.tools_called
expected_tool_calls = [
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({"location": "Tokyo"}),
)),
ToolCall(function=FunctionCall(
name="get_user_location",
arguments=json.dumps({"location": "Tokyo"}),
))
]
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
assert extracted_info.content is None

View File

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import datetime
import json
from collections.abc import Iterable, Sequence
@ -18,7 +21,8 @@ from openai_harmony import (Author, Conversation, DeveloperContent,
Role, StreamableParser, SystemContent, TextContent,
ToolDescription, load_harmony_encoding)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
ResponseInputOutputItem)
from vllm.utils import random_uuid
REASONING_EFFORT = {
@ -63,13 +67,29 @@ def get_system_message(
return sys_msg
def get_developer_message(instructions: Optional[str] = None,
tools: Optional[list[Tool]] = None) -> Message:
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: Optional[str] = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None:
dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None:
function_tools = []
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"):
# These are built-in tools that are added to the system message.
@ -80,11 +100,7 @@ def get_developer_message(instructions: Optional[str] = None,
raise ValueError(f"tool type {tool.type} not supported")
if function_tools:
function_tool_descriptions = [
ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
) for tool in function_tools
create_tool_definition(tool) for tool in function_tools
]
dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions)
@ -148,16 +164,46 @@ def parse_response_input(
return msg
def parse_chat_input(chat_msg) -> Message:
role = chat_msg["role"]
content = chat_msg["content"]
def parse_chat_input(chat_msg) -> list[Message]:
if not isinstance(chat_msg, dict):
# Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"),
content).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str):
contents = [TextContent(text=content)]
else:
# TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents)
return msg
return [msg]
def render_for_completion(messages: list[Message]) -> list[int]:

View File

@ -6,7 +6,7 @@ import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
from typing import TYPE_CHECKING, Callable, Final, Optional, Union
import jinja2
import partial_json_parser
@ -489,6 +489,8 @@ class OpenAIServingChat(OpenAIServing):
get_streamable_parser_for_assistant()
for _ in range(num_choices)
]
harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
@ -662,13 +664,11 @@ class OpenAIServingChat(OpenAIServing):
if self.use_harmony:
harmony_parser = harmony_parsers[i]
prev_recipient = harmony_parser.current_recipient
for token_id in output.token_ids:
harmony_parser.process(token_id)
is_reasoning = \
harmony_parser.current_channel == "analysis"
if not request.include_reasoning and is_reasoning:
# Skip the reasoning content.
continue
cur_channel = harmony_parser.current_channel
cur_recipient = harmony_parser.current_recipient
delta_text = harmony_parser.last_content_delta or ""
else:
delta_text = output.text
@ -681,8 +681,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage]
# just update previous_texts and previous_token_ids
if ((tool_choice_auto or self.reasoning_parser)
and not self.use_harmony):
if tool_choice_auto or self.reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
@ -696,11 +695,54 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids = as_list(output.token_ids)
if self.use_harmony:
if is_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
if cur_channel == "final":
delta_message = DeltaMessage(content=delta_text)
elif cur_channel == "analysis":
if request.include_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = None
elif (cur_channel == "commentary" and cur_recipient
and cur_recipient.startswith("functions.")):
# Count completed tool calls to determine index
base_index = 0
for msg in harmony_parser.messages:
if (msg.channel == "commentary"
and msg.recipient
and msg.recipient.startswith(
"functions.")):
base_index += 1
if prev_recipient != cur_recipient:
tool_name = cur_recipient.split(
"functions.", 1)[1]
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
),
index=base_index,
)
])
elif delta_text:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
index=base_index,
function=DeltaFunctionCall(
arguments=delta_text),
)
])
else:
delta_message = None
if delta_message is not None:
harmony_tools_streamed[i] = True
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i]
@ -758,6 +800,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(tool_calls=[
delta_tool_call,
])
tools_streamed[i] = True
elif request.tool_choice == "required":
assert previous_texts is not None
@ -783,6 +826,7 @@ class OpenAIServingChat(OpenAIServing):
if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1
tools_streamed[i] = True
# update the previous values for the next iteration
previous_texts[i] = current_text
@ -859,6 +903,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only tool calls
elif tool_choice_auto:
assert tool_parser is not None
@ -871,6 +917,8 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids,
request=request))
if delta_message and delta_message.tool_calls:
tools_streamed[i] = True
# when only reasoning
elif self.reasoning_parser:
@ -907,7 +955,10 @@ class OpenAIServingChat(OpenAIServing):
# wasn't ready to send a token, then
# get the next token without streaming a chunk
if delta_message is None:
continue
if output.finish_reason is None:
continue
else:
delta_message = DeltaMessage()
# Log streaming delta if output logging is enabled
if self.enable_log_outputs and self.request_logger:
@ -993,12 +1044,18 @@ class OpenAIServingChat(OpenAIServing):
])
# Send the finish response for each request.n only once
if auto_tools_called or tools_streamed[i] or (
self.use_harmony
and harmony_tools_streamed[i]):
finish_reason_ = "tool_calls"
else:
finish_reason_ = output.finish_reason \
if output.finish_reason else "stop"
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
finish_reason=finish_reason_,
stop_reason=output.stop_reason,
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None))
@ -1131,31 +1188,32 @@ class OpenAIServingChat(OpenAIServing):
logprobs = None
if self.use_harmony:
reasoning_content, final_content, is_tool_call = (
parse_chat_output(token_ids))
if not request.include_reasoning:
reasoning_content = None
if is_tool_call:
# TODO(woosuk): Implement tool call for gpt-oss.
# For now, only Responses API supports tool call for
# gpt-oss.
raise NotImplementedError(
"Tool call in Chat Completion API is not supported "
"for gpt-oss yet. Please use Responses API instead.")
else:
# Normal message
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=final_content,
)
if TYPE_CHECKING:
assert self.tool_parser is not None
tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
"",
request=request,
token_ids=token_ids, # type: ignore
)
reasoning_content, content = None, tool_call_info.content
if request.include_reasoning:
reasoning_content, content, _ = parse_chat_output(
token_ids)
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=content,
tool_calls=tool_call_info.tool_calls,
)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if is_tool_call else
finish_reason="tool_calls"
if tool_call_info.tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason,
)
@ -1504,12 +1562,12 @@ class OpenAIServingChat(OpenAIServing):
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message()
dev_msg = get_developer_message(tools=request.tools)
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg))
messages.extend(parse_chat_input(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)

View File

@ -16,6 +16,7 @@ from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .openai_tool_parser import OpenAIToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
@ -46,4 +47,5 @@ __all__ = [
"Qwen3CoderToolParser",
"SeedOssToolParser",
"Step3ToolParser",
"OpenAIToolParser",
]

View File

@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ToolParserManager.register_module("openai")
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
token_ids: Sequence[int] | None = None,
) -> ExtractedToolCallInformation:
if token_ids is None:
raise NotImplementedError(
"OpenAIToolParser requires token IDs and does not support text-based extraction." # noqa: E501
)
parser = parse_output_into_messages(token_ids)
tool_calls = []
final_content = None
if len(parser.messages) > 0:
for msg in parser.messages:
if msg.recipient and msg.recipient.startswith("functions."):
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
arguments=msg.content[0].text,
),
))
elif msg.channel == "final":
final_content = msg.content[0].text
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=final_content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
raise NotImplementedError(
"Not being used, manual parsing in serving_chat.py" # noqa: E501
)

View File

@ -256,7 +256,7 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config
if decoding_config.reasoning_backend == "":
decoding_config.reasoning_backend = "GptOss"
decoding_config.reasoning_backend = "openai_gptoss"
# Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs

View File

@ -6,6 +6,7 @@ from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.harmony_utils import parse_chat_output
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
@ -14,7 +15,7 @@ from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("GptOss")
@ReasoningParserManager.register_module("openai_gptoss")
class GptOssReasoningParser(ReasoningParser):
"""
Reasoning parser for GptOss model.
@ -39,9 +40,10 @@ class GptOssReasoningParser(ReasoningParser):
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")
_, content, _ = parse_chat_output(input_ids)
if content is None:
return []
return self.model_tokenizer.encode(content)
def extract_reasoning_content_streaming(
self,
@ -52,13 +54,34 @@ class GptOssReasoningParser(ReasoningParser):
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")
prev_reasoning, prev_content, _ = parse_chat_output(
list(previous_token_ids))
cur_reasoning, cur_content, _ = parse_chat_output(
list(current_token_ids))
reasoning_delta = None
content_delta = None
if cur_reasoning is not None:
prev_r = prev_reasoning or ""
if cur_reasoning.startswith(prev_r):
reasoning_delta = cur_reasoning[len(prev_r):] or None
else:
reasoning_delta = cur_reasoning
if cur_content is not None:
prev_c = prev_content or ""
if cur_content.startswith(prev_c):
content_delta = cur_content[len(prev_c):] or None
else:
content_delta = cur_content
if reasoning_delta is None and content_delta is None:
return None
return DeltaMessage(reasoning_content=reasoning_delta,
content=content_delta)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self,
model_output: str,
request: ChatCompletionRequest,
) -> tuple[Optional[str], Optional[str]]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")
raise NotImplementedError(
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
)