[gpt-oss][1a] create_responses stream outputs BaseModel type, api server is SSE still (#24759)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia 2025-09-15 13:08:08 -07:00 committed by GitHub
parent 25aba2b6a3
commit 73df49ef3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 90 additions and 71 deletions

View File

@ -15,7 +15,7 @@ import socket
import tempfile import tempfile
import uuid import uuid
from argparse import Namespace from argparse import Namespace
from collections.abc import AsyncIterator, Awaitable from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from http import HTTPStatus from http import HTTPStatus
@ -29,6 +29,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import BaseModel
from prometheus_client import make_asgi_app from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
@ -577,6 +578,18 @@ async def show_version():
return JSONResponse(content=ver) return JSONResponse(content=ver)
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[BaseModel,
None]) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, 'type', 'unknown')
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
event_data = (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
yield event_data
@router.post("/v1/responses", @router.post("/v1/responses",
dependencies=[Depends(validate_json_request)], dependencies=[Depends(validate_json_request)],
responses={ responses={
@ -612,7 +625,9 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
status_code=generator.error.code) status_code=generator.error.code)
elif isinstance(generator, ResponsesResponse): elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
return StreamingResponse(content=_convert_stream_to_sse_events(generator),
media_type="text/event-stream")
@router.get("/v1/responses/{response_id}") @router.get("/v1/responses/{response_id}")
@ -640,10 +655,10 @@ async def retrieve_responses(
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.error.code) status_code=response.error.code)
elif stream: elif isinstance(response, ResponsesResponse):
return StreamingResponse(content=response, return JSONResponse(content=response.model_dump())
media_type="text/event-stream") return StreamingResponse(content=_convert_stream_to_sse_events(response),
return JSONResponse(content=response.model_dump()) media_type="text/event-stream")
@router.post("/v1/responses/{response_id}/cancel") @router.post("/v1/responses/{response_id}/cancel")

View File

@ -10,7 +10,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Callable, Final, Optional, Union from typing import Callable, Final, Optional, TypeVar, Union
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
@ -175,7 +175,8 @@ class OpenAIServingResponses(OpenAIServing):
# HACK(wuhang): This is a hack. We should use a better store. # HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we # FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store. # never remove events from the store.
self.event_store: dict[str, tuple[deque[str], asyncio.Event]] = {} self.event_store: dict[str, tuple[deque[BaseModel],
asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {} self.background_tasks: dict[str, asyncio.Task] = {}
@ -185,7 +186,8 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
request: ResponsesRequest, request: ResponsesRequest,
raw_request: Optional[Request] = None, raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: ) -> Union[AsyncGenerator[BaseModel, None], ResponsesResponse,
ErrorResponse]:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret) logger.error("Error with model %s", error_check_ret)
@ -812,7 +814,7 @@ class OpenAIServingResponses(OpenAIServing):
*args, *args,
**kwargs, **kwargs,
): ):
event_deque: deque[str] = deque() event_deque: deque[BaseModel] = deque()
new_event_signal = asyncio.Event() new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal) self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None response = None
@ -827,8 +829,6 @@ class OpenAIServingResponses(OpenAIServing):
request.request_id) request.request_id)
response = self.create_error_response(str(e)) response = self.create_error_response(str(e))
finally: finally:
# Mark as finished with a special marker
event_deque.append("__STREAM_END__")
new_event_signal.set() new_event_signal.set()
if response is not None and isinstance(response, ErrorResponse): if response is not None and isinstance(response, ErrorResponse):
@ -867,7 +867,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
response_id: str, response_id: str,
starting_after: Optional[int] = None, starting_after: Optional[int] = None,
): ) -> AsyncGenerator[BaseModel, None]:
if response_id not in self.event_store: if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}") raise ValueError(f"Unknown response_id: {response_id}")
@ -881,9 +881,9 @@ class OpenAIServingResponses(OpenAIServing):
# Yield existing events from start_index # Yield existing events from start_index
while current_index < len(event_deque): while current_index < len(event_deque):
event = event_deque[current_index] event = event_deque[current_index]
if event == "__STREAM_END__":
return
yield event yield event
if getattr(event, 'type', 'unknown') == "response.completed":
return
current_index += 1 current_index += 1
await new_event_signal.wait() await new_event_signal.wait()
@ -893,7 +893,8 @@ class OpenAIServingResponses(OpenAIServing):
response_id: str, response_id: str,
starting_after: Optional[int], starting_after: Optional[int],
stream: Optional[bool], stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse]: ) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[BaseModel,
None]]:
if not response_id.startswith("resp_"): if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id) return self._make_invalid_id_error(response_id)
@ -976,8 +977,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_send_event: Callable[[BaseModel], str], _increment_sequence_number_and_return: Callable[[BaseModel],
) -> AsyncGenerator[str, None]: BaseModel],
) -> AsyncGenerator[BaseModel, None]:
current_content_index = 0 current_content_index = 0
current_output_index = 0 current_output_index = 0
current_item_id = "" current_item_id = ""
@ -1014,7 +1016,7 @@ class OpenAIServingResponses(OpenAIServing):
if not first_delta_sent: if not first_delta_sent:
current_item_id = str(uuid.uuid4()) current_item_id = str(uuid.uuid4())
if delta_message.reasoning_content: if delta_message.reasoning_content:
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
@ -1029,7 +1031,7 @@ class OpenAIServingResponses(OpenAIServing):
), ),
)) ))
else: else:
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
@ -1044,7 +1046,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent( openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added", type="response.content_part.added",
sequence_number=-1, sequence_number=-1,
@ -1072,7 +1074,7 @@ class OpenAIServingResponses(OpenAIServing):
reason_content = ''.join( reason_content = ''.join(
pm.reasoning_content for pm in previous_delta_messages pm.reasoning_content for pm in previous_delta_messages
if pm.reasoning_content is not None) if pm.reasoning_content is not None)
yield _send_event( yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent( ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done", type="response.reasoning_text.done",
item_id=current_item_id, item_id=current_item_id,
@ -1094,14 +1096,14 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id, id=current_item_id,
summary=[], summary=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
output_index=current_output_index, output_index=current_output_index,
item=reasoning_item, item=reasoning_item,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent( openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
@ -1116,7 +1118,7 @@ class OpenAIServingResponses(OpenAIServing):
)) ))
current_output_index += 1 current_output_index += 1
current_item_id = str(uuid.uuid4()) current_item_id = str(uuid.uuid4())
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent( openai_responses_types.ResponseContentPartAddedEvent(
type="response.content_part.added", type="response.content_part.added",
sequence_number=-1, sequence_number=-1,
@ -1135,7 +1137,7 @@ class OpenAIServingResponses(OpenAIServing):
previous_delta_messages = [] previous_delta_messages = []
if delta_message.reasoning_content is not None: if delta_message.reasoning_content is not None:
yield _send_event( yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent( ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta", type="response.reasoning_text.delta",
sequence_number=-1, sequence_number=-1,
@ -1145,7 +1147,7 @@ class OpenAIServingResponses(OpenAIServing):
delta=delta_message.reasoning_content, delta=delta_message.reasoning_content,
)) ))
elif delta_message.content is not None: elif delta_message.content is not None:
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent( openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta", type="response.output_text.delta",
sequence_number=-1, sequence_number=-1,
@ -1168,7 +1170,7 @@ class OpenAIServingResponses(OpenAIServing):
reason_content = ''.join(pm.reasoning_content reason_content = ''.join(pm.reasoning_content
for pm in previous_delta_messages for pm in previous_delta_messages
if pm.reasoning_content is not None) if pm.reasoning_content is not None)
yield _send_event( yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent( ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done", type="response.reasoning_text.done",
item_id=current_item_id, item_id=current_item_id,
@ -1190,7 +1192,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id, id=current_item_id,
summary=[], summary=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1201,7 +1203,7 @@ class OpenAIServingResponses(OpenAIServing):
final_content = ''.join(pm.content final_content = ''.join(pm.content
for pm in previous_delta_messages for pm in previous_delta_messages
if pm.content is not None) if pm.content is not None)
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent( openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done", type="response.output_text.done",
sequence_number=-1, sequence_number=-1,
@ -1217,7 +1219,7 @@ class OpenAIServingResponses(OpenAIServing):
type="output_text", type="output_text",
annotations=[], annotations=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartDoneEvent( openai_responses_types.ResponseContentPartDoneEvent(
type="response.content_part.done", type="response.content_part.done",
sequence_number=-1, sequence_number=-1,
@ -1237,7 +1239,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id, id=current_item_id,
summary=[], summary=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1255,8 +1257,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: int, created_time: int,
_send_event: Callable[[BaseModel], str], _increment_sequence_number_and_return: Callable[[BaseModel],
) -> AsyncGenerator[str, None]: BaseModel],
) -> AsyncGenerator[BaseModel, None]:
current_content_index = 0 # FIXME: this number is never changed current_content_index = 0 # FIXME: this number is never changed
current_output_index = 0 current_output_index = 0
current_item_id = "" # FIXME: this number is never changed current_item_id = "" # FIXME: this number is never changed
@ -1288,7 +1291,7 @@ class OpenAIServingResponses(OpenAIServing):
id=current_item_id, id=current_item_id,
summary=[], summary=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
ResponseReasoningTextDoneEvent( ResponseReasoningTextDoneEvent(
type="response.reasoning_text.done", type="response.reasoning_text.done",
item_id=current_item_id, item_id=current_item_id,
@ -1297,7 +1300,7 @@ class OpenAIServingResponses(OpenAIServing):
content_index=current_content_index, content_index=current_content_index,
text=previous_item.content[0].text, text=previous_item.content[0].text,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
ResponseOutputItemDoneEvent( ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1310,7 +1313,7 @@ class OpenAIServingResponses(OpenAIServing):
text=previous_item.content[0].text, text=previous_item.content[0].text,
annotations=[], annotations=[],
) )
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent( openai_responses_types.ResponseTextDoneEvent(
type="response.output_text.done", type="response.output_text.done",
sequence_number=-1, sequence_number=-1,
@ -1320,7 +1323,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[], logprobs=[],
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseContentPartDoneEvent( ResponseContentPartDoneEvent(
type="response.content_part.done", type="response.content_part.done",
@ -1330,7 +1333,7 @@ class OpenAIServingResponses(OpenAIServing):
content_index=current_content_index, content_index=current_content_index,
part=text_content, part=text_content,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent( openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1344,12 +1347,13 @@ class OpenAIServingResponses(OpenAIServing):
), ),
)) ))
# stream the output of a harmony message
if ctx.parser.last_content_delta: if ctx.parser.last_content_delta:
if (ctx.parser.current_channel == "final" if (ctx.parser.current_channel == "final"
and ctx.parser.current_recipient is None): and ctx.parser.current_recipient is None):
if not sent_output_item_added: if not sent_output_item_added:
sent_output_item_added = True sent_output_item_added = True
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
@ -1364,7 +1368,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseContentPartAddedEvent( ResponseContentPartAddedEvent(
type="response.content_part.added", type="response.content_part.added",
@ -1379,7 +1383,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[], logprobs=[],
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent( openai_responses_types.ResponseTextDeltaEvent(
type="response.output_text.delta", type="response.output_text.delta",
sequence_number=-1, sequence_number=-1,
@ -1394,7 +1398,7 @@ class OpenAIServingResponses(OpenAIServing):
and ctx.parser.current_recipient is None): and ctx.parser.current_recipient is None):
if not sent_output_item_added: if not sent_output_item_added:
sent_output_item_added = True sent_output_item_added = True
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
@ -1408,7 +1412,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseContentPartAddedEvent( ResponseContentPartAddedEvent(
type="response.content_part.added", type="response.content_part.added",
@ -1423,7 +1427,7 @@ class OpenAIServingResponses(OpenAIServing):
logprobs=[], logprobs=[],
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
ResponseReasoningTextDeltaEvent( ResponseReasoningTextDeltaEvent(
type="response.reasoning_text.delta", type="response.reasoning_text.delta",
item_id=current_item_id, item_id=current_item_id,
@ -1440,7 +1444,7 @@ class OpenAIServingResponses(OpenAIServing):
) and ctx.parser.current_recipient == "python": ) and ctx.parser.current_recipient == "python":
if not sent_output_item_added: if not sent_output_item_added:
sent_output_item_added = True sent_output_item_added = True
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseOutputItemAddedEvent( ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
@ -1456,7 +1460,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseCodeInterpreterCallInProgressEvent( ResponseCodeInterpreterCallInProgressEvent(
type= type=
@ -1465,7 +1469,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseCodeInterpreterCallCodeDeltaEvent( ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta", type="response.code_interpreter_call_code.delta",
@ -1474,6 +1478,8 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id, item_id=current_item_id,
delta=ctx.parser.last_content_delta, delta=ctx.parser.last_content_delta,
)) ))
# stream tool call outputs
if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0:
previous_item = ctx.parser.messages[-1] previous_item = ctx.parser.messages[-1]
if (self.tool_server is not None if (self.tool_server is not None
@ -1510,7 +1516,7 @@ class OpenAIServingResponses(OpenAIServing):
raise ValueError( raise ValueError(
f"Unknown function name: {function_name}") f"Unknown function name: {function_name}")
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent( openai_responses_types.ResponseOutputItemAddedEvent(
type="response.output_item.added", type="response.output_item.added",
sequence_number=-1, sequence_number=-1,
@ -1525,7 +1531,7 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
), ),
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseWebSearchCallInProgressEvent( ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress", type="response.web_search_call.in_progress",
@ -1533,7 +1539,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseWebSearchCallSearchingEvent( ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching", type="response.web_search_call.searching",
@ -1543,7 +1549,7 @@ class OpenAIServingResponses(OpenAIServing):
)) ))
# enqueue # enqueue
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseWebSearchCallCompletedEvent( ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed", type="response.web_search_call.completed",
@ -1551,7 +1557,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent( openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1569,7 +1575,7 @@ class OpenAIServingResponses(OpenAIServing):
and self.tool_server.has_tool("python") and self.tool_server.has_tool("python")
and previous_item.recipient is not None and previous_item.recipient is not None
and previous_item.recipient.startswith("python")): and previous_item.recipient.startswith("python")):
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseCodeInterpreterCallCodeDoneEvent( ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done", type="response.code_interpreter_call_code.done",
@ -1578,7 +1584,7 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id, item_id=current_item_id,
code=previous_item.content[0].text, code=previous_item.content[0].text,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseCodeInterpreterCallInterpretingEvent( ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting", type="response.code_interpreter_call.interpreting",
@ -1586,7 +1592,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types. openai_responses_types.
ResponseCodeInterpreterCallCompletedEvent( ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed", type="response.code_interpreter_call.completed",
@ -1594,7 +1600,7 @@ class OpenAIServingResponses(OpenAIServing):
output_index=current_output_index, output_index=current_output_index,
item_id=current_item_id, item_id=current_item_id,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent( openai_responses_types.ResponseOutputItemDoneEvent(
type="response.output_item.done", type="response.output_item.done",
sequence_number=-1, sequence_number=-1,
@ -1621,7 +1627,7 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None, created_time: Optional[int] = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[BaseModel, None]:
# TODO: # TODO:
# 1. Handle disconnect # 1. Handle disconnect
@ -1629,16 +1635,15 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number = 0 sequence_number = 0
def _send_event(event: BaseModel): T = TypeVar("T", bound=BaseModel)
def _increment_sequence_number_and_return(event: T) -> T:
nonlocal sequence_number nonlocal sequence_number
# Set sequence_number if the event has this attribute # Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'): if hasattr(event, 'sequence_number'):
event.sequence_number = sequence_number event.sequence_number = sequence_number
sequence_number += 1 sequence_number += 1
# Get event type from the event's type field if it exists return event
event_type = getattr(event, 'type', 'unknown')
return (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
async with AsyncExitStack() as exit_stack: async with AsyncExitStack() as exit_stack:
processer = None processer = None
@ -1658,24 +1663,23 @@ class OpenAIServingResponses(OpenAIServing):
status="in_progress", status="in_progress",
usage=None, usage=None,
).model_dump() ).model_dump()
yield _send_event( yield _increment_sequence_number_and_return(
ResponseCreatedEvent( ResponseCreatedEvent(
type="response.created", type="response.created",
sequence_number=-1, sequence_number=-1,
response=initial_response, response=initial_response,
)) ))
yield _send_event( yield _increment_sequence_number_and_return(
ResponseInProgressEvent( ResponseInProgressEvent(
type="response.in_progress", type="response.in_progress",
sequence_number=-1, sequence_number=-1,
response=initial_response, response=initial_response,
)) ))
async for event_data in processer(request, sampling_params, async for event_data in processer(
result_generator, context, request, sampling_params, result_generator, context,
model_name, tokenizer, model_name, tokenizer, request_metadata, created_time,
request_metadata, created_time, _increment_sequence_number_and_return):
_send_event):
yield event_data yield event_data
async def empty_async_generator(): async def empty_async_generator():
@ -1694,7 +1698,7 @@ class OpenAIServingResponses(OpenAIServing):
request_metadata, request_metadata,
created_time=created_time, created_time=created_time,
) )
yield _send_event( yield _increment_sequence_number_and_return(
openai_responses_types.ResponseCompletedEvent( openai_responses_types.ResponseCompletedEvent(
type="response.completed", type="response.completed",
sequence_number=-1, sequence_number=-1,