mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 06:45:01 +08:00
[GptOss] Add GptOss reasoning parser to support structure output (#22322)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com> Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
98a3a81024
commit
a47e6ffe93
@ -247,13 +247,13 @@ class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
||||
config.max_model_len)
|
||||
|
||||
|
||||
class GptOssConfig(VerifyAndUpdateConfig):
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
decoding_config = vllm_config.decoding_config
|
||||
if decoding_config.reasoning_backend == "":
|
||||
decoding_config.reasoning_backend = "openai"
|
||||
decoding_config.reasoning_backend = "GptOss"
|
||||
|
||||
# Increase the max capture size from 512 to 1024 for performance.
|
||||
# NOTE(woosuk): This will increase the number of CUDA graphs
|
||||
@ -373,5 +373,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
"GptOssForCausalLM": GptOssConfig,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
@ -20,4 +21,5 @@ __all__ = [
|
||||
"Glm4MoeModelReasoningParser",
|
||||
"MistralReasoningParser",
|
||||
"Step3ReasoningParser",
|
||||
"GptOssReasoningParser",
|
||||
]
|
||||
|
||||
64
vllm/reasoning/gptoss_reasoning_parser.py
Normal file
64
vllm/reasoning/gptoss_reasoning_parser.py
Normal file
@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("GptOss")
|
||||
class GptOssReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for GptOss model.
|
||||
|
||||
The GptOss model uses harmony to extract reasoning content and this parser
|
||||
is only used for detecting the end of the reasoning content.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
self.reasoning_end_token_ids = self.model_tokenizer.encode(
|
||||
"<|start|>assistant<|channel|>final<|message|>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
end_token_ids = self.reasoning_end_token_ids
|
||||
assert len(end_token_ids) > 0, "reasoning_end_token_ids is empty"
|
||||
# Check if the end sequence is present in the input_ids.
|
||||
# We search from the end of input_ids to find the last match.
|
||||
for i in range(len(input_ids) - len(end_token_ids), -1, -1):
|
||||
if input_ids[i:i + len(end_token_ids)] == end_token_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
raise RuntimeError(
|
||||
"GptOss model uses harmony to extract reasoning content. This "
|
||||
"function should not be called.")
|
||||
Loading…
x
Reference in New Issue
Block a user