[Frontend][Responses API] Support reporting tool output tokens and fix reasoning token count (#24285)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Ye (Charlotte) Qi 2025-09-06 13:27:15 -07:00 committed by GitHub
parent fb691ee4e7
commit a3645ed94d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 557 additions and 36 deletions

View File

@ -0,0 +1,425 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from openai_harmony import StreamState
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext
from vllm.outputs import CompletionOutput, RequestOutput
# Helper function for Python < 3.10 compatibility
async def async_next(async_iterator):
"""Compatibility function equivalent to Python 3.10's anext()."""
return await async_iterator.__anext__()
def create_mock_request_output(
prompt_token_ids=None,
output_token_ids=None,
num_cached_tokens=0,
finished=True,
):
"""Helper function to create a mock RequestOutput object for testing."""
outputs = []
token_ids = output_token_ids if output_token_ids is not None else []
outputs = [
CompletionOutput(
index=0,
text="Test output",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
finish_reason=None,
stop_reason=None,
)
]
return RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
outputs=outputs,
finished=finished,
num_cached_tokens=num_cached_tokens,
)
async def generate_mock_outputs(num_turns,
prompt_token_counts,
output_token_counts,
cached_token_counts=None):
"""Generate a sequence of mock RequestOutput objects to simulate multiple
turns."""
if cached_token_counts is None:
cached_token_counts = [0] * num_turns
for i in range(num_turns):
# Create mock prompt token IDs and output token IDs
prompt_token_ids = list(range(1, prompt_token_counts[i] + 1))
output_token_ids = list(range(1, output_token_counts[i] + 1))
# Create and yield the RequestOutput
yield create_mock_request_output(
prompt_token_ids=prompt_token_ids,
output_token_ids=output_token_ids,
num_cached_tokens=cached_token_counts[i],
)
@pytest.fixture
def mock_parser():
"""Set up a mock parser for tests."""
with patch("vllm.entrypoints.context.get_streamable_parser_for_assistant"
) as mock_parser_factory:
# Create a mock parser object
parser = MagicMock()
parser.messages = []
parser.current_channel = None
parser.state = StreamState.EXPECT_START
mock_parser_factory.return_value = parser
yield parser
def test_single_turn_token_counting():
"""Test token counting behavior for a single turn."""
# Create a context
context = HarmonyContext(messages=[], available_tools=[])
# Create a mock RequestOutput with specific token counts
mock_output = create_mock_request_output(
prompt_token_ids=[1, 2, 3, 4, 5], # 5 prompt tokens
output_token_ids=[6, 7, 8], # 3 output tokens
num_cached_tokens=2, # 2 cached tokens
)
# Append the output to the context
context.append_output(mock_output)
# Verify the token counts
assert context.num_prompt_tokens == 5
assert context.num_output_tokens == 3
assert context.num_cached_tokens == 2
assert context.num_tool_output_tokens == 0 # No tool tokens in first turn
# Verify internal state tracking
assert not context.is_first_turn
assert context.previous_turn.input_tokens == 5
assert context.previous_turn.output_tokens == 3
@pytest.mark.asyncio
async def test_multi_turn_token_counting():
"""Test token counting behavior across multiple turns with tool output."""
# Create a context
context = HarmonyContext(messages=[], available_tools=["browser"])
# Simulate a conversation with 3 turns
# Turn 1: prefill 5, decode 3, tool 7
# Turn 2: prefill 15, cached 5, decode 4, tool 1
# Turn 3: prefill 20, cached 15, decode 5
prompt_token_counts = [5, 15, 20]
output_token_counts = [3, 4, 5]
cached_token_counts = [0, 5, 15]
mock_generator = generate_mock_outputs(3, prompt_token_counts,
output_token_counts,
cached_token_counts)
# First turn - initial prompt and response
mock_output1 = await async_next(mock_generator)
context.append_output(mock_output1)
# At this point, we should have 5 prompt tokens and 3 output tokens
assert context.num_prompt_tokens == 5
assert context.num_output_tokens == 3
assert context.num_tool_output_tokens == 0
# Second turn - after tool output
mock_output2 = await async_next(mock_generator)
context.append_output(mock_output2)
# Current prompt tokens (15) - last_turn_input_tokens (5) -
# last_turn_output_tokens (3) = 7
expected_tool_output = 7
assert context.num_prompt_tokens == 5 + 15
assert context.num_output_tokens == 3 + 4
assert context.num_tool_output_tokens == expected_tool_output
assert context.num_cached_tokens == 5
# Third turn - final response
mock_output3 = await async_next(mock_generator)
context.append_output(mock_output3)
# Additional tool output tokens from third turn:
# Current prompt (20) - last_turn_input_tokens (15) -
# last_turn_output_tokens (4) = 1
expected_tool_output = 7 + 1
assert context.num_prompt_tokens == 5 + 15 + 20
assert context.num_output_tokens == 3 + 4 + 5
assert context.num_tool_output_tokens == expected_tool_output
assert context.num_cached_tokens == 5 + 15
def test_empty_output_tokens():
"""Test behavior when RequestOutput has empty output tokens."""
context = HarmonyContext(messages=[], available_tools=[])
# Create a RequestOutput with empty output tokens
mock_output = create_mock_request_output(
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
output_token_ids=[], # Empty output tokens list
num_cached_tokens=1,
)
context.append_output(mock_output)
# Should handle empty outputs gracefully
assert context.num_prompt_tokens == 3
assert context.num_output_tokens == 0 # No output tokens
assert context.num_cached_tokens == 1
assert context.num_tool_output_tokens == 0
def test_missing_prompt_token_ids():
"""Test behavior when RequestOutput has None prompt_token_ids."""
context = HarmonyContext(messages=[], available_tools=[])
mock_output = create_mock_request_output(
prompt_token_ids=None, # No prompt token IDs
output_token_ids=[1, 2], # 2 output tokens
num_cached_tokens=0,
)
# Logger.error will be called, but we don't need to check for warnings
# here Just ensure it doesn't raise an exception
context.append_output(mock_output)
# Should handle missing prompt tokens gracefully
assert context.num_prompt_tokens == 0
assert context.num_output_tokens == 2
assert context.num_cached_tokens == 0
assert context.num_tool_output_tokens == 0
def test_reasoning_tokens_counting(mock_parser):
"""Test that reasoning tokens are counted correctly."""
context = HarmonyContext(messages=[], available_tools=[])
# Mock parser to simulate reasoning channel
mock_parser.current_channel = "analysis" # Reasoning channel
mock_output = create_mock_request_output(
prompt_token_ids=[1, 2, 3],
output_token_ids=[4, 5, 6, 7], # 4 tokens, all in reasoning
num_cached_tokens=0,
)
context.append_output(mock_output)
# All output tokens should be counted as reasoning
assert context.num_reasoning_tokens == 4
assert context.num_output_tokens == 4
def test_zero_tokens_edge_case():
"""Test behavior with all zero token counts."""
context = HarmonyContext(messages=[], available_tools=[])
# Create a request with empty lists (not None) for both prompt and
# output tokens
mock_output = create_mock_request_output(
prompt_token_ids=[], # Empty prompt tokens
output_token_ids=[], # Empty output tokens
num_cached_tokens=0,
)
context.append_output(mock_output)
# All counts should be zero
assert context.num_prompt_tokens == 0
assert context.num_output_tokens == 0
assert context.num_cached_tokens == 0
assert context.num_tool_output_tokens == 0
assert context.num_reasoning_tokens == 0
@pytest.mark.asyncio
async def test_single_turn_no_tool_output():
"""Test that first turn never generates tool output tokens."""
context = HarmonyContext(
messages=[],
available_tools=["browser"] # Tools available
)
# Even with large prompt in first turn, no tool tokens should be counted
mock_output = create_mock_request_output(
prompt_token_ids=list(range(100)), # 100 tokens
output_token_ids=[1, 2, 3],
num_cached_tokens=0,
)
context.append_output(mock_output)
# First turn should never have tool output tokens
assert context.num_tool_output_tokens == 0
assert context.is_first_turn is False # Should be updated after first turn
@pytest.mark.asyncio
async def test_negative_tool_tokens_edge_case():
"""Test edge case where calculation could result in negative tool
tokens. We should log an error and clamp the value to 0."""
# Use patch to check if logger.error was called
with patch("vllm.entrypoints.context.logger.error") as mock_log:
context = HarmonyContext(messages=[], available_tools=["browser"])
# First turn
mock_output1 = create_mock_request_output(
prompt_token_ids=list(range(10)), # 10 tokens
output_token_ids=[1, 2, 3, 4, 5], # 5 tokens
)
context.append_output(mock_output1)
# Second turn with fewer new tokens than previous output
# This could happen in edge cases with aggressive caching
mock_output2 = create_mock_request_output(
prompt_token_ids=list(range(12)), # 12 tokens (only 2 new)
output_token_ids=[6, 7], # 2 tokens
)
context.append_output(mock_output2)
# Calculated negative tool tokens (12 - 10 - 5 = -3) should be clamped
# to 0 and an error should be logged
assert context.num_tool_output_tokens == 0
assert context.num_prompt_tokens == 10 + 12
assert context.num_output_tokens == 5 + 2
# Verify the error was logged properly
mock_log.assert_called_once()
# Extract the actual log message and arguments from the call
args, _ = mock_log.call_args
log_message = args[0]
# Check for key parts of the message
assert "Negative tool output tokens calculated" in log_message
assert "-3" in str(args) # Check that -3 is in the arguments
@pytest.mark.asyncio
async def test_streaming_multi_turn_token_counting(mock_parser):
"""Test token counting for streaming multi-turn conversations.
This test focuses on how StreamingHarmonyContext counts tokens in a
multi-turn conversation with streaming (token-by-token) outputs and
message boundaries.
"""
# Create a streaming context
context = StreamingHarmonyContext(messages=[], available_tools=["browser"])
# Simulate three turns of conversation:
# Turn 1: stream tokens one by one, then finish the message
# Turn 2: new prompt, stream more tokens with a reasoning segment
# Turn 3: new prompt with tool output and cached tokens
# First turn: 3 tokens streamed one by one
# First token of first turn
context.append_output(
create_mock_request_output(
prompt_token_ids=[1, 2, 3], # 3 prompt tokens
output_token_ids=[101], # Single token
num_cached_tokens=0,
finished=False, # Not end of message yet
))
# Second token of first turn
context.append_output(
create_mock_request_output(
output_token_ids=[102],
finished=False,
))
# Last token of first turn (finished=True signals end of message)
context.append_output(
create_mock_request_output(
output_token_ids=[103],
finished=True, # End of message
))
# Check token counts after first turn
assert context.num_prompt_tokens == 3 # Initial prompt tokens
assert context.num_output_tokens == 3 # Three output tokens
assert context.num_cached_tokens == 0
assert context.num_tool_output_tokens == 0 # No tool output in first turn
assert context.first_tok_of_message is True # Ready for next message
# Second turn: reasoning tokens in analysis channel
mock_parser.current_channel = "analysis" # Set to reasoning channel
# First token of second turn
context.append_output(
create_mock_request_output(
prompt_token_ids=[1, 2, 3, 101, 102, 103, 4,
5], # 8 tokens (includes previous)
output_token_ids=[201],
num_cached_tokens=3, # Some tokens cached
finished=False,
))
# More tokens in reasoning channel
context.append_output(
create_mock_request_output(
output_token_ids=[202],
finished=False,
))
context.append_output(
create_mock_request_output(
output_token_ids=[203],
finished=True, # End of reasoning message
))
# Check counts after second turn (reasoning message)
assert context.num_prompt_tokens == 3 + 8 # Initial + second prompt
assert context.num_output_tokens == 3 + 3 # First turn + second turn
assert context.num_reasoning_tokens == 3 # All tokens in analysis channel
assert context.num_cached_tokens == 3 # Cached tokens from second turn
# Formula: this turn prompt tokens - last turn prompt - last turn output
expected_tool_tokens = 8 - 3 - 3 # = 2
assert context.num_tool_output_tokens == expected_tool_tokens
# Third turn: regular output channel
mock_parser.current_channel = "final" # Switch back to regular channel
# Third turn (with more cached tokens)
context.append_output(
create_mock_request_output(
prompt_token_ids=[
1, 2, 3, 101, 102, 103, 4, 5, 201, 202, 203, 6, 7
], # 13 tokens
output_token_ids=[301],
num_cached_tokens=8, # More cached tokens
finished=False,
))
context.append_output(
create_mock_request_output(
output_token_ids=[302],
finished=True,
))
# Final token counts check
assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts
assert context.num_output_tokens == 3 + 3 + 2 # All outputs
assert context.num_reasoning_tokens == 3 # Unchanged from second turn
assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens
# Additional tool tokens from third turn
# Formula: this turn prompt - last turn prompt - last turn output
additional_tool_tokens = 13 - 8 - 3 # = 2
assert context.num_tool_output_tokens == expected_tool_tokens \
+ additional_tool_tokens

View File

@ -3,7 +3,6 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
@ -21,6 +20,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
def __init__(self, input_tokens=0, output_tokens=0):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
def reset(self):
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
def copy(self):
"""Create a copy of this turn's token counts."""
return TurnTokens(self.input_tokens, self.output_tokens)
class ConversationContext(ABC):
@abstractmethod
@ -92,52 +108,124 @@ class HarmonyContext(ConversationContext):
self.num_init_messages = len(messages)
self.num_prompt_tokens = 0
self.num_output_tokens = 0
# TODO(woosuk): Implement the following fields.
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
def _update_num_prompt_tokens(self, output: RequestOutput):
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
# NOTE: with built-in tools, there might be multiple rounds in
# the conversation, with the full conversation being resent
# as new prompt each time. Hence the sum.
self.num_prompt_tokens += len(output.prompt_token_ids)
# Turn tracking - replaces multiple individual tracking variables
self.current_turn = TurnTokens()
self.previous_turn = TurnTokens()
self.is_first_turn = True
self.first_tok_of_message = True # For streaming support
def _update_num_cached_tokens(self, output: RequestOutput):
if output.num_cached_tokens is not None:
#Similar to num_prompt_tokens
self.num_cached_tokens += output.num_cached_tokens
def _update_num_output_tokens(self, token_ids: Sequence[int]):
self.num_output_tokens += len(token_ids)
def _update_num_reasoning_tokens(self, token_ids: Sequence[int]):
# Count tokens that are part of reasoning content (analysis channel
# or tool-directed messages like python/browser calls)
is_analysis = self.parser.current_channel == "analysis"
is_tool_call = (self.parser.current_recipient is not None and
(self.parser.current_recipient.startswith("python") or
self.parser.current_recipient.startswith("browser.")))
if is_analysis or is_tool_call:
self.num_reasoning_tokens += len(token_ids)
def _update_num_reasoning_tokens(self):
# Count all analysis and commentary channels as reasoning tokens
if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1
def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
self._update_num_prompt_tokens(output)
self._update_num_cached_tokens(output)
output_token_ids = output.outputs[0].token_ids
self._update_num_output_tokens(output_token_ids)
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens([token_id])
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
# Reset current turn output tokens for this turn
self.current_turn.output_tokens = 0
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
output_msgs = self.parser.messages
else:
# Tool output.
output_msgs = output
self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error(
"RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (self.current_turn.input_tokens -
self.previous_turn.input_tokens -
self.previous_turn.output_tokens)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens, self.current_turn.input_tokens,
self.previous_turn.input_tokens,
self.previous_turn.output_tokens)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
# Update cached tokens
if output.num_cached_tokens is not None:
self.num_cached_tokens += output.num_cached_tokens
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn.output_tokens += updated_output_token_count
return updated_output_token_count
@property
def messages(self) -> list:
return self._messages
@ -231,8 +319,8 @@ class StreamingHarmonyContext(HarmonyContext):
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_num_prompt_tokens(output)
self._update_num_cached_tokens(output)
self._update_prefill_token_usage(output)
self.current_turn.output_tokens = 0
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
@ -240,9 +328,13 @@ class StreamingHarmonyContext(HarmonyContext):
self.first_tok_of_message = output.finished
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
self._update_num_output_tokens(output.outputs[0].token_ids)
self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.previous_turn = self.current_turn.copy()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens(output.outputs[0].token_ids)
self._update_num_reasoning_tokens()
self.last_tok = tok
else:
# Handle the case of tool output in direct message format

View File

@ -1841,7 +1841,8 @@ class InputTokensDetails(OpenAIBaseModel):
class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int
reasoning_tokens: int = 0
tool_output_tokens: int = 0
class ResponseUsage(OpenAIBaseModel):

View File

@ -460,7 +460,7 @@ class OpenAIServingResponses(OpenAIServing):
if self.use_harmony:
assert isinstance(context, HarmonyContext)
output = self._make_response_output_items_with_harmony(context)
# TODO: these are all 0 for now!
num_tool_output_tokens = context.num_tool_output_tokens
else:
assert isinstance(context, SimpleContext)
final_res = context.last_output
@ -473,6 +473,8 @@ class OpenAIServingResponses(OpenAIServing):
# Calculate usage.
assert final_res.prompt_token_ids is not None
num_tool_output_tokens = 0
assert isinstance(context, (SimpleContext, HarmonyContext))
num_prompt_tokens = context.num_prompt_tokens
num_generated_tokens = context.num_output_tokens
@ -486,7 +488,8 @@ class OpenAIServingResponses(OpenAIServing):
input_tokens_details=InputTokensDetails(
cached_tokens=num_cached_tokens),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=num_reasoning_tokens),
reasoning_tokens=num_reasoning_tokens,
tool_output_tokens=num_tool_output_tokens),
)
response = ResponsesResponse.from_request(
request,