mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:06:03 +08:00
[Model] Add reason parser for Hunyuan A13B Model. (#20625)
Signed-off-by: Asher Zhang <asherszhang@tencent.com>
This commit is contained in:
parent
5b8366b61a
commit
b140416abf
162
tests/reasoning/test_hunyuan_reasoning_parser.py
Normal file
162
tests/reasoning/test_hunyuan_reasoning_parser.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "hunyuan_a13b"
|
||||||
|
START_REASONING = "<think>\n"
|
||||||
|
START_RESPONSE = "\n</think>\n<answer>\n"
|
||||||
|
END_RESPONSE = "\n</answer>"
|
||||||
|
|
||||||
|
NO_REASONING_QUICK_THROUGHT = {
|
||||||
|
"output":
|
||||||
|
f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
|
||||||
|
SIMPLE_REASONING = {
|
||||||
|
"output":
|
||||||
|
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING = {
|
||||||
|
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
NO_REASONING = {
|
||||||
|
"output": "This is content",
|
||||||
|
"reasoning_content": None,
|
||||||
|
"content": "This is content",
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES = {
|
||||||
|
"output":
|
||||||
|
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
}
|
||||||
|
REASONING_WITH_THINK = {
|
||||||
|
"output":
|
||||||
|
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
}
|
||||||
|
COMPLETE_REASONING_WITH_THINK = {
|
||||||
|
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||||
|
"reasoning_content": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
}
|
||||||
|
MULTIPLE_LINES_WITH_THINK = {
|
||||||
|
"output":
|
||||||
|
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||||
|
"reasoning_content": "This\nThat",
|
||||||
|
"content": "This is the rest\nThat",
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NO_REASONING,
|
||||||
|
id="no_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(False, NO_REASONING_QUICK_THROUGHT, id="no_reasoning_quick"),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
REASONING_WITH_THINK,
|
||||||
|
id="reasoning_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING_WITH_THINK,
|
||||||
|
id="complete_reasoning_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES_WITH_THINK,
|
||||||
|
id="multiple_lines_with_think",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NO_REASONING,
|
||||||
|
id="no_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(True,
|
||||||
|
NO_REASONING_QUICK_THROUGHT,
|
||||||
|
id="no_reasoning_quick_stream"),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
REASONING_WITH_THINK,
|
||||||
|
id="reasoning_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING_WITH_THINK,
|
||||||
|
id="complete_reasoning_with_think_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES_WITH_THINK,
|
||||||
|
id="multiple_lines_with_think_streaming",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Global tokenizer initialization to avoid repeated loading
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct",
|
||||||
|
trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
):
|
||||||
|
output = tokenizer.tokenize(param_dict["output"])
|
||||||
|
# decode everything to tokens
|
||||||
|
output_tokens: list[str] = [
|
||||||
|
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||||
|
]
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
parser_name)(tokenizer)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(parser,
|
||||||
|
output_tokens,
|
||||||
|
streaming=streaming)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning_content"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
@ -4,6 +4,7 @@
|
|||||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||||
from .granite_reasoning_parser import GraniteReasoningParser
|
from .granite_reasoning_parser import GraniteReasoningParser
|
||||||
|
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -11,5 +12,6 @@ __all__ = [
|
|||||||
"ReasoningParserManager",
|
"ReasoningParserManager",
|
||||||
"DeepSeekR1ReasoningParser",
|
"DeepSeekR1ReasoningParser",
|
||||||
"GraniteReasoningParser",
|
"GraniteReasoningParser",
|
||||||
|
"HunyuanA13BReasoningParser",
|
||||||
"Qwen3ReasoningParser",
|
"Qwen3ReasoningParser",
|
||||||
]
|
]
|
||||||
|
|||||||
238
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
Normal file
238
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import re
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
DeltaMessage)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ReasoningParserManager.register_module("hunyuan_a13b")
|
||||||
|
class HunyuanA13BReasoningParser(ReasoningParser):
|
||||||
|
"""
|
||||||
|
Reasoning parser for Hunyuan A13B Model
|
||||||
|
|
||||||
|
HunyuanReasoningParser
|
||||||
|
|
||||||
|
This class implements a reasoning parser specifically designed
|
||||||
|
for the Hunyuan A13B Model. It is responsible for parsing and
|
||||||
|
extracting structured reasoning and answer segments from model
|
||||||
|
outputs that follow a specific pattern.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- For non-stream output , Recognizes and extracts reasoning ("think")
|
||||||
|
and answer ("answer") sections from text using regular expressions.
|
||||||
|
- For stream process, it require a token id sequences to change the
|
||||||
|
reasoning state and other state so it maintains internal state to
|
||||||
|
manage parsing across multiple token.
|
||||||
|
|
||||||
|
|
||||||
|
think start: "<think>\n": [14023, 771, 397]
|
||||||
|
think ends: "\n</think>\n<answer>\n": [198, 524, 27963, 397, 27, 9399, 397]
|
||||||
|
response ends: "\n</answer>": [524, 9399, 29]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||||
|
super().__init__(tokenizer)
|
||||||
|
self.think_start_expr = r"<think>\n"
|
||||||
|
self.think_end_expr = r"\n</think>\n"
|
||||||
|
|
||||||
|
self.response_start_expr = r"\n</think>\n<answer>\n"
|
||||||
|
self.response_end_expr = r"\n</answer>"
|
||||||
|
|
||||||
|
self.full_match_reasoning_regex = re.compile(
|
||||||
|
rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}",
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
|
self.half_match_reasoning_regex = re.compile(
|
||||||
|
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
|
||||||
|
re.DOTALL)
|
||||||
|
|
||||||
|
self.think_start_ids = [14023, 771, 397]
|
||||||
|
self.think_start_ids_fast = [14023, 771, 1363]
|
||||||
|
self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397]
|
||||||
|
self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397]
|
||||||
|
self.response_end_ids = [198, 524, 9399, 29]
|
||||||
|
self.fast_think_ids = [
|
||||||
|
14023, 771, 1363, 524, 27963, 397, 27, 9399, 397
|
||||||
|
]
|
||||||
|
|
||||||
|
# when state change, send out all the buffered text in last state
|
||||||
|
self.buffered_text = []
|
||||||
|
self.buffered_ids = []
|
||||||
|
|
||||||
|
self.current_state = "reasoning"
|
||||||
|
self.all_states = ["reasoning", "response"]
|
||||||
|
|
||||||
|
self.current_state = "idle"
|
||||||
|
self.expected_sequence = self.think_start_ids
|
||||||
|
# this sequence only for the think start, it has two way to start.
|
||||||
|
self.expected_sequence_side = self.think_start_ids_fast
|
||||||
|
self.sequence_index = 0
|
||||||
|
self.token_buffer = []
|
||||||
|
self.text_buffer = ""
|
||||||
|
|
||||||
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
|
return self.current_state == "response"
|
||||||
|
|
||||||
|
def extract_reasoning_content(
|
||||||
|
self, model_output: str, request: ChatCompletionRequest
|
||||||
|
) -> tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Extract the reasoning content & content sections, respectively.
|
||||||
|
If the sequence doesn't match what we expect, i.e., the model generates
|
||||||
|
something else, all content is considered non-reasoning content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (str): Output of the model to be parsed.
|
||||||
|
request (ChatCompletionRequest): Request being processed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[Optional[str], Optional[str]]: Tuple pair containing the
|
||||||
|
reasoning content and non-reasoning content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
re_match = self.full_match_reasoning_regex.findall(model_output)
|
||||||
|
if re_match:
|
||||||
|
reasoning_content, response_content = re_match[0]
|
||||||
|
if len(reasoning_content) == 0:
|
||||||
|
reasoning_content = None
|
||||||
|
if len(response_content) == 0:
|
||||||
|
response_content = None
|
||||||
|
return reasoning_content, response_content
|
||||||
|
|
||||||
|
fallback_regex = self.half_match_reasoning_regex
|
||||||
|
fallback_match = fallback_regex.findall(model_output)
|
||||||
|
if fallback_match:
|
||||||
|
reasoning_content, response_content = fallback_match[0]
|
||||||
|
|
||||||
|
if response_content.endswith(self.response_end_expr):
|
||||||
|
response_content = response_content[:-len(self.
|
||||||
|
response_end_expr)]
|
||||||
|
|
||||||
|
if len(reasoning_content) == 0:
|
||||||
|
reasoning_content = None
|
||||||
|
if len(response_content) == 0:
|
||||||
|
response_content = None
|
||||||
|
|
||||||
|
return reasoning_content, response_content
|
||||||
|
|
||||||
|
return None, model_output
|
||||||
|
|
||||||
|
def _is_strict_increasing_subsequence(self, subsequence: Sequence[int],
|
||||||
|
sequence: Sequence[int]) -> bool:
|
||||||
|
if not subsequence:
|
||||||
|
return False
|
||||||
|
|
||||||
|
sub_idx = 0
|
||||||
|
for num in sequence:
|
||||||
|
if sub_idx < len(subsequence) and num == subsequence[sub_idx]:
|
||||||
|
sub_idx += 1
|
||||||
|
return sub_idx == len(subsequence)
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self,
|
||||||
|
previous_text: str,
|
||||||
|
current_text: str,
|
||||||
|
delta_text: str,
|
||||||
|
previous_token_ids: Sequence[int],
|
||||||
|
current_token_ids: Sequence[int],
|
||||||
|
delta_token_ids: Sequence[int],
|
||||||
|
) -> Union[DeltaMessage, None]:
|
||||||
|
"""Extract content using token ID sequence state machine"""
|
||||||
|
# Define sequences
|
||||||
|
think_start_sequence = self.think_start_ids
|
||||||
|
response_start_sequence = self.response_start_ids
|
||||||
|
response_end_sequence = self.response_end_ids
|
||||||
|
|
||||||
|
assert (len(delta_token_ids) == 1)
|
||||||
|
# Process each token in the delta
|
||||||
|
token = delta_token_ids[0]
|
||||||
|
|
||||||
|
def check_token_with_sequence(token):
|
||||||
|
if self.current_state == "idle" or self.current_state == "think":
|
||||||
|
return (token == self.expected_sequence[self.sequence_index]
|
||||||
|
or token == \
|
||||||
|
self.expected_sequence_side[self.sequence_index])
|
||||||
|
else:
|
||||||
|
return token == self.expected_sequence[self.sequence_index]
|
||||||
|
|
||||||
|
def check_last_token(token):
|
||||||
|
if self.current_state == "idle" or self.current_state == "think":
|
||||||
|
# only return true if it's judge using a side sequence.
|
||||||
|
if (self.sequence_index - 1 < len(self.expected_sequence_side)
|
||||||
|
and token
|
||||||
|
== self.expected_sequence_side[self.sequence_index -
|
||||||
|
1]):
|
||||||
|
return self.sequence_index == len(
|
||||||
|
self.expected_sequence_side)
|
||||||
|
else:
|
||||||
|
return self.sequence_index == len(self.expected_sequence)
|
||||||
|
else:
|
||||||
|
return self.sequence_index == len(self.expected_sequence)
|
||||||
|
|
||||||
|
# Check if token matches expected sequence
|
||||||
|
token_in_state_seq = check_token_with_sequence(token)
|
||||||
|
|
||||||
|
if token_in_state_seq:
|
||||||
|
# Store matching token
|
||||||
|
self.token_buffer.append(token)
|
||||||
|
self.text_buffer += delta_text
|
||||||
|
self.sequence_index += 1
|
||||||
|
## state change from idle->think->response->idle
|
||||||
|
|
||||||
|
# Check if sequence fully matched
|
||||||
|
if check_last_token(token):
|
||||||
|
# State transition
|
||||||
|
if self.current_state == "idle":
|
||||||
|
self.current_state = "think"
|
||||||
|
self.expected_sequence = response_start_sequence
|
||||||
|
self.expected_sequence_side = self.response_start_ids_fast
|
||||||
|
elif self.current_state == "think":
|
||||||
|
self.current_state = "response"
|
||||||
|
self.expected_sequence = response_end_sequence
|
||||||
|
elif self.current_state == "response":
|
||||||
|
self.current_state = "idle"
|
||||||
|
self.expected_sequence = think_start_sequence
|
||||||
|
self.expected_sequence_side = self.think_start_ids_fast
|
||||||
|
|
||||||
|
# Reset matching state
|
||||||
|
self.sequence_index = 0
|
||||||
|
self.token_buffer = []
|
||||||
|
self.text_buffer = ""
|
||||||
|
# Do not send content for state transition texts.
|
||||||
|
else:
|
||||||
|
# Sequence broken - handle buffered content
|
||||||
|
if self.token_buffer and len(self.token_buffer) > 0:
|
||||||
|
# Send buffered tokens
|
||||||
|
buffered_content = self.text_buffer + delta_text
|
||||||
|
# Reset matching state
|
||||||
|
self.sequence_index = 0
|
||||||
|
self.token_buffer = []
|
||||||
|
self.text_buffer = ""
|
||||||
|
|
||||||
|
# Return content based on current state
|
||||||
|
if self.current_state == "think":
|
||||||
|
return DeltaMessage(reasoning_content=buffered_content,
|
||||||
|
content=None)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(reasoning_content=None,
|
||||||
|
content=buffered_content)
|
||||||
|
else:
|
||||||
|
# No buffered content, send normally
|
||||||
|
if self.current_state == "think":
|
||||||
|
return DeltaMessage(reasoning_content=delta_text,
|
||||||
|
content=None)
|
||||||
|
else:
|
||||||
|
return DeltaMessage(reasoning_content=None,
|
||||||
|
content=delta_text)
|
||||||
|
|
||||||
|
# If no content to send in this delta
|
||||||
|
return None
|
||||||
Loading…
x
Reference in New Issue
Block a user