[Structured Output][Reasoning] Improves decoding throughput for models using single-token reasoning endings. (#30056)

This commit is contained in:
Hubert de La Jonquiere 2025-12-09 11:54:08 +01:00 committed by GitHub
parent 67475a6e81
commit c72ea10723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 89 additions and 1 deletions

View File

@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
return self.end_token_id in delta_token_ids
...
```

View File

@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
is False
)
def test_is_reasoning_end_streaming(self, test_tokenizer):
"""Test the is_reasoning_end_streaming method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
assert (
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
is True
)
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
assert parser.is_reasoning_end_streaming([], []) is False
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id], [end_token_id]
)
is True
)
assert (
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
[2],
)
is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, 2], [2]
)
is False
)
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)

View File

@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
input_tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
assert parser.is_reasoning_end(input_ids) is True
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
# Test extract_content_ids returns all input_ids
assert parser.extract_content_ids(input_ids) == input_ids

View File

@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
request.use_structured_output = True
request.prompt_token_ids = [1, 2, 3, 4, 5]
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
request.num_computed_tokens = 5
return request
def test_should_fill_bitmask_with_enable_in_reasoning(

View File

@ -63,6 +63,31 @@ class ReasoningParser:
True if the reasoning content ends in the input_ids.
"""
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
"""
Check if the reasoning content ends in the input_ids on a
decode step.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output during a decode step.
`input_ids` the entire model output and `delta_ids` are the last few
computed tokens of the model output (like during a decode step).
Parameters:
input_ids: list[int]
The entire model output.
delta_ids: list[int]
The last few computed tokens of the model output at the current decode step.
Returns:
bool
True if the reasoning content ends in the `delta_ids` on a
decode step.
"""
return self.is_reasoning_end(input_ids)
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""

View File

@ -74,6 +74,12 @@ class BaseThinkingReasoningParser(ReasoningParser):
return True
return False
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
end_token_id = self.end_token_id
return end_token_id in delta_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens

View File

@ -35,6 +35,11 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)

View File

@ -56,6 +56,11 @@ class Holo2ReasoningParser(ReasoningParser):
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)

View File

@ -32,6 +32,11 @@ class IdentityReasoningParser(ReasoningParser):
# Always return True, since we never treat reasoning specially
return True
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
return True
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
# Identity: return all tokens as content
return input_ids

View File

@ -339,7 +339,9 @@ class StructuredOutputManager:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
if self.reasoner.is_reasoning_end_streaming(
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req.reasoning_ended = True