[Bugfix] fix streaming final output for non harmony (#30237)

Signed-off-by: penfree <qiupengfei@baidu.com>
Co-authored-by: penfree <qiupengfei@baidu.com>
This commit is contained in:
penfree 2025-12-16 09:03:11 +08:00 committed by GitHub
parent 511e81e7c9
commit bbd850e597
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 3 deletions

View File

@ -87,3 +87,48 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_streaming_output_consistency(client: OpenAI, model_name: str):
"""Test that streaming delta text matches the final response output_text.
This test verifies that when using streaming mode:
1. The concatenated text from all 'response.output_text.delta' events
2. Matches the 'output_text' in the final 'response.completed' event
"""
response = await client.responses.create(
model=model_name,
input="Say hello in one sentence.",
stream=True,
)
events = []
async for event in response:
events.append(event)
assert len(events) > 0
# Concatenate all delta text from streaming events
streaming_text = "".join(
event.delta for event in events if event.type == "response.output_text.delta"
)
# Get the final response from the last event
response_completed_event = events[-1]
assert response_completed_event.type == "response.completed"
assert response_completed_event.response.status == "completed"
# Get output_text from the final response
final_output_text = response_completed_event.response.output_text
# Verify final response has output
assert len(response_completed_event.response.output) > 0
# Verify streaming text matches final output_text
assert streaming_text == final_output_text, (
f"Streaming text does not match final output_text.\n"
f"Streaming: {streaming_text!r}\n"
f"Final: {final_output_text!r}"
)

View File

@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import copy
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import (
@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
def __init__(self):
self.last_output = None
# Accumulated final output for streaming mode
self._accumulated_text: str = ""
self._accumulated_token_ids: list[int] = []
self._accumulated_logprobs: list = []
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output = output.outputs[0]
self._accumulated_text += delta_output.text
self._accumulated_token_ids.extend(delta_output.token_ids)
if delta_output.logprobs is not None:
self._accumulated_logprobs.extend(delta_output.logprobs)
if len(self.input_messages) == 0:
output_prompt = output.prompt or ""
output_prompt_token_ids = output.prompt_token_ids or []
@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
)
self.output_messages.append(
ResponseRawMessageAndToken(
message=output.outputs[0].text,
tokens=output.outputs[0].token_ids,
message=delta_output.text,
tokens=delta_output.token_ids,
)
)
@property
def final_output(self) -> RequestOutput | None:
"""Return the final output, with complete text/token_ids/logprobs."""
if self.last_output is not None and self.last_output.outputs:
assert isinstance(self.last_output, RequestOutput)
final_output = copy.copy(self.last_output)
# copy inner item to avoid modify last_output
final_output.outputs = [replace(item) for item in self.last_output.outputs]
final_output.outputs[0].text = self._accumulated_text
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
if self._accumulated_logprobs:
final_output.outputs[0].logprobs = self._accumulated_logprobs
return final_output
return self.last_output
def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")

View File

@ -675,7 +675,8 @@ class OpenAIServingResponses(OpenAIServing):
num_tool_output_tokens = 0
else:
assert isinstance(context, SimpleContext)
final_res = context.last_output
# Use final_output which has accumulated text/token_ids/logprobs
final_res = context.final_output
assert final_res is not None
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]