[Misc] feat output content in stream response (#19608)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-07-08 04:45:10 +08:00 committed by GitHub
parent e601efcb10
commit 8e807cdfa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ from http import HTTPStatus
from typing import Annotated, Any, Optional
import prometheus_client
import pydantic
import regex as re
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
@ -1203,6 +1204,142 @@ class XRequestIdMiddleware:
return self.app(scope, receive, send_with_request_id)
def _extract_content_from_chunk(chunk_data: dict) -> str:
"""Extract content from a streaming response chunk."""
try:
from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, CompletionStreamResponse)
# Try using Completion types for type-safe parsing
if chunk_data.get('object') == 'chat.completion.chunk':
chat_response = ChatCompletionStreamResponse.model_validate(
chunk_data)
if chat_response.choices and chat_response.choices[0].delta.content:
return chat_response.choices[0].delta.content
elif chunk_data.get('object') == 'text_completion':
completion_response = CompletionStreamResponse.model_validate(
chunk_data)
if completion_response.choices and completion_response.choices[
0].text:
return completion_response.choices[0].text
except pydantic.ValidationError:
# Fallback to manual parsing
if 'choices' in chunk_data and chunk_data['choices']:
choice = chunk_data['choices'][0]
if 'delta' in choice and choice['delta'].get('content'):
return choice['delta']['content']
elif choice.get('text'):
return choice['text']
return ""
class SSEDecoder:
"""Robust Server-Sent Events decoder for streaming responses."""
def __init__(self):
self.buffer = ""
self.content_buffer = []
def decode_chunk(self, chunk: bytes) -> list[dict]:
"""Decode a chunk of SSE data and return parsed events."""
import json
try:
chunk_str = chunk.decode('utf-8')
except UnicodeDecodeError:
# Skip malformed chunks
return []
self.buffer += chunk_str
events = []
# Process complete lines
while '\n' in self.buffer:
line, self.buffer = self.buffer.split('\n', 1)
line = line.rstrip('\r') # Handle CRLF
if line.startswith('data: '):
data_str = line[6:].strip()
if data_str == '[DONE]':
events.append({'type': 'done'})
elif data_str:
try:
event_data = json.loads(data_str)
events.append({'type': 'data', 'data': event_data})
except json.JSONDecodeError:
# Skip malformed JSON
continue
return events
def extract_content(self, event_data: dict) -> str:
"""Extract content from event data."""
return _extract_content_from_chunk(event_data)
def add_content(self, content: str) -> None:
"""Add content to the buffer."""
if content:
self.content_buffer.append(content)
def get_complete_content(self) -> str:
"""Get the complete buffered content."""
return ''.join(self.content_buffer)
def _log_streaming_response(response, response_body: list) -> None:
"""Log streaming response with robust SSE parsing."""
from starlette.concurrency import iterate_in_threadpool
sse_decoder = SSEDecoder()
chunk_count = 0
def buffered_iterator():
nonlocal chunk_count
for chunk in response_body:
chunk_count += 1
yield chunk
# Parse SSE events from chunk
events = sse_decoder.decode_chunk(chunk)
for event in events:
if event['type'] == 'data':
content = sse_decoder.extract_content(event['data'])
sse_decoder.add_content(content)
elif event['type'] == 'done':
# Log complete content when done
full_content = sse_decoder.get_complete_content()
if full_content:
# Truncate if too long
if len(full_content) > 2048:
full_content = full_content[:2048] + ""
"...[truncated]"
logger.info(
"response_body={streaming_complete: " \
"content='%s', chunks=%d}",
full_content, chunk_count)
else:
logger.info(
"response_body={streaming_complete: " \
"no_content, chunks=%d}",
chunk_count)
return
response.body_iterator = iterate_in_threadpool(buffered_iterator())
logger.info("response_body={streaming_started: chunks=%d}",
len(response_body))
def _log_non_streaming_response(response_body: list) -> None:
"""Log non-streaming response."""
try:
decoded_body = response_body[0].decode()
logger.info("response_body={%s}", decoded_body)
except UnicodeDecodeError:
logger.info("response_body={<binary_data>}")
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
@ -1267,8 +1404,17 @@ def build_app(args: Namespace) -> FastAPI:
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}",
response_body[0].decode() if response_body else None)
# Check if this is a streaming response by looking at content-type
content_type = response.headers.get("content-type", "")
is_streaming = content_type == "text/event-stream; charset=utf-8"
# Log response body based on type
if not response_body:
logger.info("response_body={<empty>}")
elif is_streaming:
_log_streaming_response(response, response_body)
else:
_log_non_streaming_response(response_body)
return response
for middleware in args.middleware: