mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 20:31:20 +08:00
[Bugfix] fix streaming final output for non harmony (#30237)
Signed-off-by: penfree <qiupengfei@baidu.com> Co-authored-by: penfree <qiupengfei@baidu.com>
This commit is contained in:
parent
511e81e7c9
commit
bbd850e597
@ -87,3 +87,48 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
|
|||||||
assert response.output[0].type == "reasoning"
|
assert response.output[0].type == "reasoning"
|
||||||
assert response.output[1].type == "message"
|
assert response.output[1].type == "message"
|
||||||
assert type(response.output[1].content[0].text) is str
|
assert type(response.output[1].content[0].text) is str
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_streaming_output_consistency(client: OpenAI, model_name: str):
|
||||||
|
"""Test that streaming delta text matches the final response output_text.
|
||||||
|
|
||||||
|
This test verifies that when using streaming mode:
|
||||||
|
1. The concatenated text from all 'response.output_text.delta' events
|
||||||
|
2. Matches the 'output_text' in the final 'response.completed' event
|
||||||
|
"""
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
input="Say hello in one sentence.",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in response:
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert len(events) > 0
|
||||||
|
|
||||||
|
# Concatenate all delta text from streaming events
|
||||||
|
streaming_text = "".join(
|
||||||
|
event.delta for event in events if event.type == "response.output_text.delta"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the final response from the last event
|
||||||
|
response_completed_event = events[-1]
|
||||||
|
assert response_completed_event.type == "response.completed"
|
||||||
|
assert response_completed_event.response.status == "completed"
|
||||||
|
|
||||||
|
# Get output_text from the final response
|
||||||
|
final_output_text = response_completed_event.response.output_text
|
||||||
|
|
||||||
|
# Verify final response has output
|
||||||
|
assert len(response_completed_event.response.output) > 0
|
||||||
|
|
||||||
|
# Verify streaming text matches final output_text
|
||||||
|
assert streaming_text == final_output_text, (
|
||||||
|
f"Streaming text does not match final output_text.\n"
|
||||||
|
f"Streaming: {streaming_text!r}\n"
|
||||||
|
f"Final: {final_output_text!r}"
|
||||||
|
)
|
||||||
|
|||||||
@ -2,11 +2,13 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
from dataclasses import replace
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
from openai.types.responses.response_function_tool_call_output_item import (
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_output = None
|
self.last_output = None
|
||||||
|
|
||||||
|
# Accumulated final output for streaming mode
|
||||||
|
self._accumulated_text: str = ""
|
||||||
|
self._accumulated_token_ids: list[int] = []
|
||||||
|
self._accumulated_logprobs: list = []
|
||||||
|
|
||||||
self.num_prompt_tokens = 0
|
self.num_prompt_tokens = 0
|
||||||
self.num_output_tokens = 0
|
self.num_output_tokens = 0
|
||||||
self.num_cached_tokens = 0
|
self.num_cached_tokens = 0
|
||||||
@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
|
|||||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||||
|
|
||||||
|
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||||
|
delta_output = output.outputs[0]
|
||||||
|
self._accumulated_text += delta_output.text
|
||||||
|
self._accumulated_token_ids.extend(delta_output.token_ids)
|
||||||
|
if delta_output.logprobs is not None:
|
||||||
|
self._accumulated_logprobs.extend(delta_output.logprobs)
|
||||||
|
|
||||||
if len(self.input_messages) == 0:
|
if len(self.input_messages) == 0:
|
||||||
output_prompt = output.prompt or ""
|
output_prompt = output.prompt or ""
|
||||||
output_prompt_token_ids = output.prompt_token_ids or []
|
output_prompt_token_ids = output.prompt_token_ids or []
|
||||||
@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
|
|||||||
)
|
)
|
||||||
self.output_messages.append(
|
self.output_messages.append(
|
||||||
ResponseRawMessageAndToken(
|
ResponseRawMessageAndToken(
|
||||||
message=output.outputs[0].text,
|
message=delta_output.text,
|
||||||
tokens=output.outputs[0].token_ids,
|
tokens=delta_output.token_ids,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_output(self) -> RequestOutput | None:
|
||||||
|
"""Return the final output, with complete text/token_ids/logprobs."""
|
||||||
|
if self.last_output is not None and self.last_output.outputs:
|
||||||
|
assert isinstance(self.last_output, RequestOutput)
|
||||||
|
final_output = copy.copy(self.last_output)
|
||||||
|
# copy inner item to avoid modify last_output
|
||||||
|
final_output.outputs = [replace(item) for item in self.last_output.outputs]
|
||||||
|
final_output.outputs[0].text = self._accumulated_text
|
||||||
|
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
|
||||||
|
if self._accumulated_logprobs:
|
||||||
|
final_output.outputs[0].logprobs = self._accumulated_logprobs
|
||||||
|
return final_output
|
||||||
|
return self.last_output
|
||||||
|
|
||||||
def append_tool_output(self, output) -> None:
|
def append_tool_output(self, output) -> None:
|
||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
|
|||||||
@ -675,7 +675,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
num_tool_output_tokens = 0
|
num_tool_output_tokens = 0
|
||||||
else:
|
else:
|
||||||
assert isinstance(context, SimpleContext)
|
assert isinstance(context, SimpleContext)
|
||||||
final_res = context.last_output
|
# Use final_output which has accumulated text/token_ids/logprobs
|
||||||
|
final_res = context.final_output
|
||||||
assert final_res is not None
|
assert final_res is not None
|
||||||
assert len(final_res.outputs) == 1
|
assert len(final_res.outputs) == 1
|
||||||
final_output = final_res.outputs[0]
|
final_output = final_res.outputs[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user