mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 07:56:31 +08:00
[Misc] feat output content in stream response (#19608)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
e601efcb10
commit
8e807cdfa4
@ -21,6 +21,7 @@ from http import HTTPStatus
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
@ -1203,6 +1204,142 @@ class XRequestIdMiddleware:
|
||||
return self.app(scope, receive, send_with_request_id)
|
||||
|
||||
|
||||
def _extract_content_from_chunk(chunk_data: dict) -> str:
|
||||
"""Extract content from a streaming response chunk."""
|
||||
try:
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse, CompletionStreamResponse)
|
||||
|
||||
# Try using Completion types for type-safe parsing
|
||||
if chunk_data.get('object') == 'chat.completion.chunk':
|
||||
chat_response = ChatCompletionStreamResponse.model_validate(
|
||||
chunk_data)
|
||||
if chat_response.choices and chat_response.choices[0].delta.content:
|
||||
return chat_response.choices[0].delta.content
|
||||
elif chunk_data.get('object') == 'text_completion':
|
||||
completion_response = CompletionStreamResponse.model_validate(
|
||||
chunk_data)
|
||||
if completion_response.choices and completion_response.choices[
|
||||
0].text:
|
||||
return completion_response.choices[0].text
|
||||
except pydantic.ValidationError:
|
||||
# Fallback to manual parsing
|
||||
if 'choices' in chunk_data and chunk_data['choices']:
|
||||
choice = chunk_data['choices'][0]
|
||||
if 'delta' in choice and choice['delta'].get('content'):
|
||||
return choice['delta']['content']
|
||||
elif choice.get('text'):
|
||||
return choice['text']
|
||||
return ""
|
||||
|
||||
|
||||
class SSEDecoder:
|
||||
"""Robust Server-Sent Events decoder for streaming responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.content_buffer = []
|
||||
|
||||
def decode_chunk(self, chunk: bytes) -> list[dict]:
|
||||
"""Decode a chunk of SSE data and return parsed events."""
|
||||
import json
|
||||
|
||||
try:
|
||||
chunk_str = chunk.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
# Skip malformed chunks
|
||||
return []
|
||||
|
||||
self.buffer += chunk_str
|
||||
events = []
|
||||
|
||||
# Process complete lines
|
||||
while '\n' in self.buffer:
|
||||
line, self.buffer = self.buffer.split('\n', 1)
|
||||
line = line.rstrip('\r') # Handle CRLF
|
||||
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:].strip()
|
||||
if data_str == '[DONE]':
|
||||
events.append({'type': 'done'})
|
||||
elif data_str:
|
||||
try:
|
||||
event_data = json.loads(data_str)
|
||||
events.append({'type': 'data', 'data': event_data})
|
||||
except json.JSONDecodeError:
|
||||
# Skip malformed JSON
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def extract_content(self, event_data: dict) -> str:
|
||||
"""Extract content from event data."""
|
||||
return _extract_content_from_chunk(event_data)
|
||||
|
||||
def add_content(self, content: str) -> None:
|
||||
"""Add content to the buffer."""
|
||||
if content:
|
||||
self.content_buffer.append(content)
|
||||
|
||||
def get_complete_content(self) -> str:
|
||||
"""Get the complete buffered content."""
|
||||
return ''.join(self.content_buffer)
|
||||
|
||||
|
||||
def _log_streaming_response(response, response_body: list) -> None:
|
||||
"""Log streaming response with robust SSE parsing."""
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
|
||||
sse_decoder = SSEDecoder()
|
||||
chunk_count = 0
|
||||
|
||||
def buffered_iterator():
|
||||
nonlocal chunk_count
|
||||
|
||||
for chunk in response_body:
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
|
||||
# Parse SSE events from chunk
|
||||
events = sse_decoder.decode_chunk(chunk)
|
||||
|
||||
for event in events:
|
||||
if event['type'] == 'data':
|
||||
content = sse_decoder.extract_content(event['data'])
|
||||
sse_decoder.add_content(content)
|
||||
elif event['type'] == 'done':
|
||||
# Log complete content when done
|
||||
full_content = sse_decoder.get_complete_content()
|
||||
if full_content:
|
||||
# Truncate if too long
|
||||
if len(full_content) > 2048:
|
||||
full_content = full_content[:2048] + ""
|
||||
"...[truncated]"
|
||||
logger.info(
|
||||
"response_body={streaming_complete: " \
|
||||
"content='%s', chunks=%d}",
|
||||
full_content, chunk_count)
|
||||
else:
|
||||
logger.info(
|
||||
"response_body={streaming_complete: " \
|
||||
"no_content, chunks=%d}",
|
||||
chunk_count)
|
||||
return
|
||||
|
||||
response.body_iterator = iterate_in_threadpool(buffered_iterator())
|
||||
logger.info("response_body={streaming_started: chunks=%d}",
|
||||
len(response_body))
|
||||
|
||||
|
||||
def _log_non_streaming_response(response_body: list) -> None:
|
||||
"""Log non-streaming response."""
|
||||
try:
|
||||
decoded_body = response_body[0].decode()
|
||||
logger.info("response_body={%s}", decoded_body)
|
||||
except UnicodeDecodeError:
|
||||
logger.info("response_body={<binary_data>}")
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
@ -1267,8 +1404,17 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
section async for section in response.body_iterator
|
||||
]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
logger.info("response_body={%s}",
|
||||
response_body[0].decode() if response_body else None)
|
||||
# Check if this is a streaming response by looking at content-type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
is_streaming = content_type == "text/event-stream; charset=utf-8"
|
||||
|
||||
# Log response body based on type
|
||||
if not response_body:
|
||||
logger.info("response_body={<empty>}")
|
||||
elif is_streaming:
|
||||
_log_streaming_response(response, response_body)
|
||||
else:
|
||||
_log_non_streaming_response(response_body)
|
||||
return response
|
||||
|
||||
for middleware in args.middleware:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user