[GptOss] Add GptOss reasoning parser to support structure output (#22322)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Chen Zhang 2025-08-05 23:39:13 -07:00 committed by GitHub
parent 98a3a81024
commit a47e6ffe93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 3 deletions

View File

@ -247,13 +247,13 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
config.max_model_len)
class GptOssConfig(VerifyAndUpdateConfig):
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
decoding_config = vllm_config.decoding_config
if decoding_config.reasoning_backend == "":
decoding_config.reasoning_backend = "openai"
decoding_config.reasoning_backend = "GptOss"
# Increase the max capture size from 512 to 1024 for performance.
# NOTE(woosuk): This will increase the number of CUDA graphs
@ -373,5 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
"GptOssForCausalLM": GptOssConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,
}

View File

@ -4,6 +4,7 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .gptoss_reasoning_parser import GptOssReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
@ -20,4 +21,5 @@ __all__ = [
"Glm4MoeModelReasoningParser",
"MistralReasoningParser",
"Step3ReasoningParser",
"GptOssReasoningParser",
]

View File

@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("GptOss")
class GptOssReasoningParser(ReasoningParser):
"""
Reasoning parser for GptOss model.
The GptOss model uses harmony to extract reasoning content and this parser
is only used for detecting the end of the reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.reasoning_end_token_ids = self.model_tokenizer.encode(
"<|start|>assistant<|channel|>final<|message|>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
end_token_ids = self.reasoning_end_token_ids
assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
# Check if the end sequence is present in the input_ids.
# We search from the end of input_ids to find the last match.
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
if input_ids[i:i + len(end_token_ids)] == end_token_ids:
return True
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
raise RuntimeError(
"GptOss model uses harmony to extract reasoning content. This "
"function should not be called.")