[gpt-oss][2] fix types for streaming (#24556)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia 2025-09-17 15:04:28 -07:00 committed by GitHub
parent 3c068c637b
commit bff2e5f1d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 104 additions and 96 deletions

View File

@ -27,7 +27,6 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import BaseModel
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
@ -67,7 +66,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
RerankRequest, RerankResponse,
ResponsesRequest,
ResponsesResponse, ScoreRequest,
ScoreResponse, TokenizeRequest,
ScoreResponse,
StreamingResponsesResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
@ -481,8 +482,8 @@ async def show_version():
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[BaseModel,
None]) -> AsyncGenerator[str, None]:
generator: AsyncGenerator[StreamingResponsesResponse, None]
) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, 'type', 'unknown')

View File

@ -18,10 +18,19 @@ from openai.types.chat.chat_completion_audio import (
from openai.types.chat.chat_completion_message import (
Annotation as OpenAIAnnotation)
# yapf: enable
from openai.types.responses import (ResponseFunctionToolCall,
ResponseInputItemParam, ResponseOutputItem,
ResponsePrompt, ResponseReasoningItem,
ResponseStatus)
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent,
ResponseCreatedEvent, ResponseFunctionToolCall, ResponseInProgressEvent,
ResponseInputItemParam, ResponseOutputItem, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponsePrompt, ResponseReasoningItem,
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent,
ResponseStatus, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent)
# Backward compatibility for OpenAI client versions
try: # For older openai versions (< 1.100.0)
@ -251,6 +260,26 @@ ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
ResponseReasoningItem,
ResponseFunctionToolCall]
StreamingResponsesResponse: TypeAlias = Union[
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseCompletedEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
ResponseWebSearchCallCompletedEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterCallCompletedEvent,
]
class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation

View File

@ -10,24 +10,28 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Callable, Final, Optional, TypeVar, Union
from typing import Callable, Final, Optional, Union
import jinja2
import openai.types.responses as openai_responses_types
from fastapi import Request
from openai import BaseModel
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.responses import (ResponseCreatedEvent,
ResponseFunctionToolCall,
ResponseInProgressEvent,
ResponseOutputItem,
ResponseOutputItemDoneEvent,
ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus, response_text_delta_event)
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseCodeInterpreterToolCallParam, ResponseCompletedEvent,
ResponseContentPartAddedEvent, ResponseContentPartDoneEvent,
ResponseCreatedEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch,
ResponseInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent, ResponseOutputMessage, ResponseOutputText,
ResponseReasoningItem, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent, ResponseStatus, ResponseTextDeltaEvent,
ResponseTextDoneEvent, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent,
response_function_web_search, response_text_delta_event)
from openai.types.responses.response_output_text import (Logprob,
LogprobTopLogprob)
# yapf: enable
@ -55,7 +59,8 @@ from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse,
OutputTokensDetails,
RequestResponseMetadata,
ResponsesRequest,
ResponsesResponse, ResponseUsage)
ResponsesResponse, ResponseUsage,
StreamingResponsesResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@ -175,7 +180,7 @@ class OpenAIServingResponses(OpenAIServing):
# HACK(wuhang): This is a hack. We should use a better store.
# FIXME: If enable_store=True, this may cause a memory leak since we
# never remove events from the store.
self.event_store: dict[str, tuple[deque[BaseModel],
self.event_store: dict[str, tuple[deque[StreamingResponsesResponse],
asyncio.Event]] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
@ -186,8 +191,8 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[BaseModel, None], ResponsesResponse,
ErrorResponse]:
) -> Union[AsyncGenerator[StreamingResponsesResponse, None],
ResponsesResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
@ -814,7 +819,7 @@ class OpenAIServingResponses(OpenAIServing):
*args,
**kwargs,
):
event_deque: deque[BaseModel] = deque()
event_deque: deque[StreamingResponsesResponse] = deque()
new_event_signal = asyncio.Event()
self.event_store[request.request_id] = (event_deque, new_event_signal)
response = None
@ -867,7 +872,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
response_id: str,
starting_after: Optional[int] = None,
) -> AsyncGenerator[BaseModel, None]:
) -> AsyncGenerator[StreamingResponsesResponse, None]:
if response_id not in self.event_store:
raise ValueError(f"Unknown response_id: {response_id}")
@ -893,8 +898,8 @@ class OpenAIServingResponses(OpenAIServing):
response_id: str,
starting_after: Optional[int],
stream: Optional[bool],
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[BaseModel,
None]]:
) -> Union[ErrorResponse, ResponsesResponse, AsyncGenerator[
StreamingResponsesResponse, None]]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
@ -977,9 +982,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = 0
current_output_index = 0
current_item_id = ""
@ -1017,13 +1022,11 @@ class OpenAIServingResponses(OpenAIServing):
current_item_id = str(uuid.uuid4())
if delta_message.reasoning_content:
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseReasoningItem(
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
@ -1032,13 +1035,11 @@ class OpenAIServingResponses(OpenAIServing):
))
else:
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
@ -1047,13 +1048,13 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
@ -1104,11 +1105,11 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
@ -1119,13 +1120,13 @@ class OpenAIServingResponses(OpenAIServing):
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartAddedEvent(
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
@ -1148,7 +1149,7 @@ class OpenAIServingResponses(OpenAIServing):
))
elif delta_message.content is not None:
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
@ -1204,7 +1205,7 @@ class OpenAIServingResponses(OpenAIServing):
for pm in previous_delta_messages
if pm.content is not None)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
@ -1220,7 +1221,7 @@ class OpenAIServingResponses(OpenAIServing):
annotations=[],
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseContentPartDoneEvent(
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
item_id=current_item_id,
@ -1257,9 +1258,9 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: int,
_increment_sequence_number_and_return: Callable[[BaseModel],
BaseModel],
) -> AsyncGenerator[BaseModel, None]:
_increment_sequence_number_and_return: Callable[
[StreamingResponsesResponse], StreamingResponsesResponse],
) -> AsyncGenerator[StreamingResponsesResponse, None]:
current_content_index = -1
current_output_index = 0
current_item_id: str = ""
@ -1314,7 +1315,7 @@ class OpenAIServingResponses(OpenAIServing):
annotations=[],
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDoneEvent(
ResponseTextDoneEvent(
type="response.output_text.done",
sequence_number=-1,
output_index=current_output_index,
@ -1324,7 +1325,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartDoneEvent(
type="response.content_part.done",
sequence_number=-1,
@ -1334,7 +1334,7 @@ class OpenAIServingResponses(OpenAIServing):
part=text_content,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
@ -1355,13 +1355,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseOutputMessage(
item=ResponseOutputMessage(
id=current_item_id,
type="message",
role="assistant",
@ -1371,14 +1369,13 @@ class OpenAIServingResponses(OpenAIServing):
))
current_content_index += 1
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
@ -1386,7 +1383,7 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseTextDeltaEvent(
ResponseTextDeltaEvent(
type="response.output_text.delta",
sequence_number=-1,
content_index=current_content_index,
@ -1402,13 +1399,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"msg_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseReasoningItem(
item=ResponseReasoningItem(
type="reasoning",
id=current_item_id,
summary=[],
@ -1417,14 +1412,13 @@ class OpenAIServingResponses(OpenAIServing):
))
current_content_index += 1
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseContentPartAddedEvent(
type="response.content_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=openai_responses_types.ResponseOutputText(
part=ResponseOutputText(
type="output_text",
text="",
annotations=[],
@ -1450,13 +1444,11 @@ class OpenAIServingResponses(OpenAIServing):
sent_output_item_added = True
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseCodeInterpreterToolCallParam(
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=None,
@ -1466,7 +1458,6 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInProgressEvent(
type=
"response.code_interpreter_call.in_progress",
@ -1475,7 +1466,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDeltaEvent(
type="response.code_interpreter_call_code.delta",
sequence_number=-1,
@ -1495,14 +1485,12 @@ class OpenAIServingResponses(OpenAIServing):
action = None
parsed_args = json.loads(previous_item.content[0].text)
if function_name == "search":
action = (openai_responses_types.
response_function_web_search.ActionSearch(
action = (response_function_web_search.ActionSearch(
type="search",
query=parsed_args["query"],
))
elif function_name == "open":
action = (
openai_responses_types.
response_function_web_search.ActionOpenPage(
type="open_page",
# TODO: translate to url
@ -1510,7 +1498,6 @@ class OpenAIServingResponses(OpenAIServing):
))
elif function_name == "find":
action = (
openai_responses_types.
response_function_web_search.ActionFind(
type="find",
pattern=parsed_args["pattern"],
@ -1523,12 +1510,11 @@ class OpenAIServingResponses(OpenAIServing):
current_item_id = f"tool_{random_uuid()}"
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemAddedEvent(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
response_function_web_search.
item=response_function_web_search.
ResponseFunctionWebSearch(
# TODO: generate a unique id for web search call
type="web_search_call",
@ -1538,7 +1524,6 @@ class OpenAIServingResponses(OpenAIServing):
),
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallInProgressEvent(
type="response.web_search_call.in_progress",
sequence_number=-1,
@ -1546,7 +1531,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallSearchingEvent(
type="response.web_search_call.searching",
sequence_number=-1,
@ -1556,7 +1540,6 @@ class OpenAIServingResponses(OpenAIServing):
# enqueue
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseWebSearchCallCompletedEvent(
type="response.web_search_call.completed",
sequence_number=-1,
@ -1564,12 +1547,11 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseFunctionWebSearch(
item=ResponseFunctionWebSearch(
type="web_search_call",
id=current_item_id,
action=action,
@ -1582,7 +1564,6 @@ class OpenAIServingResponses(OpenAIServing):
and previous_item.recipient is not None
and previous_item.recipient.startswith("python")):
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCodeDoneEvent(
type="response.code_interpreter_call_code.done",
sequence_number=-1,
@ -1591,7 +1572,6 @@ class OpenAIServingResponses(OpenAIServing):
code=previous_item.content[0].text,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallInterpretingEvent(
type="response.code_interpreter_call.interpreting",
sequence_number=-1,
@ -1599,7 +1579,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.
ResponseCodeInterpreterCallCompletedEvent(
type="response.code_interpreter_call.completed",
sequence_number=-1,
@ -1607,12 +1586,11 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
))
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseOutputItemDoneEvent(
ResponseOutputItemDoneEvent(
type="response.output_item.done",
sequence_number=-1,
output_index=current_output_index,
item=openai_responses_types.
ResponseCodeInterpreterToolCallParam(
item=ResponseCodeInterpreterToolCallParam(
type="code_interpreter_call",
id=current_item_id,
code=previous_item.content[0].text,
@ -1633,7 +1611,7 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> AsyncGenerator[BaseModel, None]:
) -> AsyncGenerator[StreamingResponsesResponse, None]:
# TODO:
# 1. Handle disconnect
@ -1641,9 +1619,9 @@ class OpenAIServingResponses(OpenAIServing):
sequence_number = 0
T = TypeVar("T", bound=BaseModel)
def _increment_sequence_number_and_return(event: T) -> T:
def _increment_sequence_number_and_return(
event: StreamingResponsesResponse
) -> StreamingResponsesResponse:
nonlocal sequence_number
# Set sequence_number if the event has this attribute
if hasattr(event, 'sequence_number'):
@ -1705,7 +1683,7 @@ class OpenAIServingResponses(OpenAIServing):
created_time=created_time,
)
yield _increment_sequence_number_and_return(
openai_responses_types.ResponseCompletedEvent(
ResponseCompletedEvent(
type="response.completed",
sequence_number=-1,
response=final_response.model_dump(),