Implement OpenAI Responses API [1/N] (#20504)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-06 18:32:13 -07:00 committed by GitHub
parent c18b3b8e8b
commit 462b269280
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1106 additions and 8 deletions

View File

@ -95,6 +95,10 @@ def test_openapi_stateless(case: schemathesis.Case):
case.operation.method.upper(),
case.operation.path,
)
if case.operation.path.startswith("/v1/responses"):
# Skip responses API as it is meant to be stateful.
return
timeout = {
# requires a longer timeout
("POST", "/v1/chat/completions"):

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
# Use a small reasoning model to test the responses API.
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")
def default_server_args():
return [
"--max-model-len",
"8192",
"--enforce-eager", # For faster startup.
"--reasoning-parser",
"deepseek_r1",
]
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client

View File

@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai # use the official client for correctness check
import pytest
@pytest.mark.asyncio
async def test_simple_input(client: openai.AsyncOpenAI):
response = await client.responses.create(input="What is 13 * 24?")
print(response)
outputs = response.output
# Whether the output contains the answer.
assert outputs[-1].type == "message"
assert "312" in outputs[-1].content[0].text
# Whether the output contains the reasoning.
assert outputs[0].type == "reasoning"
assert outputs[0].text != ""
@pytest.mark.asyncio
async def test_instructions(client: openai.AsyncOpenAI):
response = await client.responses.create(
instructions="Finish the answer with QED.",
input="What is 13 * 24?",
)
print(response)
output_text = response.output[-1].content[0].text
assert "312" in output_text
assert "QED" in output_text
@pytest.mark.asyncio
async def test_chat(client: openai.AsyncOpenAI):
response = await client.responses.create(input=[
{
"role": "system",
"content": "Finish the answer with QED."
},
{
"role": "user",
"content": "What is 5 * 3?"
},
{
"role": "assistant",
"content": "15. QED."
},
{
"role": "user",
"content": "Multiply the result by 2."
},
], )
print(response)
output_text = response.output[-1].content[0].text
assert "30" in output_text
assert "QED" in output_text
@pytest.mark.asyncio
async def test_chat_with_input_type(client: openai.AsyncOpenAI):
response = await client.responses.create(input=[
{
"role": "user",
"content": [{
"type": "input_text",
"text": "Hello!"
}],
},
], )
print(response)
assert response.status == "completed"

View File

@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import openai
import pytest
@pytest.mark.asyncio
async def test_store(client: openai.AsyncOpenAI):
# By default, store is True.
response = await client.responses.create(input="Hello!")
assert response.status == "completed"
# Retrieve the response.
response = await client.responses.retrieve(response.id)
assert response.status == "completed"
# Test store=False.
response = await client.responses.create(
input="Hello!",
store=False,
)
assert response.status == "completed"
# The response should not be found.
with pytest.raises(openai.NotFoundError,
match="Response with id .* not found."):
await client.responses.retrieve(response.id)
@pytest.mark.asyncio
async def test_background(client: openai.AsyncOpenAI):
# NOTE: This query should be easy enough for the model to answer
# within the 10 seconds.
response = await client.responses.create(
input="Hello!",
background=True,
)
assert response.status == "queued"
max_retries = 10
for _ in range(max_retries):
await asyncio.sleep(1)
response = await client.responses.retrieve(response.id)
if response.status != "queued":
break
print(response)
assert response.status == "completed"
@pytest.mark.asyncio
async def test_background_error(client: openai.AsyncOpenAI):
with pytest.raises(
openai.BadRequestError,
match="background can only be used when `store` is true"):
_ = await client.responses.create(
input="What is 13 * 24?",
background=True,
store=False,
)
@pytest.mark.asyncio
async def test_background_cancel(client: openai.AsyncOpenAI):
response = await client.responses.create(
input="Write a long story about a cat.",
background=True,
)
assert response.status == "queued"
# Cancel the response before it is completed.
# FIXME: This test can be flaky.
await asyncio.sleep(0.5)
response = await client.responses.cancel(response.id)
assert response.status == "cancelled"
# Make sure the response status remains unchanged.
await asyncio.sleep(5)
response = await client.responses.retrieve(response.id)
assert response.status == "cancelled"
@pytest.mark.asyncio
async def test_cancel_completed(client: openai.AsyncOpenAI):
response = await client.responses.create(input="Hello")
assert response.status == "completed"
with pytest.raises(openai.BadRequestError,
match="Cannot cancel a synchronous response."):
await client.responses.cancel(response.id)
@pytest.mark.asyncio
async def test_previous_response_id(client: openai.AsyncOpenAI):
response1 = await client.responses.create(
instructions="You are tested on your ability to retrieve the correct "
"information from the previous response.",
input="Hello, my name is John.")
response2 = await client.responses.create(
input="Actually, my name is not John. My real name is Mark.",
previous_response_id=response1.id,
)
response3 = await client.responses.create(
input="What is my real name again? Answer in one word.",
previous_response_id=response2.id,
)
print(response3)
assert "Mark" in response3.output[-1].content[0].text
assert "John" not in response3.output[-1].content[0].text
@pytest.mark.asyncio
async def test_two_responses_with_same_prev_id(client: openai.AsyncOpenAI):
response1 = await client.responses.create(
instructions="You are tested on your ability to retrieve the correct "
"information from the previous response.",
input="Hello, my name is John.")
# Both response 2 and 3 use response 1 as the previous response.
response2 = client.responses.create(
input="Actually, my name is not John. My name is Mark.",
previous_response_id=response1.id,
)
response3 = client.responses.create(
input="What is my name again? Answer in one word.",
previous_response_id=response1.id,
)
_ = await response2
response3_result = await response3
print(response3_result)
assert "John" in response3_result.output[-1].content[0].text
assert "Mark" not in response3_result.output[-1].content[0].text

View File

@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import openai
import pytest
from pydantic import BaseModel
@pytest.mark.asyncio
async def test_structured_output(client: openai.AsyncOpenAI):
response = await client.responses.create(
input=[
{
"role": "system",
"content": "Extract the event information."
},
{
"role": "user",
"content":
"Alice and Bob are going to a science fair on Friday.",
},
],
text={
"format": {
"type": "json_schema",
"name": "calendar_event",
"schema": {
"type": "object",
"properties": {
"event_name": {
"type": "string"
},
"date": {
"type": "string"
},
"participants": {
"type": "array",
"items": {
"type": "string"
}
},
},
"required": ["event_name", "date", "participants"],
"additionalProperties": False,
},
"description": "A calendar event.",
"strict": True,
}
},
)
print(response)
# NOTE: The JSON schema is applied to the output text, not reasoning.
output_text = response.output[-1].content[0].text
event = json.loads(output_text)
assert event["event_name"].lower() == "science fair"
assert event["date"] == "Friday"
participants = event["participants"]
assert len(participants) == 2
assert participants[0] == "Alice"
assert participants[1] == "Bob"
@pytest.mark.asyncio
async def test_structured_output_with_parse(client: openai.AsyncOpenAI):
class CalendarEvent(BaseModel):
event_name: str
date: str
participants: list[str]
response = await client.responses.parse(
model=None,
instructions="Extract the event information.",
input="Alice and Bob are going to a science fair on Friday.",
text_format=CalendarEvent,
)
print(response)
# The output is successfully parsed.
event = response.output_parsed
assert event is not None
# The output is correct.
assert event.event_name.lower() == "science fair"
assert event.date == "Friday"
participants = event.participants
assert len(participants) == 2
assert participants[0] == "Alice"
assert participants[1] == "Bob"

View File

@ -902,6 +902,8 @@ MM_PARSER_MAP: dict[
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"input_text":
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
@ -1040,7 +1042,7 @@ def _parse_chat_message_content_part(
"with empty / unparsable content.", part, part_type)
return None
if part_type in ("text", "refusal"):
if part_type in ("text", "input_text", "refusal"):
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}

View File

@ -69,8 +69,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
ResponsesRequest,
ResponsesResponse, ScoreRequest,
ScoreResponse, TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
@ -87,6 +88,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
@ -368,6 +370,10 @@ def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
def responses(request: Request) -> Optional[OpenAIServingResponses]:
return request.app.state.openai_serving_responses
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
@ -531,6 +537,71 @@ async def show_version():
return JSONResponse(content=ver)
@router.post("/v1/responses",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {
"content": {
"text/event-stream": {}
}
},
HTTPStatus.BAD_REQUEST.value: {
"model": ErrorResponse
},
HTTPStatus.NOT_FOUND.value: {
"model": ErrorResponse
},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {
"model": ErrorResponse
},
})
@with_cancellation
async def create_responses(request: ResponsesRequest, raw_request: Request):
handler = responses(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Responses API")
generator = await handler.create_responses(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.get("/v1/responses/{response_id}")
async def retrieve_responses(response_id: str, raw_request: Request):
handler = responses(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Responses API")
response = await handler.retrieve_responses(response_id)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return JSONResponse(content=response.model_dump())
@router.post("/v1/responses/{response_id}/cancel")
async def cancel_responses(response_id: str, raw_request: Request):
handler = responses(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Responses API")
response = await handler.cancel_responses(response_id)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return JSONResponse(content=response.model_dump())
@router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)],
responses={
@ -1272,6 +1343,22 @@ async def init_app_state(
prompt_adapters=args.prompt_adapters,
)
await state.openai_serving_models.init_static_loras()
state.openai_serving_responses = OpenAIServingResponses(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
expand_tools_even_if_tool_choice_none=args.
expand_tools_even_if_tool_choice_none,
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,

View File

@ -11,6 +11,12 @@ from typing import Annotated, Any, ClassVar, Literal, Optional, Union
import regex as re
import torch
from fastapi import HTTPException, UploadFile
from openai.types.responses import (ResponseInputParam, ResponseOutputItem,
ResponseOutputMessage, ResponsePrompt,
ResponseStatus, ResponseTextConfig)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import TypeAlias
@ -220,6 +226,124 @@ def get_logits_processors(processors: Optional[LogitsProcessors],
return None
class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/responses/create
background: Optional[bool] = False
include: Optional[list[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
],
]] = None
input: Union[str, ResponseInputParam]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Metadata] = None
model: Optional[str] = None
parallel_tool_calls: Optional[bool] = True
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale",
"priority"] = "auto"
store: Optional[bool] = True
stream: Optional[bool] = False
temperature: Optional[float] = None
text: Optional[ResponseTextConfig] = None
tool_choice: ToolChoice = "auto"
tools: list[Tool] = Field(default_factory=list)
top_logprobs: Optional[int] = 0
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None
# --8<-- [start:responses-extra-params]
request_id: str = Field(
default_factory=lambda: f"resp_{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."),
)
mm_processor_kwargs: Optional[dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."),
)
# --8<-- [end:responses-extra-params]
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
if self.max_output_tokens is None:
max_tokens = default_max_tokens
else:
max_tokens = min(self.max_output_tokens, default_max_tokens)
default_sampling_params = default_sampling_params or {}
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
# Structured output
guided_decoding = None
if self.text is not None and self.text.format is not None:
response_format = self.text.format
if response_format.type == "json_schema":
guided_decoding = GuidedDecodingParams.from_optional(
json=response_format.schema_)
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")
# TODO: add more parameters
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
logprobs=self.top_logprobs,
output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding,
)
@model_validator(mode="before")
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError(
"background can only be used when `store` is true")
return data
@model_validator(mode="before")
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise ValueError("prompt template is not supported")
return data
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@ -1473,6 +1597,83 @@ class TranscriptionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class ResponseReasoningItem(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"rs_{random_uuid()}")
text: str
summary: list = Field(default_factory=list)
type: Literal["reasoning"] = "reasoning"
encrypted_content: Optional[str] = None
status: Optional[Literal["in_progress", "completed", "incomplete"]]
class ResponsesResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
created_at: int = Field(default_factory=lambda: int(time.time()))
# error: Optional[ResponseError] = None
# incomplete_details: Optional[IncompleteDetails] = None
instructions: Optional[str] = None
metadata: Optional[Metadata] = None
model: str
object: Literal["response"] = "response"
output: list[Union[ResponseOutputMessage, ResponseReasoningItem]]
parallel_tool_calls: bool
temperature: float
tool_choice: ToolChoice
tools: list[Tool]
top_p: float
background: bool
max_output_tokens: int
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus
text: Optional[ResponseTextConfig] = None
top_logprobs: int
truncation: Literal["auto", "disabled"]
usage: Optional[UsageInfo] = None
user: Optional[str] = None
@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
created_time: int,
output: list[ResponseOutputItem],
status: ResponseStatus,
usage: Optional[UsageInfo] = None,
) -> "ResponsesResponse":
return cls(
id=request.request_id,
created_at=created_time,
instructions=request.instructions,
metadata=request.metadata,
model=model_name,
output=output,
parallel_tool_calls=request.parallel_tool_calls,
temperature=sampling_params.temperature,
tool_choice=request.tool_choice,
tools=request.tools,
top_p=sampling_params.top_p,
background=request.background,
max_output_tokens=sampling_params.max_tokens,
max_tool_calls=request.max_tool_calls,
previous_response_id=request.previous_response_id,
prompt=request.prompt,
reasoning=request.reasoning,
service_tier=request.service_tier,
status=status,
text=request.text,
top_logprobs=sampling_params.logprobs,
truncation=request.truncation,
user=request.user,
usage=usage,
)
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
ScoreRequest, RerankRequest]

View File

@ -53,7 +53,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
PoolingResponse, RerankRequest,
ScoreRequest, ScoreResponse,
ResponsesRequest, ScoreRequest,
ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
@ -91,7 +92,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
ResponsesRequest]
AnyResponse = Union[
CompletionResponse,
@ -762,7 +764,7 @@ class OpenAIServing:
async def _preprocess_chat(
self,
request: ChatLikeRequest,
request: Union[ChatLikeRequest, ResponsesRequest],
tokenizer: AnyTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str],

View File

@ -0,0 +1,464 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from http import HTTPStatus
from typing import Callable, Final, Optional, Union
import jinja2
from fastapi import Request
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
PromptTokenUsageInfo,
RequestResponseMetadata,
ResponseReasoningItem,
ResponsesRequest,
ResponsesResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
class OpenAIServingResponses(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "",
enable_auto_tools: bool = False,
expand_tools_even_if_tool_choice_none: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage,
)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.reasoning_parser: Optional[Callable[[AnyTokenizer],
ReasoningParser]] = None
if reasoning_parser:
try:
self.reasoning_parser = (
ReasoningParserManager.get_reasoning_parser(
reasoning_parser))
assert self.reasoning_parser is not None
except Exception as e:
raise TypeError(
f"{reasoning_parser=} has not been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params)
# HACK(woosuk): This is a hack. We should use a better store.
# FIXME: This causes a memory leak since we never remove responses
# from the store.
self.response_store: dict[str, ResponsesResponse] = {}
self.response_store_lock = asyncio.Lock()
# HACK(woosuk): This is a hack. We should use a better store.
# FIXME: This causes a memory leak since we never remove messages
# from the store.
self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {}
self.background_tasks: dict[str, asyncio.Task] = {}
async def create_responses(
self,
request: ResponsesRequest,
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
# Handle the previous response ID.
prev_response_id = request.previous_response_id
if prev_response_id is not None:
if not prev_response_id.startswith("resp_"):
return self._make_invalid_id_error(prev_response_id)
async with self.response_store_lock:
prev_response = self.response_store.get(prev_response_id)
if prev_response is None:
return self._make_not_found_error(prev_response_id)
else:
prev_response = None
# Construct the input messages.
messages = self._construct_input_messages(request, prev_response)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(request.model, lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
except (ValueError, TypeError, RuntimeError,
jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
request_metadata = RequestResponseMetadata(
request_id=request.request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
self._log_inputs(request.request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request.request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
async with self.response_store_lock:
self.response_store[response.id] = response
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
if request.stream:
raise NotImplementedError("Streaming responses are not supported")
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
model_name,
tokenizer,
request_metadata,
)
except Exception as e:
return self.create_error_response(str(e))
async def responses_full_generator(
self,
request: ResponsesRequest,
sampling_params: SamplingParams,
result_generator: AsyncIterator[RequestOutput],
model_name: str,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
created_time: Optional[int] = None,
) -> Union[ErrorResponse, ResponsesResponse]:
if created_time is None:
created_time = int(time.time())
final_res: Optional[RequestOutput] = None
try:
async for res in result_generator:
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert final_res is not None
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]
if self.reasoning_parser:
try:
reasoning_parser = self.reasoning_parser(tokenizer)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))
reasoning_content, content = (
reasoning_parser.extract_reasoning_content(final_output.text,
request=request))
else:
reasoning_content = None
content = final_output.text
output = []
if reasoning_content:
reasoning_item = ResponseReasoningItem(
text=reasoning_content,
status=None, # NOTE: Only the last output item has status.
)
output.append(reasoning_item)
if content:
output_text = ResponseOutputText(
text=content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
message = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=[output_text],
role="assistant",
status="completed",
type="message",
)
output.append(message)
# Calculate usage.
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = len(final_output.token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens)
request_metadata.final_usage_info = usage
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=output,
status="completed",
usage=usage,
)
if request.store:
async with self.response_store_lock:
stored_response = self.response_store.get(response.id)
# If the response is already cancelled, don't update it.
if (stored_response is None
or stored_response.status != "cancelled"):
self.response_store[response.id] = response
return response
def _construct_input_messages(
self,
request: ResponsesRequest,
prev_response: Optional[ResponsesResponse] = None,
) -> list[ChatCompletionMessageParam]:
messages: list[ChatCompletionMessageParam] = []
if request.instructions:
messages.append({
"role": "system",
"content": request.instructions,
})
# Prepend the conversation history.
if prev_response is not None:
# Add the previous messages.
prev_msg = self.msg_store[prev_response.id]
messages.extend(prev_msg)
# Add the previous output.
for output_item in prev_response.output:
# NOTE: We skip the reasoning output.
if isinstance(output_item, ResponseOutputMessage):
for content in output_item.content:
messages.append({
"role": "assistant",
"content": content.text,
})
# Append the new input.
# Reponses API supports simple text inputs without chat format.
if isinstance(request.input, str):
messages.append({"role": "user", "content": request.input})
else:
messages.extend(request.input) # type: ignore
return messages
async def _run_background_request(
self,
request: ResponsesRequest,
*args,
**kwargs,
):
try:
response = await self.responses_full_generator(
request, *args, **kwargs)
except Exception as e:
logger.exception("Background request failed for %s",
request.request_id)
response = self.create_error_response(str(e))
if isinstance(response, ErrorResponse):
# If the request has failed, update the status to "failed".
response_id = request.request_id
async with self.response_store_lock:
stored_response = self.response_store.get(response_id)
assert stored_response is not None
if stored_response.status not in ("completed", "cancelled"):
stored_response.status = "failed"
async def retrieve_responses(
self,
response_id: str,
) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
if response is None:
return self._make_not_found_error(response_id)
return response
async def cancel_responses(
self,
response_id: str,
) -> Union[ErrorResponse, ResponsesResponse]:
if not response_id.startswith("resp_"):
return self._make_invalid_id_error(response_id)
async with self.response_store_lock:
response = self.response_store.get(response_id)
if response is None:
return self._make_not_found_error(response_id)
prev_status = response.status
if prev_status not in ("queued", "in_progress"):
return self.create_error_response(
err_type="invalid_request_error",
message="Cannot cancel a synchronous response.",
)
# Update the status to "cancelled".
response.status = "cancelled"
# Abort the request.
if (task := self.background_tasks.get(response_id)):
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.exception("Background task for %s was cancelled",
response_id)
return response
def _make_invalid_id_error(self, response_id: str) -> ErrorResponse:
return self.create_error_response(
err_type="invalid_request_error",
message=(f"Invalid 'response_id': '{response_id}'. "
"Expected an ID that begins with 'resp'."),
)
def _make_not_found_error(self, response_id: str) -> ErrorResponse:
return self.create_error_response(
err_type="invalid_request_error",
message=f"Response with id '{response_id}' not found.",
status_code=HTTPStatus.NOT_FOUND,
)

View File

@ -10,7 +10,7 @@ from functools import cached_property
from typing import Callable, Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
DeltaMessage, ResponsesRequest)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import import_from_path, is_list_of
@ -66,7 +66,9 @@ class ReasoningParser:
@abstractmethod
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
self,
model_output: str,
request: Union[ChatCompletionRequest, ResponsesRequest],
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from a complete model-generated string.