[gpt-oss][1a] create_responses stream outputs BaseModel type, api server is SSE still (#24759)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia 2025-09-15 13:08:08 -07:00 committed by GitHub
parent 25aba2b6a3
commit 73df49ef3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 71 deletions

View File

@ -15,7 +15,7 @@ import socket
import tempfile
import uuid
from argparse import Namespace
from collections.abc import AsyncIterator, Awaitable
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
@ -29,6 +29,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import BaseModel
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
@ -577,6 +578,18 @@ async def show_version():
return JSONResponse(content=ver)
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[BaseModel,
None]) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, 'type', 'unknown')
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
event_data = (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
yield event_data
@router.post("/v1/responses",
dependencies=[Depends(validate_json_request)],
responses={
@ -612,7 +625,9 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
status_code=generator.error.code)
elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
return StreamingResponse(content=_convert_stream_to_sse_events(generator),
media_type="text/event-stream")
@router.get("/v1/responses/{response_id}")
@ -640,10 +655,10 @@ async def retrieve_responses(
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.error.code)
elif stream:
return StreamingResponse(content=response,
media_type="text/event-stream")
return JSONResponse(content=response.model_dump())
elif isinstance(response, ResponsesResponse):
return JSONResponse(content=response.model_dump())
return StreamingResponse(content=_convert_stream_to_sse_events(response),
media_type="text/event-stream")
@router.post("/v1/responses/{response_id}/cancel")

View File

@ -10,7 +10,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Callable, Final, Optional, Union
from typing import Callable, Final, Optional, TypeVar, Union
import jinja2
import openai.types.responses as openai_responses_types
@ -175,7 +175,8 @@ class OpenAIServingResponses(OpenAIServing):
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {}
self.event_store: dict[str, tuple[deque[BaseModel],
asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
@ -185,7 +186,8 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]:
) -> Union[AsyncGenerator[BaseModel, None], ResponsesResponse,
ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
@ -812,7 +814,7 @@ class OpenAIServingResponses(OpenAIServing):
*args,
**kwargs,
):
event_deque: deque[str] = deque()
event_deque: deque[BaseModel] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
@ -827,8 +829,6 @@ class OpenAIServingResponses(OpenAIServing):
request.request_id)
response = self.create_error_response(str(e))
finally:
# Mark as finished with a special marker
event_deque.append("__STREAM_END__")
new_event_signal.set()
if response is not None and isinstance(response, ErrorResponse):
@ -867,7 +867,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
response_id: str,
starting_after: Optional[int] = None,
):
) -> AsyncGenerator[BaseModel, None]:
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
@ -881,9 +881,9 @@ class OpenAIServingResponses(OpenAIServing):
# Yield existing events from start_index
while current_index < len(event_deque):
event = event_deque[current_index]
if event == "__STREAM_END__":
return
yield event
if getattr(event, 'type', 'unknown') == "response.completed":
return
current_index += 1
await new_event_signal.wait()
@ -893,7 +893,8 @@ class OpenAIServingResponses(OpenAIServing):
response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse]:
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[BaseModel,
None]]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
@ -976,8 +977,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_send_event: Callable[[BaseModel], str],
) -> AsyncGenerator[str, None]:
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
current_content_index = 0
current_output_index = 0
current_item_id = ""
@ -1014,7 +1016,7 @@ class OpenAIServingResponses(OpenAIServing):
if not first_delta_sent:
current_item_id = str(uuid.uuid4())
if delta_message.reasoning_content:
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@ -1029,7 +1031,7 @@ class OpenAIServingResponses(OpenAIServing):
),
))
else:
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@ -1044,7 +1046,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
@ -1072,7 +1074,7 @@ class OpenAIServingResponses(OpenAIServing):
reason_content = ''.join(
pm.reasoning_content for pm in previous_delta_messages
if pm.reasoning_content is not None)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
@ -1094,14 +1096,14 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id,
summary=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=reasoning_item,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
@ -1116,7 +1118,7 @@ class OpenAIServingResponses(OpenAIServing):
))
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
@ -1135,7 +1137,7 @@ class OpenAIServingResponses(OpenAIServing):
previous_delta_messages = []
if delta_message.reasoning_content is not None:
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
sequence_number=-1,
@ -1145,7 +1147,7 @@ class OpenAIServingResponses(OpenAIServing):
delta=delta_message.reasoning_content,
))
elif delta_message.content is not None:
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
@ -1168,7 +1170,7 @@ class OpenAIServingResponses(OpenAIServing):
reason_content = ''.join(pm.reasoning_content
for pm in previous_delta_messages
if pm.reasoning_content is not None)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
@ -1190,7 +1192,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id,
summary=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1201,7 +1203,7 @@ class OpenAIServingResponses(OpenAIServing):
final_content = ''.join(pm.content
for pm in previous_delta_messages
if pm.content is not None)
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
@ -1217,7 +1219,7 @@ class OpenAIServingResponses(OpenAIServing):
type="output_text",
annotations=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
@ -1237,7 +1239,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id,
summary=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1255,8 +1257,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_send_event: Callable[[BaseModel], str],
) -> AsyncGenerator[str, None]:
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
current_content_index = 0 # FIXME: this number is never changed
current_output_index = 0
current_item_id = "" # FIXME: this number is never changed
@ -1288,7 +1291,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id,
summary=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done",
item_id=current_item_id,
@ -1297,7 +1300,7 @@ class OpenAIServingResponses(OpenAIServing):
content_index=current_content_index,
text=previous_item.content[0].text,
))
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1310,7 +1313,7 @@ class OpenAIServingResponses(OpenAIServing):
text=previous_item.content[0].text,
annotations=[],
)
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
@ -1320,7 +1323,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[],
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartDoneEvent(
type="response.content_part.done",
@ -1330,7 +1333,7 @@ class OpenAIServingResponses(OpenAIServing):
content_index=current_content_index,
part=text_content,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1344,12 +1347,13 @@ class OpenAIServingResponses(OpenAIServing):
),
))
# stream the output of a harmony message
if ctx.parser.last_content_delta:
if (ctx.parser.current_channel == "final"
and ctx.parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@ -1364,7 +1368,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
@ -1379,7 +1383,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[],
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
@ -1394,7 +1398,7 @@ class OpenAIServingResponses(OpenAIServing):
and ctx.parser.current_recipient is None):
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@ -1408,7 +1412,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
@ -1423,7 +1427,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[],
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta",
item_id=current_item_id,
@ -1440,7 +1444,7 @@ class OpenAIServingResponses(OpenAIServing):
) and ctx.parser.current_recipient == "python":
if not sent_output_item_added:
sent_output_item_added = True
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@ -1456,7 +1460,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInProgressEvent(
type=
@ -1465,7 +1469,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index,
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
@ -1474,6 +1478,8 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
delta=ctx.parser.last_content_delta,
))
# stream tool call outputs
if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1]
if (self.tool_server is not None
@ -1510,7 +1516,7 @@ class OpenAIServingResponses(OpenAIServing):
raise ValueError(
f"Unknown function name: {function_name}")
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
@ -1525,7 +1531,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
),
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
@ -1533,7 +1539,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index,
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
@ -1543,7 +1549,7 @@ class OpenAIServingResponses(OpenAIServing):
))
# enqueue
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
@ -1551,7 +1557,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index,
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1569,7 +1575,7 @@ class OpenAIServingResponses(OpenAIServing):
and self.tool_server.has_tool("python")
and previous_item.recipient is not None
and previous_item.recipient.startswith("python")):
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
@ -1578,7 +1584,7 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
code=previous_item.content[0].text,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
@ -1586,7 +1592,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index,
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
@ -1594,7 +1600,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index,
item_id=current_item_id,
))
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
@ -1621,7 +1627,7 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[BaseModel, None]:
# TODO:
# 1. Handle disconnect
@ -1629,16 +1635,15 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number = 0
def _send_event(event: BaseModel):
T = TypeVar("T", bound=BaseModel)
def _increment_sequence_number_and_return(event: T) -> T:
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number
sequence_number += 1
# Get event type from the event's type field if it exists
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
return event
async with AsyncExitStack() as exit_stack:
processer = None
@ -1658,24 +1663,23 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress",
usage=None,
).model_dump()
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseCreatedEvent(
type="response.created",
sequence_number=-1,
response=initial_response,
))
yield _send_event(
yield _increment_sequence_number_and_return(
ResponseInProgressEvent(
type="response.in_progress",
sequence_number=-1,
response=initial_response,
))
async for event_data in processer(request, sampling_params,
result_generator, context,
model_name, tokenizer,
request_metadata, created_time,
_send_event):
async for event_data in processer(
request, sampling_params, result_generator, context,
model_name, tokenizer, request_metadata, created_time,
_increment_sequence_number_and_return):
yield event_data
async def empty_async_generator():
@ -1694,7 +1698,7 @@ class OpenAIServingResponses(OpenAIServing):
request_metadata,
created_time=created_time,
)
yield _send_event(
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,