[frontend][gptoss] Add per turn stats into Harmony Context (#25061)

Signed-off-by: lacora <hyelacora@gmail.com>
Co-authored-by: Ye Hu <yehu@fb.com>
This commit is contained in:
Ye Hu 2025-10-14 16:48:13 -07:00 committed by GitHub
parent 7e0ef4084a
commit 0512c04aee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 188 additions and 62 deletions

View File

@ -6,7 +6,11 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext from vllm.entrypoints.context import (
HarmonyContext,
StreamingHarmonyContext,
TurnMetrics,
)
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
@ -101,8 +105,12 @@ def test_single_turn_token_counting():
# Verify internal state tracking # Verify internal state tracking
assert not context.is_first_turn assert not context.is_first_turn
assert context.previous_turn.input_tokens == 5 assert len(context.all_turn_metrics) == 1
assert context.previous_turn.output_tokens == 3 previous_turn = context.all_turn_metrics[0]
assert previous_turn.input_tokens == 5
assert previous_turn.output_tokens == 3
assert previous_turn.cached_input_tokens == 2
assert previous_turn.tool_output_tokens == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -156,6 +164,15 @@ async def test_multi_turn_token_counting():
assert context.num_tool_output_tokens == expected_tool_output assert context.num_tool_output_tokens == expected_tool_output
assert context.num_cached_tokens == 5 + 15 assert context.num_cached_tokens == 5 + 15
# Validate all turn metrics
assert len(context.all_turn_metrics) == 3
for i, turn in enumerate(context.all_turn_metrics):
assert turn.input_tokens == prompt_token_counts[i]
assert turn.output_tokens == output_token_counts[i]
assert turn.cached_input_tokens == cached_token_counts[i]
assert context.all_turn_metrics[1].tool_output_tokens == 7
assert context.all_turn_metrics[2].tool_output_tokens == 1
def test_empty_output_tokens(): def test_empty_output_tokens():
"""Test behavior when RequestOutput has empty output tokens.""" """Test behavior when RequestOutput has empty output tokens."""
@ -314,6 +331,10 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
# Create a streaming context # Create a streaming context
context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) context = StreamingHarmonyContext(messages=[], available_tools=["browser"])
num_prompt_tokens = [3, 8, 13]
num_output_tokens = [3, 3, 2]
num_cached_tokens = [0, 3, 8]
# Simulate three turns of conversation: # Simulate three turns of conversation:
# Turn 1: stream tokens one by one, then finish the message # Turn 1: stream tokens one by one, then finish the message
# Turn 2: new prompt, stream more tokens with a reasoning segment # Turn 2: new prompt, stream more tokens with a reasoning segment
@ -325,7 +346,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
create_mock_request_output( create_mock_request_output(
prompt_token_ids=[1, 2, 3], # 3 prompt tokens prompt_token_ids=[1, 2, 3], # 3 prompt tokens
output_token_ids=[101], # Single token output_token_ids=[101], # Single token
num_cached_tokens=0, num_cached_tokens=num_cached_tokens[0],
finished=False, # Not end of message yet finished=False, # Not end of message yet
) )
) )
@ -370,7 +391,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
5, 5,
], # 8 tokens (includes previous) ], # 8 tokens (includes previous)
output_token_ids=[201], output_token_ids=[201],
num_cached_tokens=3, # Some tokens cached num_cached_tokens=num_cached_tokens[1], # Some tokens cached
finished=False, finished=False,
) )
) )
@ -422,7 +443,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
7, 7,
], # 13 tokens ], # 13 tokens
output_token_ids=[301], output_token_ids=[301],
num_cached_tokens=8, # More cached tokens num_cached_tokens=num_cached_tokens[2], # More cached tokens
finished=False, finished=False,
) )
) )
@ -435,10 +456,12 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
) )
# Final token counts check # Final token counts check
assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts assert context.num_prompt_tokens == sum(num_prompt_tokens) # All prompts
assert context.num_output_tokens == 3 + 3 + 2 # All outputs assert context.num_output_tokens == sum(num_output_tokens) # All outputs
assert context.num_reasoning_tokens == 3 # Unchanged from second turn assert context.num_reasoning_tokens == 3 # Unchanged from second turn
assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens assert context.num_cached_tokens == sum(
num_cached_tokens
) # Accumulated cached tokens
# Additional tool tokens from third turn # Additional tool tokens from third turn
# Formula: this turn prompt - last turn prompt - last turn output # Formula: this turn prompt - last turn prompt - last turn output
@ -447,6 +470,15 @@ async def test_streaming_multi_turn_token_counting(mock_parser):
context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens
) )
# Validate all turn metrics
assert len(context.all_turn_metrics) == 3
for i, turn in enumerate(context.all_turn_metrics):
assert turn.input_tokens == num_prompt_tokens[i]
assert turn.output_tokens == num_output_tokens[i]
assert turn.cached_input_tokens == num_cached_tokens[i]
assert context.all_turn_metrics[1].tool_output_tokens == 2
assert context.all_turn_metrics[2].tool_output_tokens == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_streaming_message_synchronization(mock_parser): async def test_streaming_message_synchronization(mock_parser):
@ -522,3 +554,46 @@ async def test_streaming_message_synchronization(mock_parser):
assert len(context._messages) == 3 assert len(context._messages) == 3
assert context.num_init_messages == 1 assert context.num_init_messages == 1
assert context._messages[2].content[0].text == "Response 4" assert context._messages[2].content[0].text == "Response 4"
def test_turn_metrics_copy_and_reset():
"""Test TurnMetrics copy and reset methods work correctly."""
# Create a TurnMetrics with specific values
original_metrics = TurnMetrics(
input_tokens=10,
output_tokens=20,
cached_input_tokens=5,
tool_output_tokens=3,
)
# Test copy functionality
copied_metrics = original_metrics.copy()
# Verify copy has same values
assert copied_metrics.input_tokens == 10
assert copied_metrics.output_tokens == 20
assert copied_metrics.cached_input_tokens == 5
assert copied_metrics.tool_output_tokens == 3
# Verify they are separate objects
assert copied_metrics is not original_metrics
# Modify copy to ensure independence
copied_metrics.input_tokens = 999
assert original_metrics.input_tokens == 10 # Original unchanged
assert copied_metrics.input_tokens == 999
# Test reset functionality
original_metrics.reset()
# Verify all fields are reset to zero
assert original_metrics.input_tokens == 0
assert original_metrics.output_tokens == 0
assert original_metrics.cached_input_tokens == 0
assert original_metrics.tool_output_tokens == 0
# Verify copied metrics are unaffected by reset
assert copied_metrics.input_tokens == 999
assert copied_metrics.output_tokens == 20
assert copied_metrics.cached_input_tokens == 5
assert copied_metrics.tool_output_tokens == 3

View File

@ -45,21 +45,36 @@ def _map_tool_name_to_tool_type(tool_name: str) -> str:
return _TOOL_NAME_TO_TYPE_MAP[tool_name] return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnTokens: class TurnMetrics:
"""Tracks token counts for a single conversation turn.""" """Tracks token and toolcall details for a single conversation turn."""
def __init__(self, input_tokens=0, output_tokens=0): def __init__(
self,
input_tokens=0,
output_tokens=0,
cached_input_tokens=0,
tool_output_tokens=0,
):
self.input_tokens = input_tokens self.input_tokens = input_tokens
self.output_tokens = output_tokens self.output_tokens = output_tokens
self.cached_input_tokens = cached_input_tokens
self.tool_output_tokens = tool_output_tokens
def reset(self): def reset(self):
"""Reset counters for a new turn.""" """Reset counters for a new turn."""
self.input_tokens = 0 self.input_tokens = 0
self.output_tokens = 0 self.output_tokens = 0
self.cached_input_tokens = 0
self.tool_output_tokens = 0
def copy(self): def copy(self):
"""Create a copy of this turn's token counts.""" """Create a copy of this turn's token counts."""
return TurnTokens(self.input_tokens, self.output_tokens) return TurnMetrics(
self.input_tokens,
self.output_tokens,
self.cached_input_tokens,
self.tool_output_tokens,
)
class ConversationContext(ABC): class ConversationContext(ABC):
@ -102,6 +117,8 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = 0 self.num_cached_tokens = 0
# todo num_reasoning_tokens is not implemented yet. # todo num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0 self.num_reasoning_tokens = 0
# not implemented yet for SimpleContext
self.all_turn_metrics = []
def append_output(self, output) -> None: def append_output(self, output) -> None:
self.last_output = output self.last_output = output
@ -154,8 +171,9 @@ class HarmonyContext(ConversationContext):
self.num_tool_output_tokens = 0 self.num_tool_output_tokens = 0
# Turn tracking - replaces multiple individual tracking variables # Turn tracking - replaces multiple individual tracking variables
self.current_turn = TurnTokens() self.current_turn_metrics = TurnMetrics()
self.previous_turn = TurnTokens() # Track metrics for all turns
self.all_turn_metrics: list[TurnMetrics] = []
self.is_first_turn = True self.is_first_turn = True
self.first_tok_of_message = True # For streaming support self.first_tok_of_message = True # For streaming support
@ -173,11 +191,10 @@ class HarmonyContext(ConversationContext):
# Check if the current token is part of reasoning content # Check if the current token is part of reasoning content
self._update_num_reasoning_tokens() self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output) self._update_prefill_token_usage(output)
# Reset current turn output tokens for this turn
self.current_turn.output_tokens = 0
self._update_decode_token_usage(output) self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations # Append current turn to all turn list for next turn's calculations
self.previous_turn = self.current_turn.copy() self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# append_output is called only once before tool calling # append_output is called only once before tool calling
# in non-streaming case # in non-streaming case
# so we can append all the parser messages to _messages # so we can append all the parser messages to _messages
@ -213,20 +230,21 @@ class HarmonyContext(ConversationContext):
logger.error("RequestOutput appended contains no prompt_token_ids.") logger.error("RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens # Update current turn input tokens
self.current_turn.input_tokens = this_turn_input_tokens self.current_turn_metrics.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn) # Calculate tool tokens (except on first turn)
if self.is_first_turn: if self.is_first_turn:
self.is_first_turn = False self.is_first_turn = False
else: else:
previous_turn = self.all_turn_metrics[-1]
# start counting tool after first turn # start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill - # tool tokens = this turn prefill - last turn prefill -
# last turn decode # last turn decode
this_turn_tool_tokens = ( this_turn_tool_tokens = (
self.current_turn.input_tokens self.current_turn_metrics.input_tokens
- self.previous_turn.input_tokens - previous_turn.input_tokens
- self.previous_turn.output_tokens - previous_turn.output_tokens
) )
# Handle negative tool token counts (shouldn't happen in normal # Handle negative tool token counts (shouldn't happen in normal
@ -237,17 +255,20 @@ class HarmonyContext(ConversationContext):
"(current_input=%d, previous_input=%d, " "(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.", "previous_output=%d). Setting to 0.",
this_turn_tool_tokens, this_turn_tool_tokens,
self.current_turn.input_tokens, self.current_turn_metrics.input_tokens,
self.previous_turn.input_tokens, previous_turn.input_tokens,
self.previous_turn.output_tokens, previous_turn.output_tokens,
) )
this_turn_tool_tokens = 0 this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens self.num_tool_output_tokens += this_turn_tool_tokens
self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens
# Update cached tokens # Update cached tokens
if output.num_cached_tokens is not None: num_cached_token = output.num_cached_tokens
self.num_cached_tokens += output.num_cached_tokens if num_cached_token is not None:
self.num_cached_tokens += num_cached_token
self.current_turn_metrics.cached_input_tokens = num_cached_token
def _update_decode_token_usage(self, output: RequestOutput) -> int: def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation. """Update token usage statistics for the decode phase of generation.
@ -272,7 +293,7 @@ class HarmonyContext(ConversationContext):
# only keep last round # only keep last round
updated_output_token_count += len(completion_output.token_ids) updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count self.num_output_tokens += updated_output_token_count
self.current_turn.output_tokens += updated_output_token_count self.current_turn_metrics.output_tokens += updated_output_token_count
return updated_output_token_count return updated_output_token_count
@property @property
@ -452,7 +473,6 @@ class StreamingHarmonyContext(HarmonyContext):
# so we only want to add the prompt tokens once for each message. # so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message: if self.first_tok_of_message:
self._update_prefill_token_usage(output) self._update_prefill_token_usage(output)
self.current_turn.output_tokens = 0
# Reset self.first_tok_of_message if needed: # Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message # if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the # (finished=True), then the next token processed will mark the
@ -464,7 +484,8 @@ class StreamingHarmonyContext(HarmonyContext):
# For streaming, update previous turn when message is complete # For streaming, update previous turn when message is complete
if output.finished: if output.finished:
self.previous_turn = self.current_turn.copy() self.all_turn_metrics.append(self.current_turn_metrics.copy())
self.current_turn_metrics.reset()
# Check if the current token is part of reasoning content # Check if the current token is part of reasoning content
self._update_num_reasoning_tokens() self._update_num_reasoning_tokens()
self.last_tok = tok self.last_tok = tok

View File

@ -2103,11 +2103,15 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
class InputTokensDetails(OpenAIBaseModel): class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int cached_tokens: int
input_tokens_per_turn: list[int] = Field(default_factory=list)
cached_tokens_per_turn: list[int] = Field(default_factory=list)
class OutputTokensDetails(OpenAIBaseModel): class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int = 0 reasoning_tokens: int = 0
tool_output_tokens: int = 0 tool_output_tokens: int = 0
output_tokens_per_turn: list[int] = Field(default_factory=list)
tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
class ResponseUsage(OpenAIBaseModel): class ResponseUsage(OpenAIBaseModel):

View File

@ -589,10 +589,24 @@ class OpenAIServingResponses(OpenAIServing):
input_tokens=num_prompt_tokens, input_tokens=num_prompt_tokens,
output_tokens=num_generated_tokens, output_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens,
input_tokens_details=InputTokensDetails(cached_tokens=num_cached_tokens), input_tokens_details=InputTokensDetails(
cached_tokens=num_cached_tokens,
input_tokens_per_turn=[
turn.input_tokens for turn in context.all_turn_metrics
],
cached_tokens_per_turn=[
turn.cached_input_tokens for turn in context.all_turn_metrics
],
),
output_tokens_details=OutputTokensDetails( output_tokens_details=OutputTokensDetails(
reasoning_tokens=num_reasoning_tokens, reasoning_tokens=num_reasoning_tokens,
tool_output_tokens=num_tool_output_tokens, tool_output_tokens=num_tool_output_tokens,
output_tokens_per_turn=[
turn.output_tokens for turn in context.all_turn_metrics
],
tool_output_tokens_per_turn=[
turn.tool_output_tokens for turn in context.all_turn_metrics
],
), ),
) )
response = ResponsesResponse.from_request( response = ResponsesResponse.from_request(
@ -665,11 +679,13 @@ class OpenAIServingResponses(OpenAIServing):
token=text, token=text,
logprob=max(token_logprob.logprob, -9999.0), logprob=max(token_logprob.logprob, -9999.0),
bytes=list(text.encode("utf-8", errors="replace")), bytes=list(text.encode("utf-8", errors="replace")),
top_logprobs=self._topk_logprobs( top_logprobs=(
logprob, top_logprobs=top_logprobs, tokenizer=tokenizer self._topk_logprobs(
) logprob, top_logprobs=top_logprobs, tokenizer=tokenizer
if top_logprobs )
else [], if top_logprobs
else []
),
) )
) )
return out return out
@ -758,14 +774,16 @@ class OpenAIServingResponses(OpenAIServing):
text=content, text=content,
annotations=[], # TODO annotations=[], # TODO
type="output_text", type="output_text",
logprobs=self._create_response_logprobs( logprobs=(
token_ids=final_output.token_ids, self._create_response_logprobs(
logprobs=final_output.logprobs, token_ids=final_output.token_ids,
tokenizer=tokenizer, logprobs=final_output.logprobs,
top_logprobs=request.top_logprobs, tokenizer=tokenizer,
) top_logprobs=request.top_logprobs,
if request.is_include_output_logprobs() )
else None, if request.is_include_output_logprobs()
else None
),
) )
message = ResponseOutputMessage( message = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
@ -870,15 +888,21 @@ class OpenAIServingResponses(OpenAIServing):
with_custom_tools = has_custom_tools(tool_types) with_custom_tools = has_custom_tools(tool_types)
sys_msg = get_system_message( sys_msg = get_system_message(
reasoning_effort=reasoning_effort, reasoning_effort=reasoning_effort,
browser_description=self.tool_server.get_tool_description("browser") browser_description=(
if enable_browser and self.tool_server is not None self.tool_server.get_tool_description("browser")
else None, if enable_browser and self.tool_server is not None
python_description=self.tool_server.get_tool_description("python") else None
if enable_code_interpreter and self.tool_server is not None ),
else None, python_description=(
container_description=self.tool_server.get_tool_description("container") self.tool_server.get_tool_description("python")
if enable_container and self.tool_server is not None if enable_code_interpreter and self.tool_server is not None
else None, else None
),
container_description=(
self.tool_server.get_tool_description("container")
if enable_container and self.tool_server is not None
else None
),
instructions=request.instructions, instructions=request.instructions,
with_custom_tools=with_custom_tools, with_custom_tools=with_custom_tools,
) )
@ -1283,14 +1307,16 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
delta=delta_message.content, delta=delta_message.content,
logprobs=self._create_stream_response_logprobs( logprobs=(
token_ids=output.token_ids, self._create_stream_response_logprobs(
logprobs=output.logprobs, token_ids=output.token_ids,
tokenizer=tokenizer, logprobs=output.logprobs,
top_logprobs=request.top_logprobs, tokenizer=tokenizer,
) top_logprobs=request.top_logprobs,
if request.is_include_output_logprobs() )
else [], if request.is_include_output_logprobs()
else []
),
) )
) )
current_content_index += 1 current_content_index += 1