mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 20:45:44 +08:00
[Frontend][Responses API] Support reporting tool output tokens and fix reasoning token count (#24285)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
fb691ee4e7
commit
a3645ed94d
425
tests/entrypoints/test_context.py
Normal file
425
tests/entrypoints/test_context.py
Normal file
@ -0,0 +1,425 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai_harmony import StreamState
|
||||||
|
|
||||||
|
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function for Python < 3.10 compatibility
|
||||||
|
async def async_next(async_iterator):
|
||||||
|
"""Compatibility function equivalent to Python 3.10's anext()."""
|
||||||
|
return await async_iterator.__anext__()
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_request_output(
|
||||||
|
prompt_token_ids=None,
|
||||||
|
output_token_ids=None,
|
||||||
|
num_cached_tokens=0,
|
||||||
|
finished=True,
|
||||||
|
):
|
||||||
|
"""Helper function to create a mock RequestOutput object for testing."""
|
||||||
|
outputs = []
|
||||||
|
token_ids = output_token_ids if output_token_ids is not None else []
|
||||||
|
outputs = [
|
||||||
|
CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="Test output",
|
||||||
|
token_ids=token_ids,
|
||||||
|
cumulative_logprob=0.0,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
stop_reason=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=outputs,
|
||||||
|
finished=finished,
|
||||||
|
num_cached_tokens=num_cached_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_mock_outputs(num_turns,
|
||||||
|
prompt_token_counts,
|
||||||
|
output_token_counts,
|
||||||
|
cached_token_counts=None):
|
||||||
|
"""Generate a sequence of mock RequestOutput objects to simulate multiple
|
||||||
|
turns."""
|
||||||
|
if cached_token_counts is None:
|
||||||
|
cached_token_counts = [0] * num_turns
|
||||||
|
|
||||||
|
for i in range(num_turns):
|
||||||
|
# Create mock prompt token IDs and output token IDs
|
||||||
|
prompt_token_ids = list(range(1, prompt_token_counts[i] + 1))
|
||||||
|
output_token_ids = list(range(1, output_token_counts[i] + 1))
|
||||||
|
|
||||||
|
# Create and yield the RequestOutput
|
||||||
|
yield create_mock_request_output(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
output_token_ids=output_token_ids,
|
||||||
|
num_cached_tokens=cached_token_counts[i],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_parser():
|
||||||
|
"""Set up a mock parser for tests."""
|
||||||
|
with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant"
|
||||||
|
) as mock_parser_factory:
|
||||||
|
# Create a mock parser object
|
||||||
|
parser = MagicMock()
|
||||||
|
parser.messages = []
|
||||||
|
parser.current_channel = None
|
||||||
|
parser.state = StreamState.EXPECT_START
|
||||||
|
mock_parser_factory.return_value = parser
|
||||||
|
yield parser
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_turn_token_counting():
|
||||||
|
"""Test token counting behavior for a single turn."""
|
||||||
|
# Create a context
|
||||||
|
context = HarmonyContext(messages=[], available_tools=[])
|
||||||
|
|
||||||
|
# Create a mock RequestOutput with specific token counts
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens
|
||||||
|
output_token_ids=[6, 7, 8], # 3 output tokens
|
||||||
|
num_cached_tokens=2, # 2 cached tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Append the output to the context
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# Verify the token counts
|
||||||
|
assert context.num_prompt_tokens == 5
|
||||||
|
assert context.num_output_tokens == 3
|
||||||
|
assert context.num_cached_tokens == 2
|
||||||
|
assert context.num_tool_output_tokens == 0 # No tool tokens in first turn
|
||||||
|
|
||||||
|
# Verify internal state tracking
|
||||||
|
assert not context.is_first_turn
|
||||||
|
assert context.previous_turn.input_tokens == 5
|
||||||
|
assert context.previous_turn.output_tokens == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_turn_token_counting():
|
||||||
|
"""Test token counting behavior across multiple turns with tool output."""
|
||||||
|
# Create a context
|
||||||
|
context = HarmonyContext(messages=[], available_tools=["browser"])
|
||||||
|
|
||||||
|
# Simulate a conversation with 3 turns
|
||||||
|
# Turn 1: prefill 5, decode 3, tool 7
|
||||||
|
# Turn 2: prefill 15, cached 5, decode 4, tool 1
|
||||||
|
# Turn 3: prefill 20, cached 15, decode 5
|
||||||
|
prompt_token_counts = [5, 15, 20]
|
||||||
|
output_token_counts = [3, 4, 5]
|
||||||
|
cached_token_counts = [0, 5, 15]
|
||||||
|
mock_generator = generate_mock_outputs(3, prompt_token_counts,
|
||||||
|
output_token_counts,
|
||||||
|
cached_token_counts)
|
||||||
|
|
||||||
|
# First turn - initial prompt and response
|
||||||
|
mock_output1 = await async_next(mock_generator)
|
||||||
|
context.append_output(mock_output1)
|
||||||
|
|
||||||
|
# At this point, we should have 5 prompt tokens and 3 output tokens
|
||||||
|
assert context.num_prompt_tokens == 5
|
||||||
|
assert context.num_output_tokens == 3
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
|
||||||
|
# Second turn - after tool output
|
||||||
|
mock_output2 = await async_next(mock_generator)
|
||||||
|
context.append_output(mock_output2)
|
||||||
|
# Current prompt tokens (15) - last_turn_input_tokens (5) -
|
||||||
|
# last_turn_output_tokens (3) = 7
|
||||||
|
expected_tool_output = 7
|
||||||
|
|
||||||
|
assert context.num_prompt_tokens == 5 + 15
|
||||||
|
assert context.num_output_tokens == 3 + 4
|
||||||
|
assert context.num_tool_output_tokens == expected_tool_output
|
||||||
|
assert context.num_cached_tokens == 5
|
||||||
|
|
||||||
|
# Third turn - final response
|
||||||
|
mock_output3 = await async_next(mock_generator)
|
||||||
|
context.append_output(mock_output3)
|
||||||
|
# Additional tool output tokens from third turn:
|
||||||
|
# Current prompt (20) - last_turn_input_tokens (15) -
|
||||||
|
# last_turn_output_tokens (4) = 1
|
||||||
|
expected_tool_output = 7 + 1
|
||||||
|
|
||||||
|
assert context.num_prompt_tokens == 5 + 15 + 20
|
||||||
|
assert context.num_output_tokens == 3 + 4 + 5
|
||||||
|
assert context.num_tool_output_tokens == expected_tool_output
|
||||||
|
assert context.num_cached_tokens == 5 + 15
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_output_tokens():
|
||||||
|
"""Test behavior when RequestOutput has empty output tokens."""
|
||||||
|
context = HarmonyContext(messages=[], available_tools=[])
|
||||||
|
|
||||||
|
# Create a RequestOutput with empty output tokens
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
|
||||||
|
output_token_ids=[], # Empty output tokens list
|
||||||
|
num_cached_tokens=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# Should handle empty outputs gracefully
|
||||||
|
assert context.num_prompt_tokens == 3
|
||||||
|
assert context.num_output_tokens == 0 # No output tokens
|
||||||
|
assert context.num_cached_tokens == 1
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_prompt_token_ids():
|
||||||
|
"""Test behavior when RequestOutput has None prompt_token_ids."""
|
||||||
|
context = HarmonyContext(messages=[], available_tools=[])
|
||||||
|
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=None, # No prompt token IDs
|
||||||
|
output_token_ids=[1, 2], # 2 output tokens
|
||||||
|
num_cached_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logger.error will be called, but we don't need to check for warnings
|
||||||
|
# here Just ensure it doesn't raise an exception
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# Should handle missing prompt tokens gracefully
|
||||||
|
assert context.num_prompt_tokens == 0
|
||||||
|
assert context.num_output_tokens == 2
|
||||||
|
assert context.num_cached_tokens == 0
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_reasoning_tokens_counting(mock_parser):
|
||||||
|
"""Test that reasoning tokens are counted correctly."""
|
||||||
|
context = HarmonyContext(messages=[], available_tools=[])
|
||||||
|
|
||||||
|
# Mock parser to simulate reasoning channel
|
||||||
|
mock_parser.current_channel = "analysis" # Reasoning channel
|
||||||
|
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning
|
||||||
|
num_cached_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# All output tokens should be counted as reasoning
|
||||||
|
assert context.num_reasoning_tokens == 4
|
||||||
|
assert context.num_output_tokens == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_tokens_edge_case():
|
||||||
|
"""Test behavior with all zero token counts."""
|
||||||
|
context = HarmonyContext(messages=[], available_tools=[])
|
||||||
|
|
||||||
|
# Create a request with empty lists (not None) for both prompt and
|
||||||
|
# output tokens
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=[], # Empty prompt tokens
|
||||||
|
output_token_ids=[], # Empty output tokens
|
||||||
|
num_cached_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# All counts should be zero
|
||||||
|
assert context.num_prompt_tokens == 0
|
||||||
|
assert context.num_output_tokens == 0
|
||||||
|
assert context.num_cached_tokens == 0
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
assert context.num_reasoning_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_turn_no_tool_output():
|
||||||
|
"""Test that first turn never generates tool output tokens."""
|
||||||
|
context = HarmonyContext(
|
||||||
|
messages=[],
|
||||||
|
available_tools=["browser"] # Tools available
|
||||||
|
)
|
||||||
|
|
||||||
|
# Even with large prompt in first turn, no tool tokens should be counted
|
||||||
|
mock_output = create_mock_request_output(
|
||||||
|
prompt_token_ids=list(range(100)), # 100 tokens
|
||||||
|
output_token_ids=[1, 2, 3],
|
||||||
|
num_cached_tokens=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
context.append_output(mock_output)
|
||||||
|
|
||||||
|
# First turn should never have tool output tokens
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
assert context.is_first_turn is False # Should be updated after first turn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_negative_tool_tokens_edge_case():
|
||||||
|
"""Test edge case where calculation could result in negative tool
|
||||||
|
tokens. We should log an error and clamp the value to 0."""
|
||||||
|
# Use patch to check if logger.error was called
|
||||||
|
with patch("vllm.entrypoints.context.logger.error") as mock_log:
|
||||||
|
context = HarmonyContext(messages=[], available_tools=["browser"])
|
||||||
|
|
||||||
|
# First turn
|
||||||
|
mock_output1 = create_mock_request_output(
|
||||||
|
prompt_token_ids=list(range(10)), # 10 tokens
|
||||||
|
output_token_ids=[1, 2, 3, 4, 5], # 5 tokens
|
||||||
|
)
|
||||||
|
context.append_output(mock_output1)
|
||||||
|
|
||||||
|
# Second turn with fewer new tokens than previous output
|
||||||
|
# This could happen in edge cases with aggressive caching
|
||||||
|
mock_output2 = create_mock_request_output(
|
||||||
|
prompt_token_ids=list(range(12)), # 12 tokens (only 2 new)
|
||||||
|
output_token_ids=[6, 7], # 2 tokens
|
||||||
|
)
|
||||||
|
context.append_output(mock_output2)
|
||||||
|
|
||||||
|
# Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped
|
||||||
|
# to 0 and an error should be logged
|
||||||
|
assert context.num_tool_output_tokens == 0
|
||||||
|
assert context.num_prompt_tokens == 10 + 12
|
||||||
|
assert context.num_output_tokens == 5 + 2
|
||||||
|
|
||||||
|
# Verify the error was logged properly
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
|
||||||
|
# Extract the actual log message and arguments from the call
|
||||||
|
args, _ = mock_log.call_args
|
||||||
|
log_message = args[0]
|
||||||
|
|
||||||
|
# Check for key parts of the message
|
||||||
|
assert "Negative tool output tokens calculated" in log_message
|
||||||
|
assert "-3" in str(args) # Check that -3 is in the arguments
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_multi_turn_token_counting(mock_parser):
|
||||||
|
"""Test token counting for streaming multi-turn conversations.
|
||||||
|
|
||||||
|
This test focuses on how StreamingHarmonyContext counts tokens in a
|
||||||
|
multi-turn conversation with streaming (token-by-token) outputs and
|
||||||
|
message boundaries.
|
||||||
|
"""
|
||||||
|
# Create a streaming context
|
||||||
|
context = StreamingHarmonyContext(messages=[], available_tools=["browser"])
|
||||||
|
|
||||||
|
# Simulate three turns of conversation:
|
||||||
|
# Turn 1: stream tokens one by one, then finish the message
|
||||||
|
# Turn 2: new prompt, stream more tokens with a reasoning segment
|
||||||
|
# Turn 3: new prompt with tool output and cached tokens
|
||||||
|
|
||||||
|
# First turn: 3 tokens streamed one by one
|
||||||
|
# First token of first turn
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
|
||||||
|
output_token_ids=[101], # Single token
|
||||||
|
num_cached_tokens=0,
|
||||||
|
finished=False, # Not end of message yet
|
||||||
|
))
|
||||||
|
|
||||||
|
# Second token of first turn
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
output_token_ids=[102],
|
||||||
|
finished=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Last token of first turn (finished=True signals end of message)
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
output_token_ids=[103],
|
||||||
|
finished=True, # End of message
|
||||||
|
))
|
||||||
|
|
||||||
|
# Check token counts after first turn
|
||||||
|
assert context.num_prompt_tokens == 3 # Initial prompt tokens
|
||||||
|
assert context.num_output_tokens == 3 # Three output tokens
|
||||||
|
assert context.num_cached_tokens == 0
|
||||||
|
assert context.num_tool_output_tokens == 0 # No tool output in first turn
|
||||||
|
assert context.first_tok_of_message is True # Ready for next message
|
||||||
|
|
||||||
|
# Second turn: reasoning tokens in analysis channel
|
||||||
|
mock_parser.current_channel = "analysis" # Set to reasoning channel
|
||||||
|
|
||||||
|
# First token of second turn
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
prompt_token_ids=[1, 2, 3, 101, 102, 103, 4,
|
||||||
|
5], # 8 tokens (includes previous)
|
||||||
|
output_token_ids=[201],
|
||||||
|
num_cached_tokens=3, # Some tokens cached
|
||||||
|
finished=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
# More tokens in reasoning channel
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
output_token_ids=[202],
|
||||||
|
finished=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
output_token_ids=[203],
|
||||||
|
finished=True, # End of reasoning message
|
||||||
|
))
|
||||||
|
|
||||||
|
# Check counts after second turn (reasoning message)
|
||||||
|
assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt
|
||||||
|
assert context.num_output_tokens == 3 + 3 # First turn + second turn
|
||||||
|
assert context.num_reasoning_tokens == 3 # All tokens in analysis channel
|
||||||
|
assert context.num_cached_tokens == 3 # Cached tokens from second turn
|
||||||
|
|
||||||
|
# Formula: this turn prompt tokens - last turn prompt - last turn output
|
||||||
|
expected_tool_tokens = 8 - 3 - 3 # = 2
|
||||||
|
assert context.num_tool_output_tokens == expected_tool_tokens
|
||||||
|
|
||||||
|
# Third turn: regular output channel
|
||||||
|
mock_parser.current_channel = "final" # Switch back to regular channel
|
||||||
|
|
||||||
|
# Third turn (with more cached tokens)
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
prompt_token_ids=[
|
||||||
|
1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7
|
||||||
|
], # 13 tokens
|
||||||
|
output_token_ids=[301],
|
||||||
|
num_cached_tokens=8, # More cached tokens
|
||||||
|
finished=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
context.append_output(
|
||||||
|
create_mock_request_output(
|
||||||
|
output_token_ids=[302],
|
||||||
|
finished=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Final token counts check
|
||||||
|
assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts
|
||||||
|
assert context.num_output_tokens == 3 + 3 + 2 # All outputs
|
||||||
|
assert context.num_reasoning_tokens == 3 # Unchanged from second turn
|
||||||
|
assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens
|
||||||
|
|
||||||
|
# Additional tool tokens from third turn
|
||||||
|
# Formula: this turn prompt - last turn prompt - last turn output
|
||||||
|
additional_tool_tokens = 13 - 8 - 3 # = 2
|
||||||
|
assert context.num_tool_output_tokens == expected_tool_tokens \
|
||||||
|
+ additional_tool_tokens
|
||||||
@ -3,7 +3,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
@ -21,6 +20,23 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TurnTokens:
|
||||||
|
"""Tracks token counts for a single conversation turn."""
|
||||||
|
|
||||||
|
def __init__(self, input_tokens=0, output_tokens=0):
|
||||||
|
self.input_tokens = input_tokens
|
||||||
|
self.output_tokens = output_tokens
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset counters for a new turn."""
|
||||||
|
self.input_tokens = 0
|
||||||
|
self.output_tokens = 0
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""Create a copy of this turn's token counts."""
|
||||||
|
return TurnTokens(self.input_tokens, self.output_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ConversationContext(ABC):
|
class ConversationContext(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -92,52 +108,124 @@ class HarmonyContext(ConversationContext):
|
|||||||
self.num_init_messages = len(messages)
|
self.num_init_messages = len(messages)
|
||||||
self.num_prompt_tokens = 0
|
self.num_prompt_tokens = 0
|
||||||
self.num_output_tokens = 0
|
self.num_output_tokens = 0
|
||||||
# TODO(woosuk): Implement the following fields.
|
|
||||||
self.num_cached_tokens = 0
|
self.num_cached_tokens = 0
|
||||||
self.num_reasoning_tokens = 0
|
self.num_reasoning_tokens = 0
|
||||||
|
self.num_tool_output_tokens = 0
|
||||||
|
|
||||||
def _update_num_prompt_tokens(self, output: RequestOutput):
|
# Turn tracking - replaces multiple individual tracking variables
|
||||||
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
|
self.current_turn = TurnTokens()
|
||||||
# NOTE: with built-in tools, there might be multiple rounds in
|
self.previous_turn = TurnTokens()
|
||||||
# the conversation, with the full conversation being resent
|
self.is_first_turn = True
|
||||||
# as new prompt each time. Hence the sum.
|
self.first_tok_of_message = True # For streaming support
|
||||||
self.num_prompt_tokens += len(output.prompt_token_ids)
|
|
||||||
|
|
||||||
def _update_num_cached_tokens(self, output: RequestOutput):
|
def _update_num_reasoning_tokens(self):
|
||||||
if output.num_cached_tokens is not None:
|
# Count all analysis and commentary channels as reasoning tokens
|
||||||
#Similar to num_prompt_tokens
|
if self.parser.current_channel in {"analysis", "commentary"}:
|
||||||
self.num_cached_tokens += output.num_cached_tokens
|
self.num_reasoning_tokens += 1
|
||||||
|
|
||||||
def _update_num_output_tokens(self, token_ids: Sequence[int]):
|
|
||||||
self.num_output_tokens += len(token_ids)
|
|
||||||
|
|
||||||
def _update_num_reasoning_tokens(self, token_ids: Sequence[int]):
|
|
||||||
# Count tokens that are part of reasoning content (analysis channel
|
|
||||||
# or tool-directed messages like python/browser calls)
|
|
||||||
is_analysis = self.parser.current_channel == "analysis"
|
|
||||||
is_tool_call = (self.parser.current_recipient is not None and
|
|
||||||
(self.parser.current_recipient.startswith("python") or
|
|
||||||
self.parser.current_recipient.startswith("browser.")))
|
|
||||||
if is_analysis or is_tool_call:
|
|
||||||
self.num_reasoning_tokens += len(token_ids)
|
|
||||||
|
|
||||||
def append_output(self, output) -> None:
|
def append_output(self, output) -> None:
|
||||||
if isinstance(output, RequestOutput):
|
if isinstance(output, RequestOutput):
|
||||||
self._update_num_prompt_tokens(output)
|
|
||||||
self._update_num_cached_tokens(output)
|
|
||||||
output_token_ids = output.outputs[0].token_ids
|
output_token_ids = output.outputs[0].token_ids
|
||||||
self._update_num_output_tokens(output_token_ids)
|
|
||||||
self.parser = get_streamable_parser_for_assistant()
|
self.parser = get_streamable_parser_for_assistant()
|
||||||
for token_id in output_token_ids:
|
for token_id in output_token_ids:
|
||||||
self.parser.process(token_id)
|
self.parser.process(token_id)
|
||||||
# Check if the current token is part of reasoning content
|
# Check if the current token is part of reasoning content
|
||||||
self._update_num_reasoning_tokens([token_id])
|
self._update_num_reasoning_tokens()
|
||||||
|
self._update_prefill_token_usage(output)
|
||||||
|
# Reset current turn output tokens for this turn
|
||||||
|
self.current_turn.output_tokens = 0
|
||||||
|
self._update_decode_token_usage(output)
|
||||||
|
# Move current turn to previous turn for next turn's calculations
|
||||||
|
self.previous_turn = self.current_turn.copy()
|
||||||
output_msgs = self.parser.messages
|
output_msgs = self.parser.messages
|
||||||
else:
|
else:
|
||||||
# Tool output.
|
# Tool output.
|
||||||
output_msgs = output
|
output_msgs = output
|
||||||
self._messages.extend(output_msgs)
|
self._messages.extend(output_msgs)
|
||||||
|
|
||||||
|
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
|
||||||
|
"""Update token usage statistics for the prefill phase of generation.
|
||||||
|
|
||||||
|
The prefill phase processes the input prompt tokens. This method:
|
||||||
|
1. Counts the prompt tokens for this turn
|
||||||
|
2. Calculates tool output tokens for multi-turn conversations
|
||||||
|
3. Updates cached token counts
|
||||||
|
4. Tracks state for next turn calculations
|
||||||
|
|
||||||
|
Tool output tokens are calculated as:
|
||||||
|
current_prompt_tokens - last_turn_prompt_tokens -
|
||||||
|
last_turn_output_tokens
|
||||||
|
This represents tokens added between turns (typically tool responses).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: The RequestOutput containing prompt token information
|
||||||
|
"""
|
||||||
|
if output.prompt_token_ids is not None:
|
||||||
|
this_turn_input_tokens = len(output.prompt_token_ids)
|
||||||
|
else:
|
||||||
|
this_turn_input_tokens = 0
|
||||||
|
logger.error(
|
||||||
|
"RequestOutput appended contains no prompt_token_ids.")
|
||||||
|
|
||||||
|
# Update current turn input tokens
|
||||||
|
self.current_turn.input_tokens = this_turn_input_tokens
|
||||||
|
self.num_prompt_tokens += this_turn_input_tokens
|
||||||
|
|
||||||
|
# Calculate tool tokens (except on first turn)
|
||||||
|
if self.is_first_turn:
|
||||||
|
self.is_first_turn = False
|
||||||
|
else:
|
||||||
|
# start counting tool after first turn
|
||||||
|
# tool tokens = this turn prefill - last turn prefill -
|
||||||
|
# last turn decode
|
||||||
|
this_turn_tool_tokens = (self.current_turn.input_tokens -
|
||||||
|
self.previous_turn.input_tokens -
|
||||||
|
self.previous_turn.output_tokens)
|
||||||
|
|
||||||
|
# Handle negative tool token counts (shouldn't happen in normal
|
||||||
|
# cases)
|
||||||
|
if this_turn_tool_tokens < 0:
|
||||||
|
logger.error(
|
||||||
|
"Negative tool output tokens calculated: %d "
|
||||||
|
"(current_input=%d, previous_input=%d, "
|
||||||
|
"previous_output=%d). Setting to 0.",
|
||||||
|
this_turn_tool_tokens, self.current_turn.input_tokens,
|
||||||
|
self.previous_turn.input_tokens,
|
||||||
|
self.previous_turn.output_tokens)
|
||||||
|
this_turn_tool_tokens = 0
|
||||||
|
|
||||||
|
self.num_tool_output_tokens += this_turn_tool_tokens
|
||||||
|
|
||||||
|
# Update cached tokens
|
||||||
|
if output.num_cached_tokens is not None:
|
||||||
|
self.num_cached_tokens += output.num_cached_tokens
|
||||||
|
|
||||||
|
def _update_decode_token_usage(self, output: RequestOutput) -> int:
|
||||||
|
"""Update token usage statistics for the decode phase of generation.
|
||||||
|
|
||||||
|
The decode phase processes the generated output tokens. This method:
|
||||||
|
1. Counts output tokens from all completion outputs
|
||||||
|
2. Updates the total output token count
|
||||||
|
3. Tracks tokens generated in the current turn
|
||||||
|
|
||||||
|
In streaming mode, this is called for each token generated.
|
||||||
|
In non-streaming mode, this is called once with all output tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: The RequestOutput containing generated token information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of output tokens processed in this call
|
||||||
|
"""
|
||||||
|
updated_output_token_count = 0
|
||||||
|
if output.outputs:
|
||||||
|
for completion_output in output.outputs:
|
||||||
|
# only keep last round
|
||||||
|
updated_output_token_count += len(completion_output.token_ids)
|
||||||
|
self.num_output_tokens += updated_output_token_count
|
||||||
|
self.current_turn.output_tokens += updated_output_token_count
|
||||||
|
return updated_output_token_count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> list:
|
def messages(self) -> list:
|
||||||
return self._messages
|
return self._messages
|
||||||
@ -231,8 +319,8 @@ class StreamingHarmonyContext(HarmonyContext):
|
|||||||
# append_output is called for each output token in streaming case,
|
# append_output is called for each output token in streaming case,
|
||||||
# so we only want to add the prompt tokens once for each message.
|
# so we only want to add the prompt tokens once for each message.
|
||||||
if self.first_tok_of_message:
|
if self.first_tok_of_message:
|
||||||
self._update_num_prompt_tokens(output)
|
self._update_prefill_token_usage(output)
|
||||||
self._update_num_cached_tokens(output)
|
self.current_turn.output_tokens = 0
|
||||||
# Reset self.first_tok_of_message if needed:
|
# Reset self.first_tok_of_message if needed:
|
||||||
# if the current token is the last one of the current message
|
# if the current token is the last one of the current message
|
||||||
# (finished=True), then the next token processed will mark the
|
# (finished=True), then the next token processed will mark the
|
||||||
@ -240,9 +328,13 @@ class StreamingHarmonyContext(HarmonyContext):
|
|||||||
self.first_tok_of_message = output.finished
|
self.first_tok_of_message = output.finished
|
||||||
for tok in output.outputs[0].token_ids:
|
for tok in output.outputs[0].token_ids:
|
||||||
self.parser.process(tok)
|
self.parser.process(tok)
|
||||||
self._update_num_output_tokens(output.outputs[0].token_ids)
|
self._update_decode_token_usage(output)
|
||||||
|
|
||||||
|
# For streaming, update previous turn when message is complete
|
||||||
|
if output.finished:
|
||||||
|
self.previous_turn = self.current_turn.copy()
|
||||||
# Check if the current token is part of reasoning content
|
# Check if the current token is part of reasoning content
|
||||||
self._update_num_reasoning_tokens(output.outputs[0].token_ids)
|
self._update_num_reasoning_tokens()
|
||||||
self.last_tok = tok
|
self.last_tok = tok
|
||||||
else:
|
else:
|
||||||
# Handle the case of tool output in direct message format
|
# Handle the case of tool output in direct message format
|
||||||
|
|||||||
@ -1841,7 +1841,8 @@ class InputTokensDetails(OpenAIBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class OutputTokensDetails(OpenAIBaseModel):
|
class OutputTokensDetails(OpenAIBaseModel):
|
||||||
reasoning_tokens: int
|
reasoning_tokens: int = 0
|
||||||
|
tool_output_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ResponseUsage(OpenAIBaseModel):
|
class ResponseUsage(OpenAIBaseModel):
|
||||||
|
|||||||
@ -460,7 +460,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
assert isinstance(context, HarmonyContext)
|
assert isinstance(context, HarmonyContext)
|
||||||
output = self._make_response_output_items_with_harmony(context)
|
output = self._make_response_output_items_with_harmony(context)
|
||||||
# TODO: these are all 0 for now!
|
num_tool_output_tokens = context.num_tool_output_tokens
|
||||||
else:
|
else:
|
||||||
assert isinstance(context, SimpleContext)
|
assert isinstance(context, SimpleContext)
|
||||||
final_res = context.last_output
|
final_res = context.last_output
|
||||||
@ -473,6 +473,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
# Calculate usage.
|
# Calculate usage.
|
||||||
assert final_res.prompt_token_ids is not None
|
assert final_res.prompt_token_ids is not None
|
||||||
|
num_tool_output_tokens = 0
|
||||||
|
|
||||||
assert isinstance(context, (SimpleContext, HarmonyContext))
|
assert isinstance(context, (SimpleContext, HarmonyContext))
|
||||||
num_prompt_tokens = context.num_prompt_tokens
|
num_prompt_tokens = context.num_prompt_tokens
|
||||||
num_generated_tokens = context.num_output_tokens
|
num_generated_tokens = context.num_output_tokens
|
||||||
@ -486,7 +488,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
input_tokens_details=InputTokensDetails(
|
input_tokens_details=InputTokensDetails(
|
||||||
cached_tokens=num_cached_tokens),
|
cached_tokens=num_cached_tokens),
|
||||||
output_tokens_details=OutputTokensDetails(
|
output_tokens_details=OutputTokensDetails(
|
||||||
reasoning_tokens=num_reasoning_tokens),
|
reasoning_tokens=num_reasoning_tokens,
|
||||||
|
tool_output_tokens=num_tool_output_tokens),
|
||||||
)
|
)
|
||||||
response = ResponsesResponse.from_request(
|
response = ResponsesResponse.from_request(
|
||||||
request,
|
request,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user