diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index b0faa870a927..31ea856224f9 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -6,7 +6,11 @@ from unittest.mock import MagicMock, patch import pytest from openai_harmony import Author, Message, Role, StreamState, TextContent -from vllm.entrypoints.context import HarmonyContext, StreamingHarmonyContext +from vllm.entrypoints.context import ( + HarmonyContext, + StreamingHarmonyContext, + TurnMetrics, +) from vllm.outputs import CompletionOutput, RequestOutput @@ -101,8 +105,12 @@ def test_single_turn_token_counting(): # Verify internal state tracking assert not context.is_first_turn - assert context.previous_turn.input_tokens == 5 - assert context.previous_turn.output_tokens == 3 + assert len(context.all_turn_metrics) == 1 + previous_turn = context.all_turn_metrics[0] + assert previous_turn.input_tokens == 5 + assert previous_turn.output_tokens == 3 + assert previous_turn.cached_input_tokens == 2 + assert previous_turn.tool_output_tokens == 0 @pytest.mark.asyncio @@ -156,6 +164,15 @@ async def test_multi_turn_token_counting(): assert context.num_tool_output_tokens == expected_tool_output assert context.num_cached_tokens == 5 + 15 + # Validate all turn metrics + assert len(context.all_turn_metrics) == 3 + for i, turn in enumerate(context.all_turn_metrics): + assert turn.input_tokens == prompt_token_counts[i] + assert turn.output_tokens == output_token_counts[i] + assert turn.cached_input_tokens == cached_token_counts[i] + assert context.all_turn_metrics[1].tool_output_tokens == 7 + assert context.all_turn_metrics[2].tool_output_tokens == 1 + def test_empty_output_tokens(): """Test behavior when RequestOutput has empty output tokens.""" @@ -314,6 +331,10 @@ async def test_streaming_multi_turn_token_counting(mock_parser): # Create a streaming context context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + num_prompt_tokens = [3, 8, 13] + num_output_tokens = [3, 3, 2] + num_cached_tokens = [0, 3, 8] + # Simulate three turns of conversation: # Turn 1: stream tokens one by one, then finish the message # Turn 2: new prompt, stream more tokens with a reasoning segment @@ -325,7 +346,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): create_mock_request_output( prompt_token_ids=[1, 2, 3], # 3 prompt tokens output_token_ids=[101], # Single token - num_cached_tokens=0, + num_cached_tokens=num_cached_tokens[0], finished=False, # Not end of message yet ) ) @@ -370,7 +391,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): 5, ], # 8 tokens (includes previous) output_token_ids=[201], - num_cached_tokens=3, # Some tokens cached + num_cached_tokens=num_cached_tokens[1], # Some tokens cached finished=False, ) ) @@ -422,7 +443,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): 7, ], # 13 tokens output_token_ids=[301], - num_cached_tokens=8, # More cached tokens + num_cached_tokens=num_cached_tokens[2], # More cached tokens finished=False, ) ) @@ -435,10 +456,12 @@ async def test_streaming_multi_turn_token_counting(mock_parser): ) # Final token counts check - assert context.num_prompt_tokens == 3 + 8 + 13 # All prompts - assert context.num_output_tokens == 3 + 3 + 2 # All outputs + assert context.num_prompt_tokens == sum(num_prompt_tokens) # All prompts + assert context.num_output_tokens == sum(num_output_tokens) # All outputs assert context.num_reasoning_tokens == 3 # Unchanged from second turn - assert context.num_cached_tokens == 3 + 8 # Accumulated cached tokens + assert context.num_cached_tokens == sum( + num_cached_tokens + ) # Accumulated cached tokens # Additional tool tokens from third turn # Formula: this turn prompt - last turn prompt - last turn output @@ -447,6 +470,15 @@ async def test_streaming_multi_turn_token_counting(mock_parser): context.num_tool_output_tokens == expected_tool_tokens + additional_tool_tokens ) + # Validate all turn metrics + assert len(context.all_turn_metrics) == 3 + for i, turn in enumerate(context.all_turn_metrics): + assert turn.input_tokens == num_prompt_tokens[i] + assert turn.output_tokens == num_output_tokens[i] + assert turn.cached_input_tokens == num_cached_tokens[i] + assert context.all_turn_metrics[1].tool_output_tokens == 2 + assert context.all_turn_metrics[2].tool_output_tokens == 2 + @pytest.mark.asyncio async def test_streaming_message_synchronization(mock_parser): @@ -522,3 +554,46 @@ async def test_streaming_message_synchronization(mock_parser): assert len(context._messages) == 3 assert context.num_init_messages == 1 assert context._messages[2].content[0].text == "Response 4" + + +def test_turn_metrics_copy_and_reset(): + """Test TurnMetrics copy and reset methods work correctly.""" + # Create a TurnMetrics with specific values + original_metrics = TurnMetrics( + input_tokens=10, + output_tokens=20, + cached_input_tokens=5, + tool_output_tokens=3, + ) + + # Test copy functionality + copied_metrics = original_metrics.copy() + + # Verify copy has same values + assert copied_metrics.input_tokens == 10 + assert copied_metrics.output_tokens == 20 + assert copied_metrics.cached_input_tokens == 5 + assert copied_metrics.tool_output_tokens == 3 + + # Verify they are separate objects + assert copied_metrics is not original_metrics + + # Modify copy to ensure independence + copied_metrics.input_tokens = 999 + assert original_metrics.input_tokens == 10 # Original unchanged + assert copied_metrics.input_tokens == 999 + + # Test reset functionality + original_metrics.reset() + + # Verify all fields are reset to zero + assert original_metrics.input_tokens == 0 + assert original_metrics.output_tokens == 0 + assert original_metrics.cached_input_tokens == 0 + assert original_metrics.tool_output_tokens == 0 + + # Verify copied metrics are unaffected by reset + assert copied_metrics.input_tokens == 999 + assert copied_metrics.output_tokens == 20 + assert copied_metrics.cached_input_tokens == 5 + assert copied_metrics.tool_output_tokens == 3 diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index c694bcfaaa75..8f94880e431b 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -45,21 +45,36 @@ def _map_tool_name_to_tool_type(tool_name: str) -> str: return _TOOL_NAME_TO_TYPE_MAP[tool_name] -class TurnTokens: - """Tracks token counts for a single conversation turn.""" +class TurnMetrics: + """Tracks token and toolcall details for a single conversation turn.""" - def __init__(self, input_tokens=0, output_tokens=0): + def __init__( + self, + input_tokens=0, + output_tokens=0, + cached_input_tokens=0, + tool_output_tokens=0, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens + self.cached_input_tokens = cached_input_tokens + self.tool_output_tokens = tool_output_tokens def reset(self): """Reset counters for a new turn.""" self.input_tokens = 0 self.output_tokens = 0 + self.cached_input_tokens = 0 + self.tool_output_tokens = 0 def copy(self): """Create a copy of this turn's token counts.""" - return TurnTokens(self.input_tokens, self.output_tokens) + return TurnMetrics( + self.input_tokens, + self.output_tokens, + self.cached_input_tokens, + self.tool_output_tokens, + ) class ConversationContext(ABC): @@ -102,6 +117,8 @@ class SimpleContext(ConversationContext): self.num_cached_tokens = 0 # todo num_reasoning_tokens is not implemented yet. self.num_reasoning_tokens = 0 + # not implemented yet for SimpleContext + self.all_turn_metrics = [] def append_output(self, output) -> None: self.last_output = output @@ -154,8 +171,9 @@ class HarmonyContext(ConversationContext): self.num_tool_output_tokens = 0 # Turn tracking - replaces multiple individual tracking variables - self.current_turn = TurnTokens() - self.previous_turn = TurnTokens() + self.current_turn_metrics = TurnMetrics() + # Track metrics for all turns + self.all_turn_metrics: list[TurnMetrics] = [] self.is_first_turn = True self.first_tok_of_message = True # For streaming support @@ -173,11 +191,10 @@ class HarmonyContext(ConversationContext): # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self._update_prefill_token_usage(output) - # Reset current turn output tokens for this turn - self.current_turn.output_tokens = 0 self._update_decode_token_usage(output) - # Move current turn to previous turn for next turn's calculations - self.previous_turn = self.current_turn.copy() + # Append current turn to all turn list for next turn's calculations + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() # append_output is called only once before tool calling # in non-streaming case # so we can append all the parser messages to _messages @@ -213,20 +230,21 @@ class HarmonyContext(ConversationContext): logger.error("RequestOutput appended contains no prompt_token_ids.") # Update current turn input tokens - self.current_turn.input_tokens = this_turn_input_tokens + self.current_turn_metrics.input_tokens = this_turn_input_tokens self.num_prompt_tokens += this_turn_input_tokens # Calculate tool tokens (except on first turn) if self.is_first_turn: self.is_first_turn = False else: + previous_turn = self.all_turn_metrics[-1] # start counting tool after first turn # tool tokens = this turn prefill - last turn prefill - # last turn decode this_turn_tool_tokens = ( - self.current_turn.input_tokens - - self.previous_turn.input_tokens - - self.previous_turn.output_tokens + self.current_turn_metrics.input_tokens + - previous_turn.input_tokens + - previous_turn.output_tokens ) # Handle negative tool token counts (shouldn't happen in normal @@ -237,17 +255,20 @@ class HarmonyContext(ConversationContext): "(current_input=%d, previous_input=%d, " "previous_output=%d). Setting to 0.", this_turn_tool_tokens, - self.current_turn.input_tokens, - self.previous_turn.input_tokens, - self.previous_turn.output_tokens, + self.current_turn_metrics.input_tokens, + previous_turn.input_tokens, + previous_turn.output_tokens, ) this_turn_tool_tokens = 0 self.num_tool_output_tokens += this_turn_tool_tokens + self.current_turn_metrics.tool_output_tokens = this_turn_tool_tokens # Update cached tokens - if output.num_cached_tokens is not None: - self.num_cached_tokens += output.num_cached_tokens + num_cached_token = output.num_cached_tokens + if num_cached_token is not None: + self.num_cached_tokens += num_cached_token + self.current_turn_metrics.cached_input_tokens = num_cached_token def _update_decode_token_usage(self, output: RequestOutput) -> int: """Update token usage statistics for the decode phase of generation. @@ -272,7 +293,7 @@ class HarmonyContext(ConversationContext): # only keep last round updated_output_token_count += len(completion_output.token_ids) self.num_output_tokens += updated_output_token_count - self.current_turn.output_tokens += updated_output_token_count + self.current_turn_metrics.output_tokens += updated_output_token_count return updated_output_token_count @property @@ -452,7 +473,6 @@ class StreamingHarmonyContext(HarmonyContext): # so we only want to add the prompt tokens once for each message. if self.first_tok_of_message: self._update_prefill_token_usage(output) - self.current_turn.output_tokens = 0 # Reset self.first_tok_of_message if needed: # if the current token is the last one of the current message # (finished=True), then the next token processed will mark the @@ -464,7 +484,8 @@ class StreamingHarmonyContext(HarmonyContext): # For streaming, update previous turn when message is complete if output.finished: - self.previous_turn = self.current_turn.copy() + self.all_turn_metrics.append(self.current_turn_metrics.copy()) + self.current_turn_metrics.reset() # Check if the current token is part of reasoning content self._update_num_reasoning_tokens() self.last_tok = tok diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f41fa196acd8..86e1e62ff437 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -2103,11 +2103,15 @@ class TranscriptionStreamResponse(OpenAIBaseModel): class InputTokensDetails(OpenAIBaseModel): cached_tokens: int + input_tokens_per_turn: list[int] = Field(default_factory=list) + cached_tokens_per_turn: list[int] = Field(default_factory=list) class OutputTokensDetails(OpenAIBaseModel): reasoning_tokens: int = 0 tool_output_tokens: int = 0 + output_tokens_per_turn: list[int] = Field(default_factory=list) + tool_output_tokens_per_turn: list[int] = Field(default_factory=list) class ResponseUsage(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 51e2856a5a9d..6cdabff6e709 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -589,10 +589,24 @@ class OpenAIServingResponses(OpenAIServing): input_tokens=num_prompt_tokens, output_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, - input_tokens_details=InputTokensDetails(cached_tokens=num_cached_tokens), + input_tokens_details=InputTokensDetails( + cached_tokens=num_cached_tokens, + input_tokens_per_turn=[ + turn.input_tokens for turn in context.all_turn_metrics + ], + cached_tokens_per_turn=[ + turn.cached_input_tokens for turn in context.all_turn_metrics + ], + ), output_tokens_details=OutputTokensDetails( reasoning_tokens=num_reasoning_tokens, tool_output_tokens=num_tool_output_tokens, + output_tokens_per_turn=[ + turn.output_tokens for turn in context.all_turn_metrics + ], + tool_output_tokens_per_turn=[ + turn.tool_output_tokens for turn in context.all_turn_metrics + ], ), ) response = ResponsesResponse.from_request( @@ -665,11 +679,13 @@ class OpenAIServingResponses(OpenAIServing): token=text, logprob=max(token_logprob.logprob, -9999.0), bytes=list(text.encode("utf-8", errors="replace")), - top_logprobs=self._topk_logprobs( - logprob, top_logprobs=top_logprobs, tokenizer=tokenizer - ) - if top_logprobs - else [], + top_logprobs=( + self._topk_logprobs( + logprob, top_logprobs=top_logprobs, tokenizer=tokenizer + ) + if top_logprobs + else [] + ), ) ) return out @@ -758,14 +774,16 @@ class OpenAIServingResponses(OpenAIServing): text=content, annotations=[], # TODO type="output_text", - logprobs=self._create_response_logprobs( - token_ids=final_output.token_ids, - logprobs=final_output.logprobs, - tokenizer=tokenizer, - top_logprobs=request.top_logprobs, - ) - if request.is_include_output_logprobs() - else None, + logprobs=( + self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else None + ), ) message = ResponseOutputMessage( id=f"msg_{random_uuid()}", @@ -870,15 +888,21 @@ class OpenAIServingResponses(OpenAIServing): with_custom_tools = has_custom_tools(tool_types) sys_msg = get_system_message( reasoning_effort=reasoning_effort, - browser_description=self.tool_server.get_tool_description("browser") - if enable_browser and self.tool_server is not None - else None, - python_description=self.tool_server.get_tool_description("python") - if enable_code_interpreter and self.tool_server is not None - else None, - container_description=self.tool_server.get_tool_description("container") - if enable_container and self.tool_server is not None - else None, + browser_description=( + self.tool_server.get_tool_description("browser") + if enable_browser and self.tool_server is not None + else None + ), + python_description=( + self.tool_server.get_tool_description("python") + if enable_code_interpreter and self.tool_server is not None + else None + ), + container_description=( + self.tool_server.get_tool_description("container") + if enable_container and self.tool_server is not None + else None + ), instructions=request.instructions, with_custom_tools=with_custom_tools, ) @@ -1283,14 +1307,16 @@ class OpenAIServingResponses(OpenAIServing): output_index=current_output_index, item_id=current_item_id, delta=delta_message.content, - logprobs=self._create_stream_response_logprobs( - token_ids=output.token_ids, - logprobs=output.logprobs, - tokenizer=tokenizer, - top_logprobs=request.top_logprobs, - ) - if request.is_include_output_logprobs() - else [], + logprobs=( + self._create_stream_response_logprobs( + token_ids=output.token_ids, + logprobs=output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else [] + ), ) ) current_content_index += 1