[Model] Add Holo2 reasoning parser (#30048)

Signed-off-by: hdlj-h <hubert@hcompany.ai>
This commit is contained in:
Hubert de La Jonquiere 2025-12-05 03:38:45 +01:00 committed by GitHub
parent aaddc9c82a
commit befb59e5b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 277 additions and 0 deletions

View File

@ -18,6 +18,7 @@ vLLM currently supports the following reasoning models:
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ |
| [Holo2 series](https://huggingface.co/collections/Hcompany/holo2) | `holo2` | `json`, `regex` | ✅ |
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ |
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
| [MiniMax-M2](https://huggingface.co/MiniMaxAI/MiniMax-M2) | `minimax_m2_append_think` | `json`, `regex` | ✅ |
@ -28,6 +29,7 @@ vLLM currently supports the following reasoning models:
IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`.
DeepSeek-V3.1 tool calling is supported in non-thinking mode.
Holo2 reasoning is enabled by default. To disable it, you must also pass `thinking=False` in your `chat_template_kwargs`.
## Quickstart

View File

@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
REASONING_MODEL_NAME = "HCompany/Holo2-4B"
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
@pytest.mark.parametrize(
"thinking,expected_parser_type",
[
(True, DeepSeekR1ReasoningParser),
(False, IdentityReasoningParser),
],
)
def test_parser_selection(tokenizer, thinking, expected_parser_type):
parser = Holo2ReasoningParser(
tokenizer,
chat_template_kwargs={
"thinking": thinking,
},
)
assert isinstance(parser._parser, expected_parser_type)
def test_holo2_default_parser_is_deepseekr1(tokenizer):
parser = Holo2ReasoningParser(tokenizer)
assert isinstance(parser._parser, DeepSeekR1ReasoningParser)
def test_holo2_supports_structured_output(tokenizer):
# Structured output manager uses the reasoning parser to check if the
# reasoning content is ended before applying the grammar. The main function
# used is is_reasoning_end. This test checks if the parser is able to
# correctly identify the end of the reasoning content.
# important to not pass chat_template_kwargs here as it is done in the
# StructuredOutputManager
parser = Holo2ReasoningParser(tokenizer)
end_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
assert parser.is_reasoning_end([1, 2, 4, end_token_id])
assert not parser.is_reasoning_end([1, 2, 4])
assert parser.is_reasoning_end([1, 2, 4, end_token_id, 5])
# thinking is True, non-streaming
WITH_THINK = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
}
# thinking is True, streaming
WITH_THINK_STREAM = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
}
# thinking is False, non-streaming
THINKING_DISABLED = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
}
# thinking is False, streaming
THINKING_DISABLED_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
}
# thinking is False but the model output </think>, non-streaming
THINKING_DISABLED_WITH_CLOSE_TAG = {
"output": "</think>This is the rest",
"reasoning": None,
"content": "</think>This is the rest",
}
# thinking is False but the model output </think>, streaming
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM = {
"output": "some text</think>This is the rest",
"reasoning": None,
"content": "some text</think>This is the rest",
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
None,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
None,
id="with_think_stream",
),
pytest.param(
False,
WITH_THINK,
{"thinking": True},
id="with_think_enabled",
),
pytest.param(
True,
WITH_THINK_STREAM,
{"thinking": True},
id="with_think_stream_enabled",
),
pytest.param(
False,
THINKING_DISABLED,
{"thinking": False},
id="thinking_disabled",
),
pytest.param(
True,
THINKING_DISABLED_STREAM,
{"thinking": False},
id="thinking_disabled_stream",
),
pytest.param(
False,
THINKING_DISABLED_WITH_CLOSE_TAG,
{"thinking": False},
id="thinking_disabled_with_close_tag",
),
pytest.param(
True,
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM,
{"thinking": False},
id="thinking_disabled_with_close_tag_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
None,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
None,
id="complete_reasoning_stream",
),
]
@pytest.mark.parametrize("streaming, param_dict, chat_template_kwargs", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
chat_template_kwargs: dict | None,
tokenizer,
):
output = tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser("holo2")(
tokenizer,
chat_template_kwargs=chat_template_kwargs,
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]

View File

@ -44,6 +44,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"granite_reasoning_parser",
"GraniteReasoningParser",
),
"holo2": (
"holo2_reasoning_parser",
"Holo2ReasoningParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_reasoning_parser",
"HunyuanA13BReasoningParser",

View File

@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import (
ReasoningParser,
)
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Holo2ReasoningParser(ReasoningParser):
"""
Reasoning parser for the Holo2 models which are based on Qwen3.
The Holo2 model uses <think>...</think> tokens to denote reasoning text but <think>
is part of the chat template. This parser extracts the reasoning content until
</think> in the model's output.
The model provides a switch to enable or disable reasoning
output via the 'thinking=False' parameter.
Chat template args:
- thinking: Whether to enable reasoning output (default: True)
Parsing rules on model output:
- thinking == False
-> Model output is treated as purely the content |content|
- thinking == True
-> Model output is |reasoning_content|</think>|content|
"""
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
# Deepseek V3 and Holo2 are similar. However, Holo2 models think by default.
# this parser without user specified chat template args is initiated once for
# all requests in the structured output manager. So it is important that without
# user specified chat template args, the default thinking is True.
enable_thinking = bool(chat_kwargs.get("thinking", True))
if enable_thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else:
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
return self._parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)