diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e3285a9bf76d1..2f8b31c8a7ba7 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -21,6 +21,7 @@ from http import HTTPStatus from typing import Annotated, Any, Optional import prometheus_client +import pydantic import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request @@ -1203,6 +1204,142 @@ class XRequestIdMiddleware: return self.app(scope, receive, send_with_request_id) +def _extract_content_from_chunk(chunk_data: dict) -> str: + """Extract content from a streaming response chunk.""" + try: + from vllm.entrypoints.openai.protocol import ( + ChatCompletionStreamResponse, CompletionStreamResponse) + + # Try using Completion types for type-safe parsing + if chunk_data.get('object') == 'chat.completion.chunk': + chat_response = ChatCompletionStreamResponse.model_validate( + chunk_data) + if chat_response.choices and chat_response.choices[0].delta.content: + return chat_response.choices[0].delta.content + elif chunk_data.get('object') == 'text_completion': + completion_response = CompletionStreamResponse.model_validate( + chunk_data) + if completion_response.choices and completion_response.choices[ + 0].text: + return completion_response.choices[0].text + except pydantic.ValidationError: + # Fallback to manual parsing + if 'choices' in chunk_data and chunk_data['choices']: + choice = chunk_data['choices'][0] + if 'delta' in choice and choice['delta'].get('content'): + return choice['delta']['content'] + elif choice.get('text'): + return choice['text'] + return "" + + +class SSEDecoder: + """Robust Server-Sent Events decoder for streaming responses.""" + + def __init__(self): + self.buffer = "" + self.content_buffer = [] + + def decode_chunk(self, chunk: bytes) -> list[dict]: + """Decode a chunk of SSE data and return parsed events.""" + import json + + try: + chunk_str = chunk.decode('utf-8') + except UnicodeDecodeError: + # Skip malformed chunks + return [] + + self.buffer += chunk_str + events = [] + + # Process complete lines + while '\n' in self.buffer: + line, self.buffer = self.buffer.split('\n', 1) + line = line.rstrip('\r') # Handle CRLF + + if line.startswith('data: '): + data_str = line[6:].strip() + if data_str == '[DONE]': + events.append({'type': 'done'}) + elif data_str: + try: + event_data = json.loads(data_str) + events.append({'type': 'data', 'data': event_data}) + except json.JSONDecodeError: + # Skip malformed JSON + continue + + return events + + def extract_content(self, event_data: dict) -> str: + """Extract content from event data.""" + return _extract_content_from_chunk(event_data) + + def add_content(self, content: str) -> None: + """Add content to the buffer.""" + if content: + self.content_buffer.append(content) + + def get_complete_content(self) -> str: + """Get the complete buffered content.""" + return ''.join(self.content_buffer) + + +def _log_streaming_response(response, response_body: list) -> None: + """Log streaming response with robust SSE parsing.""" + from starlette.concurrency import iterate_in_threadpool + + sse_decoder = SSEDecoder() + chunk_count = 0 + + def buffered_iterator(): + nonlocal chunk_count + + for chunk in response_body: + chunk_count += 1 + yield chunk + + # Parse SSE events from chunk + events = sse_decoder.decode_chunk(chunk) + + for event in events: + if event['type'] == 'data': + content = sse_decoder.extract_content(event['data']) + sse_decoder.add_content(content) + elif event['type'] == 'done': + # Log complete content when done + full_content = sse_decoder.get_complete_content() + if full_content: + # Truncate if too long + if len(full_content) > 2048: + full_content = full_content[:2048] + "" + "...[truncated]" + logger.info( + "response_body={streaming_complete: " \ + "content='%s', chunks=%d}", + full_content, chunk_count) + else: + logger.info( + "response_body={streaming_complete: " \ + "no_content, chunks=%d}", + chunk_count) + return + + response.body_iterator = iterate_in_threadpool(buffered_iterator()) + logger.info("response_body={streaming_started: chunks=%d}", + len(response_body)) + + +def _log_non_streaming_response(response_body: list) -> None: + """Log non-streaming response.""" + try: + decoded_body = response_body[0].decode() + logger.info("response_body={%s}", decoded_body) + except UnicodeDecodeError: + logger.info("response_body={}") + + def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI(openapi_url=None, @@ -1267,8 +1404,17 @@ def build_app(args: Namespace) -> FastAPI: section async for section in response.body_iterator ] response.body_iterator = iterate_in_threadpool(iter(response_body)) - logger.info("response_body={%s}", - response_body[0].decode() if response_body else None) + # Check if this is a streaming response by looking at content-type + content_type = response.headers.get("content-type", "") + is_streaming = content_type == "text/event-stream; charset=utf-8" + + # Log response body based on type + if not response_body: + logger.info("response_body={}") + elif is_streaming: + _log_streaming_response(response, response_body) + else: + _log_non_streaming_response(response_body) return response for middleware in args.middleware: