mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 13:41:52 +08:00
[Bugfix] fix streaming final output for non harmony (#30237)
Signed-off-by: penfree <qiupengfei@baidu.com> Co-authored-by: penfree <qiupengfei@baidu.com>
This commit is contained in:
parent
511e81e7c9
commit
bbd850e597
@ -87,3 +87,48 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "message"
|
||||
assert type(response.output[1].content[0].text) is str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_streaming_output_consistency(client: OpenAI, model_name: str):
|
||||
"""Test that streaming delta text matches the final response output_text.
|
||||
|
||||
This test verifies that when using streaming mode:
|
||||
1. The concatenated text from all 'response.output_text.delta' events
|
||||
2. Matches the 'output_text' in the final 'response.completed' event
|
||||
"""
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Say hello in one sentence.",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in response:
|
||||
events.append(event)
|
||||
|
||||
assert len(events) > 0
|
||||
|
||||
# Concatenate all delta text from streaming events
|
||||
streaming_text = "".join(
|
||||
event.delta for event in events if event.type == "response.output_text.delta"
|
||||
)
|
||||
|
||||
# Get the final response from the last event
|
||||
response_completed_event = events[-1]
|
||||
assert response_completed_event.type == "response.completed"
|
||||
assert response_completed_event.response.status == "completed"
|
||||
|
||||
# Get output_text from the final response
|
||||
final_output_text = response_completed_event.response.output_text
|
||||
|
||||
# Verify final response has output
|
||||
assert len(response_completed_event.response.output) > 0
|
||||
|
||||
# Verify streaming text matches final output_text
|
||||
assert streaming_text == final_output_text, (
|
||||
f"Streaming text does not match final output_text.\n"
|
||||
f"Streaming: {streaming_text!r}\n"
|
||||
f"Final: {final_output_text!r}"
|
||||
)
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
|
||||
# Accumulated final output for streaming mode
|
||||
self._accumulated_text: str = ""
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
self._accumulated_logprobs: list = []
|
||||
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||
delta_output = output.outputs[0]
|
||||
self._accumulated_text += delta_output.text
|
||||
self._accumulated_token_ids.extend(delta_output.token_ids)
|
||||
if delta_output.logprobs is not None:
|
||||
self._accumulated_logprobs.extend(delta_output.logprobs)
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
message=delta_output.text,
|
||||
tokens=delta_output.token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def final_output(self) -> RequestOutput | None:
|
||||
"""Return the final output, with complete text/token_ids/logprobs."""
|
||||
if self.last_output is not None and self.last_output.outputs:
|
||||
assert isinstance(self.last_output, RequestOutput)
|
||||
final_output = copy.copy(self.last_output)
|
||||
# copy inner item to avoid modify last_output
|
||||
final_output.outputs = [replace(item) for item in self.last_output.outputs]
|
||||
final_output.outputs[0].text = self._accumulated_text
|
||||
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
|
||||
if self._accumulated_logprobs:
|
||||
final_output.outputs[0].logprobs = self._accumulated_logprobs
|
||||
return final_output
|
||||
return self.last_output
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
@ -675,7 +675,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
num_tool_output_tokens = 0
|
||||
else:
|
||||
assert isinstance(context, SimpleContext)
|
||||
final_res = context.last_output
|
||||
# Use final_output which has accumulated text/token_ids/logprobs
|
||||
final_res = context.final_output
|
||||
assert final_res is not None
|
||||
assert len(final_res.outputs) == 1
|
||||
final_output = final_res.outputs[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user