[Feature][Responses API] Support logprobs(non-stream) (#23319)

Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
Kebe 2025-08-22 07:09:16 +08:00 committed by GitHub
parent 8ef6b8a38c
commit 5368f76855
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 86 additions and 4 deletions

View File

@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
], )
print(response)
assert response.status == "completed"
@pytest.mark.asyncio
async def test_logprobs(client: openai.AsyncOpenAI):
response = await client.responses.create(
include=["message.output_text.logprobs"],
input="What is 13 * 24?",
top_logprobs=5,
)
print(response)
outputs = response.output
assert outputs[-1].content[-1].logprobs
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5

View File

@ -357,13 +357,22 @@ class ResponsesRequest(OpenAIBaseModel):
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
logprobs=self.top_logprobs,
logprobs=self.top_logprobs
if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding,
)
def is_include_output_logprobs(self) -> bool:
"""Check if the request includes output logprobs."""
if self.include is None:
return False
return isinstance(
self.include,
list) and "message.output_text.logprobs" in self.include
@model_validator(mode="before")
def validate_background(cls, data):
if not data.get("background"):
@ -1808,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel):
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus
text: Optional[ResponseTextConfig] = None
top_logprobs: int
top_logprobs: Optional[int] = None
truncation: Literal["auto", "disabled"]
usage: Optional[ResponseUsage] = None
user: Optional[str] = None

View File

@ -4,7 +4,7 @@
import asyncio
import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent,
ResponseReasoningItem,
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent)
from openai.types.responses.response_output_text import (Logprob,
LogprobTopLogprob)
# yapf: enable
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent)
@ -59,6 +61,8 @@ from vllm.logger import init_logger
from vllm.outputs import CompletionOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob as SampleLogprob
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
@ -201,6 +205,12 @@ class OpenAIServingResponses(OpenAIServing):
# (i.e., their request's `store=True` just because it's the default
# value).
request.store = False
if self.use_harmony and request.is_include_output_logprobs():
return self.create_error_response(
err_type="invalid_request_error",
message="logprobs are not supported with gpt-oss models",
status_code=HTTPStatus.BAD_REQUEST,
)
# Handle the previous response ID.
prev_response_id = request.previous_response_id
@ -491,6 +501,51 @@ class OpenAIServingResponses(OpenAIServing):
self.response_store[response.id] = response
return response
def _topk_logprobs(self, logprobs: dict[int,
SampleLogprob], top_logprobs: int,
tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]:
"""Returns the top-k logprobs from the logprobs dictionary."""
out = []
for i, (token_id, _logprob) in enumerate(logprobs.items()):
if i >= top_logprobs:
break
text = _logprob.decoded_token if _logprob.decoded_token \
is not None else tokenizer.decode([token_id])
out.append(
LogprobTopLogprob(
token=text,
logprob=max(_logprob.logprob, -9999.0),
bytes=list(text.encode("utf-8", errors="replace")),
))
return out
def _create_response_logprobs(
self,
token_ids: Sequence[int],
logprobs: Optional[SampleLogprobs],
tokenizer: AnyTokenizer,
top_logprobs: Optional[int] = None) -> list[Logprob]:
assert logprobs is not None, "logprobs must be provided"
assert len(token_ids) == len(logprobs), (
"token_ids and logprobs.token_ids must have the same length")
out = []
for i, token_id in enumerate(token_ids):
logprob = logprobs[i]
token_logprob = logprob[token_id]
text = token_logprob.decoded_token if token_logprob.decoded_token \
is not None else tokenizer.decode([token_id])
out.append(
Logprob(
token=text,
logprob=max(token_logprob.logprob, -9999.0),
bytes=list(text.encode("utf-8", errors="replace")),
top_logprobs=self._topk_logprobs(logprob,
top_logprobs=top_logprobs,
tokenizer=tokenizer)
if top_logprobs else [],
))
return out
def _make_response_output_items(
self,
request: ResponsesRequest,
@ -547,7 +602,12 @@ class OpenAIServingResponses(OpenAIServing):
text=content,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
logprobs=self._create_response_logprobs(
token_ids=final_output.token_ids,
logprobs=final_output.logprobs,
tokenizer=tokenizer,
top_logprobs=request.top_logprobs,
) if request.is_include_output_logprobs() else None,
)
message = ResponseOutputMessage(
id=f"msg_{random_uuid()}",