mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 03:29:38 +08:00
[Structured Output][Reasoning] Improves decoding throughput for models using single-token reasoning endings. (#30056)
This commit is contained in:
parent
67475a6e81
commit
c72ea10723
@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
|
||||
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
|
||||
return self.end_token_id in delta_token_ids
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
|
||||
is False
|
||||
)
|
||||
|
||||
def test_is_reasoning_end_streaming(self, test_tokenizer):
|
||||
"""Test the is_reasoning_end_streaming method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
end_token_id = parser.end_token_id
|
||||
start_token_id = parser.start_token_id
|
||||
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
|
||||
is True
|
||||
)
|
||||
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
|
||||
assert parser.is_reasoning_end_streaming([], []) is False
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id], [end_token_id]
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
|
||||
[2],
|
||||
)
|
||||
is False
|
||||
)
|
||||
assert (
|
||||
parser.is_reasoning_end_streaming(
|
||||
[1, start_token_id, 2, end_token_id, 2, 2], [2]
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_extract_content_ids(self, test_tokenizer):
|
||||
"""Test the extract_content_ids method."""
|
||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||
|
||||
@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
|
||||
input_tokens = tokenizer.tokenize(input_text)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
assert parser.is_reasoning_end(input_ids) is True
|
||||
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
|
||||
|
||||
# Test extract_content_ids returns all input_ids
|
||||
assert parser.extract_content_ids(input_ids) == input_ids
|
||||
|
||||
@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
|
||||
request.use_structured_output = True
|
||||
request.prompt_token_ids = [1, 2, 3, 4, 5]
|
||||
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
request.num_computed_tokens = 5
|
||||
return request
|
||||
|
||||
def test_should_fill_bitmask_with_enable_in_reasoning(
|
||||
|
||||
@ -63,6 +63,31 @@ class ReasoningParser:
|
||||
True if the reasoning content ends in the input_ids.
|
||||
"""
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids on a
|
||||
decode step.
|
||||
|
||||
It is used in structured engines like `xgrammar` to check if the
|
||||
reasoning content ends in the model output during a decode step.
|
||||
`input_ids` the entire model output and `delta_ids` are the last few
|
||||
computed tokens of the model output (like during a decode step).
|
||||
|
||||
Parameters:
|
||||
input_ids: list[int]
|
||||
The entire model output.
|
||||
delta_ids: list[int]
|
||||
The last few computed tokens of the model output at the current decode step.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the reasoning content ends in the `delta_ids` on a
|
||||
decode step.
|
||||
"""
|
||||
return self.is_reasoning_end(input_ids)
|
||||
|
||||
@abstractmethod
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
|
||||
@ -74,6 +74,12 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
) -> bool:
|
||||
end_token_id = self.end_token_id
|
||||
return end_token_id in delta_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
|
||||
@ -35,6 +35,11 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
) -> bool:
|
||||
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
|
||||
@ -56,6 +56,11 @@ class Holo2ReasoningParser(ReasoningParser):
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
) -> bool:
|
||||
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
|
||||
@ -32,6 +32,11 @@ class IdentityReasoningParser(ReasoningParser):
|
||||
# Always return True, since we never treat reasoning specially
|
||||
return True
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
# Identity: return all tokens as content
|
||||
return input_ids
|
||||
|
||||
@ -339,7 +339,9 @@ class StructuredOutputManager:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
if self.reasoner.is_reasoning_end_streaming(
|
||||
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
|
||||
):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user