mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Model] Add Holo2 reasoning parser (#30048)
Signed-off-by: hdlj-h <hubert@hcompany.ai>
This commit is contained in:
parent
aaddc9c82a
commit
befb59e5b1
@ -18,6 +18,7 @@ vLLM currently supports the following reasoning models:
|
||||
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
|
||||
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
|
||||
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ |
|
||||
| [Holo2 series](https://huggingface.co/collections/Hcompany/holo2) | `holo2` | `json`, `regex` | ✅ |
|
||||
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ |
|
||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||
| [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) | `minimax_m2_append_think` | `json`, `regex` | ✅ |
|
||||
@ -28,6 +29,7 @@ vLLM currently supports the following reasoning models:
|
||||
IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`.
|
||||
DeepSeek-V3.1 tool calling is supported in non-thinking mode.
|
||||
Holo2 reasoning is enabled by default. To disable it, you must also pass `thinking=False` in your `chat_template_kwargs`.
|
||||
|
||||
## Quickstart
|
||||
|
||||
|
||||
188
tests/reasoning/test_holo2_reasoning_parser.py
Normal file
188
tests/reasoning/test_holo2_reasoning_parser.py
Normal file
@ -0,0 +1,188 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
REASONING_MODEL_NAME = "HCompany/Holo2-4B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"thinking,expected_parser_type",
|
||||
[
|
||||
(True, DeepSeekR1ReasoningParser),
|
||||
(False, IdentityReasoningParser),
|
||||
],
|
||||
)
|
||||
def test_parser_selection(tokenizer, thinking, expected_parser_type):
|
||||
parser = Holo2ReasoningParser(
|
||||
tokenizer,
|
||||
chat_template_kwargs={
|
||||
"thinking": thinking,
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(parser._parser, expected_parser_type)
|
||||
|
||||
|
||||
def test_holo2_default_parser_is_deepseekr1(tokenizer):
|
||||
parser = Holo2ReasoningParser(tokenizer)
|
||||
|
||||
assert isinstance(parser._parser, DeepSeekR1ReasoningParser)
|
||||
|
||||
|
||||
def test_holo2_supports_structured_output(tokenizer):
|
||||
# Structured output manager uses the reasoning parser to check if the
|
||||
# reasoning content is ended before applying the grammar. The main function
|
||||
# used is is_reasoning_end. This test checks if the parser is able to
|
||||
# correctly identify the end of the reasoning content.
|
||||
|
||||
# important to not pass chat_template_kwargs here as it is done in the
|
||||
# StructuredOutputManager
|
||||
parser = Holo2ReasoningParser(tokenizer)
|
||||
|
||||
end_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
|
||||
|
||||
assert parser.is_reasoning_end([1, 2, 4, end_token_id])
|
||||
assert not parser.is_reasoning_end([1, 2, 4])
|
||||
assert parser.is_reasoning_end([1, 2, 4, end_token_id, 5])
|
||||
|
||||
|
||||
# thinking is True, non-streaming
|
||||
WITH_THINK = {
|
||||
"output": "This is a reasoning section</think>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
# thinking is True, streaming
|
||||
WITH_THINK_STREAM = {
|
||||
"output": "This is a reasoning section</think>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
# thinking is False, non-streaming
|
||||
THINKING_DISABLED = {
|
||||
"output": "This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
}
|
||||
# thinking is False, streaming
|
||||
THINKING_DISABLED_STREAM = {
|
||||
"output": "This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is the rest",
|
||||
}
|
||||
# thinking is False but the model output </think>, non-streaming
|
||||
THINKING_DISABLED_WITH_CLOSE_TAG = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "</think>This is the rest",
|
||||
}
|
||||
# thinking is False but the model output </think>, streaming
|
||||
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM = {
|
||||
"output": "some text</think>This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "some text</think>This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section</think>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
WITH_THINK,
|
||||
None,
|
||||
id="with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITH_THINK_STREAM,
|
||||
None,
|
||||
id="with_think_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
WITH_THINK,
|
||||
{"thinking": True},
|
||||
id="with_think_enabled",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITH_THINK_STREAM,
|
||||
{"thinking": True},
|
||||
id="with_think_stream_enabled",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
THINKING_DISABLED,
|
||||
{"thinking": False},
|
||||
id="thinking_disabled",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
THINKING_DISABLED_STREAM,
|
||||
{"thinking": False},
|
||||
id="thinking_disabled_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
THINKING_DISABLED_WITH_CLOSE_TAG,
|
||||
{"thinking": False},
|
||||
id="thinking_disabled_with_close_tag",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM,
|
||||
{"thinking": False},
|
||||
id="thinking_disabled_with_close_tag_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
None,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
None,
|
||||
id="complete_reasoning_stream",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict, chat_template_kwargs", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
chat_template_kwargs: dict | None,
|
||||
tokenizer,
|
||||
):
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
output_tokens: list[str] = [
|
||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser("holo2")(
|
||||
tokenizer,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_tokens, streaming=streaming
|
||||
)
|
||||
|
||||
assert reasoning == param_dict["reasoning"]
|
||||
assert content == param_dict["content"]
|
||||
@ -44,6 +44,10 @@ _REASONING_PARSERS_TO_REGISTER = {
|
||||
"granite_reasoning_parser",
|
||||
"GraniteReasoningParser",
|
||||
),
|
||||
"holo2": (
|
||||
"holo2_reasoning_parser",
|
||||
"Holo2ReasoningParser",
|
||||
),
|
||||
"hunyuan_a13b": (
|
||||
"hunyuan_a13b_reasoning_parser",
|
||||
"HunyuanA13BReasoningParser",
|
||||
|
||||
83
vllm/reasoning/holo2_reasoning_parser.py
Normal file
83
vllm/reasoning/holo2_reasoning_parser.py
Normal file
@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import (
|
||||
ReasoningParser,
|
||||
)
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Holo2ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for the Holo2 models which are based on Qwen3.
|
||||
|
||||
The Holo2 model uses <think>...</think> tokens to denote reasoning text but <think>
|
||||
is part of the chat template. This parser extracts the reasoning content until
|
||||
</think> in the model's output.
|
||||
|
||||
The model provides a switch to enable or disable reasoning
|
||||
output via the 'thinking=False' parameter.
|
||||
|
||||
Chat template args:
|
||||
- thinking: Whether to enable reasoning output (default: True)
|
||||
|
||||
|
||||
Parsing rules on model output:
|
||||
- thinking == False
|
||||
-> Model output is treated as purely the content |content|
|
||||
- thinking == True
|
||||
-> Model output is |reasoning_content|</think>|content|
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
# Deepseek V3 and Holo2 are similar. However, Holo2 models think by default.
|
||||
# this parser without user specified chat template args is initiated once for
|
||||
# all requests in the structured output manager. So it is important that without
|
||||
# user specified chat template args, the default thinking is True.
|
||||
|
||||
enable_thinking = bool(chat_kwargs.get("thinking", True))
|
||||
|
||||
if enable_thinking:
|
||||
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
|
||||
else:
|
||||
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
def extract_reasoning(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
return self._parser.extract_reasoning(model_output, request)
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
return self._parser.extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user