[Feature][Responses API] Support logprobs(non-stream) (#23319)

Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
Kebe 2025-08-22 07:09:16 +08:00 committed by GitHub
parent 8ef6b8a38c
commit 5368f76855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 4 deletions

View File

@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
], ) ], )
print(response) print(response)
assert response.status == "completed" assert response.status == "completed"
@pytest.mark.asyncio
async def test_logprobs(client: openai.AsyncOpenAI):
response = await client.responses.create(
include=["message.output_text.logprobs"],
input="What is 13 * 24?",
top_logprobs=5,
)
print(response)
outputs = response.output
assert outputs[-1].content[-1].logprobs
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5

View File

@ -357,13 +357,22 @@ class ResponsesRequest(OpenAIBaseModel):
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=self.top_logprobs, logprobs=self.top_logprobs
if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
output_kind=(RequestOutputKind.DELTA output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY), if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
) )
def is_include_output_logprobs(self) -> bool:
"""Check if the request includes output logprobs."""
if self.include is None:
return False
return isinstance(
self.include,
list) and "message.output_text.logprobs" in self.include
@model_validator(mode="before") @model_validator(mode="before")
def validate_background(cls, data): def validate_background(cls, data):
if not data.get("background"): if not data.get("background"):
@ -1808,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel):
service_tier: Literal["auto", "default", "flex", "scale", "priority"] service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus status: ResponseStatus
text: Optional[ResponseTextConfig] = None text: Optional[ResponseTextConfig] = None
top_logprobs: int top_logprobs: Optional[int] = None
truncation: Literal["auto", "disabled"] truncation: Literal["auto", "disabled"]
usage: Optional[ResponseUsage] = None usage: Optional[ResponseUsage] = None
user: Optional[str] = None user: Optional[str] = None

View File

@ -4,7 +4,7 @@
import asyncio import asyncio
import json import json
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent,
ResponseReasoningItem, ResponseReasoningItem,
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent) ResponseReasoningTextDoneEvent)
from openai.types.responses.response_output_text import (Logprob,
LogprobTopLogprob)
# yapf: enable # yapf: enable
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent) Content as ResponseReasoningTextContent)
@ -59,6 +61,8 @@ from vllm.logger import init_logger
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob as SampleLogprob
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -201,6 +205,12 @@ class OpenAIServingResponses(OpenAIServing):
# (i.e., their request's `store=True` just because it's the default # (i.e., their request's `store=True` just because it's the default
# value). # value).
request.store = False request.store = False
if self.use_harmony and request.is_include_output_logprobs():
return self.create_error_response(
err_type="invalid_request_error",
message="logprobs are not supported with gpt-oss models",
status_code=HTTPStatus.BAD_REQUEST,
)
# Handle the previous response ID. # Handle the previous response ID.
prev_response_id = request.previous_response_id prev_response_id = request.previous_response_id
@ -491,6 +501,51 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response self.response_store[response.id] = response
return response return response
def _topk_logprobs(self, logprobs: dict[int,
SampleLogprob], top_logprobs: int,
tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary."""
out = []
for i, (token_id, _logprob) in enumerate(logprobs.items()):
if i >= top_logprobs:
break
text = _logprob.decoded_token if _logprob.decoded_token \
is not None else tokenizer.decode([token_id])
out.append(
LogprobTopLogprob(
token=text,
logprob=max(_logprob.logprob, -9999.0),
bytes=list(text.encode("utf-8", errors="replace")),
))
return out
def _create_response_logprobs(
self,
token_ids: Sequence[int],
logprobs: Optional[SampleLogprobs],
tokenizer: AnyTokenizer,
top_logprobs: Optional[int] = None) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided"
assert len(token_ids) == len(logprobs), (
"token_ids and logprobs.token_ids must have the same length")
out = []
for i, token_id in enumerate(token_ids):
logprob = logprobs[i]
token_logprob = logprob[token_id]
text = token_logprob.decoded_token if token_logprob.decoded_token \
is not None else tokenizer.decode([token_id])
out.append(
Logprob(
token=text,
logprob=max(token_logprob.logprob, -9999.0),
bytes=list(text.encode("utf-8", errors="replace")),
top_logprobs=self._topk_logprobs(logprob,
top_logprobs=top_logprobs,
tokenizer=tokenizer)
if top_logprobs else [],
))
return out
def _make_response_output_items( def _make_response_output_items(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
@ -547,7 +602,12 @@ class OpenAIServingResponses(OpenAIServing):
text=content, text=content,
annotations=[], # TODO annotations=[], # TODO
type="output_text", type="output_text",
logprobs=None, # TODO logprobs=self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
) if request.is_include_output_logprobs() else None,
) )
message = ResponseOutputMessage( message = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",