mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 05:11:19 +08:00
[Structured Output][Reasoning] Improves decoding throughput for models using single-token reasoning endings. (#30056)
This commit is contained in:
parent
67475a6e81
commit
c72ea10723
@ -299,6 +299,9 @@ Additionally, to enable structured output, you'll need to create a new `Reasoner
|
|||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.end_token_id in input_ids
|
return self.end_token_id in input_ids
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(self, input_ids: list[int], delta_ids: list[int]) -> bool:
|
||||||
|
return self.end_token_id in delta_token_ids
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@ -132,6 +132,41 @@ class TestBaseThinkingReasoningParserMethods:
|
|||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_is_reasoning_end_streaming(self, test_tokenizer):
|
||||||
|
"""Test the is_reasoning_end_streaming method."""
|
||||||
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
end_token_id = parser.end_token_id
|
||||||
|
start_token_id = parser.start_token_id
|
||||||
|
|
||||||
|
assert (
|
||||||
|
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
|
||||||
|
assert parser.is_reasoning_end_streaming([], []) is False
|
||||||
|
assert (
|
||||||
|
parser.is_reasoning_end_streaming(
|
||||||
|
[1, start_token_id, 2, end_token_id], [end_token_id]
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
parser.is_reasoning_end_streaming(
|
||||||
|
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
|
||||||
|
[2],
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
parser.is_reasoning_end_streaming(
|
||||||
|
[1, start_token_id, 2, end_token_id, 2, 2], [2]
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
def test_extract_content_ids(self, test_tokenizer):
|
def test_extract_content_ids(self, test_tokenizer):
|
||||||
"""Test the extract_content_ids method."""
|
"""Test the extract_content_ids method."""
|
||||||
parser = TestThinkingReasoningParser(test_tokenizer)
|
parser = TestThinkingReasoningParser(test_tokenizer)
|
||||||
|
|||||||
@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
|
|||||||
input_tokens = tokenizer.tokenize(input_text)
|
input_tokens = tokenizer.tokenize(input_text)
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||||
assert parser.is_reasoning_end(input_ids) is True
|
assert parser.is_reasoning_end(input_ids) is True
|
||||||
|
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
|
||||||
|
|
||||||
# Test extract_content_ids returns all input_ids
|
# Test extract_content_ids returns all input_ids
|
||||||
assert parser.extract_content_ids(input_ids) == input_ids
|
assert parser.extract_content_ids(input_ids) == input_ids
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class TestReasoningStructuredOutput:
|
|||||||
request.use_structured_output = True
|
request.use_structured_output = True
|
||||||
request.prompt_token_ids = [1, 2, 3, 4, 5]
|
request.prompt_token_ids = [1, 2, 3, 4, 5]
|
||||||
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
|
request.all_token_ids = [1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
request.num_computed_tokens = 5
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def test_should_fill_bitmask_with_enable_in_reasoning(
|
def test_should_fill_bitmask_with_enable_in_reasoning(
|
||||||
|
|||||||
@ -63,6 +63,31 @@ class ReasoningParser:
|
|||||||
True if the reasoning content ends in the input_ids.
|
True if the reasoning content ends in the input_ids.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: list[int], delta_ids: list[int]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the reasoning content ends in the input_ids on a
|
||||||
|
decode step.
|
||||||
|
|
||||||
|
It is used in structured engines like `xgrammar` to check if the
|
||||||
|
reasoning content ends in the model output during a decode step.
|
||||||
|
`input_ids` the entire model output and `delta_ids` are the last few
|
||||||
|
computed tokens of the model output (like during a decode step).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_ids: list[int]
|
||||||
|
The entire model output.
|
||||||
|
delta_ids: list[int]
|
||||||
|
The last few computed tokens of the model output at the current decode step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
True if the reasoning content ends in the `delta_ids` on a
|
||||||
|
decode step.
|
||||||
|
"""
|
||||||
|
return self.is_reasoning_end(input_ids)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -74,6 +74,12 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: list[int], delta_ids: list[int]
|
||||||
|
) -> bool:
|
||||||
|
end_token_id = self.end_token_id
|
||||||
|
return end_token_id in delta_ids
|
||||||
|
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Extract the content after the end tokens
|
Extract the content after the end tokens
|
||||||
|
|||||||
@ -35,6 +35,11 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
|||||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
return self._parser.is_reasoning_end(input_ids)
|
return self._parser.is_reasoning_end(input_ids)
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: list[int], delta_ids: list[int]
|
||||||
|
) -> bool:
|
||||||
|
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||||
|
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
return self._parser.extract_content_ids(input_ids)
|
return self._parser.extract_content_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,11 @@ class Holo2ReasoningParser(ReasoningParser):
|
|||||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||||
return self._parser.is_reasoning_end(input_ids)
|
return self._parser.is_reasoning_end(input_ids)
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: list[int], delta_ids: list[int]
|
||||||
|
) -> bool:
|
||||||
|
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||||
|
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
return self._parser.extract_content_ids(input_ids)
|
return self._parser.extract_content_ids(input_ids)
|
||||||
|
|
||||||
|
|||||||
@ -32,6 +32,11 @@ class IdentityReasoningParser(ReasoningParser):
|
|||||||
# Always return True, since we never treat reasoning specially
|
# Always return True, since we never treat reasoning specially
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_reasoning_end_streaming(
|
||||||
|
self, input_ids: list[int], delta_ids: list[int]
|
||||||
|
) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||||
# Identity: return all tokens as content
|
# Identity: return all tokens as content
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|||||||
@ -339,7 +339,9 @@ class StructuredOutputManager:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Check if reasoning ends in *this* step
|
# Check if reasoning ends in *this* step
|
||||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
if self.reasoner.is_reasoning_end_streaming(
|
||||||
|
request.all_token_ids, request.all_token_ids[request.num_computed_tokens :]
|
||||||
|
):
|
||||||
# Reasoning just ended, so we shouldn't advance til
|
# Reasoning just ended, so we shouldn't advance til
|
||||||
# next pass
|
# next pass
|
||||||
structured_req.reasoning_ended = True
|
structured_req.reasoning_ended = True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user