mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Feature][Responses API] Support logprobs(non-stream) (#23319)
Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
parent
8ef6b8a38c
commit
5368f76855
@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
|
||||
], )
|
||||
print(response)
|
||||
assert response.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logprobs(client: openai.AsyncOpenAI):
|
||||
response = await client.responses.create(
|
||||
include=["message.output_text.logprobs"],
|
||||
input="What is 13 * 24?",
|
||||
top_logprobs=5,
|
||||
)
|
||||
print(response)
|
||||
outputs = response.output
|
||||
assert outputs[-1].content[-1].logprobs
|
||||
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5
|
||||
|
||||
@ -357,13 +357,22 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=self.top_logprobs,
|
||||
logprobs=self.top_logprobs
|
||||
if self.is_include_output_logprobs() else None,
|
||||
stop_token_ids=stop_token_ids,
|
||||
output_kind=(RequestOutputKind.DELTA
|
||||
if self.stream else RequestOutputKind.FINAL_ONLY),
|
||||
guided_decoding=guided_decoding,
|
||||
)
|
||||
|
||||
def is_include_output_logprobs(self) -> bool:
|
||||
"""Check if the request includes output logprobs."""
|
||||
if self.include is None:
|
||||
return False
|
||||
return isinstance(
|
||||
self.include,
|
||||
list) and "message.output_text.logprobs" in self.include
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
@ -1808,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||
status: ResponseStatus
|
||||
text: Optional[ResponseTextConfig] = None
|
||||
top_logprobs: int
|
||||
top_logprobs: Optional[int] = None
|
||||
truncation: Literal["auto", "disabled"]
|
||||
usage: Optional[ResponseUsage] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||
from contextlib import AsyncExitStack
|
||||
from copy import copy
|
||||
from http import HTTPStatus
|
||||
@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent,
|
||||
ResponseReasoningItem,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent)
|
||||
from openai.types.responses.response_output_text import (Logprob,
|
||||
LogprobTopLogprob)
|
||||
# yapf: enable
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent)
|
||||
@ -59,6 +61,8 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob as SampleLogprob
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@ -201,6 +205,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# (i.e., their request's `store=True` just because it's the default
|
||||
# value).
|
||||
request.store = False
|
||||
if self.use_harmony and request.is_include_output_logprobs():
|
||||
return self.create_error_response(
|
||||
err_type="invalid_request_error",
|
||||
message="logprobs are not supported with gpt-oss models",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Handle the previous response ID.
|
||||
prev_response_id = request.previous_response_id
|
||||
@ -491,6 +501,51 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self.response_store[response.id] = response
|
||||
return response
|
||||
|
||||
def _topk_logprobs(self, logprobs: dict[int,
|
||||
SampleLogprob], top_logprobs: int,
|
||||
tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]:
|
||||
"""Returns the top-k logprobs from the logprobs dictionary."""
|
||||
out = []
|
||||
for i, (token_id, _logprob) in enumerate(logprobs.items()):
|
||||
if i >= top_logprobs:
|
||||
break
|
||||
text = _logprob.decoded_token if _logprob.decoded_token \
|
||||
is not None else tokenizer.decode([token_id])
|
||||
out.append(
|
||||
LogprobTopLogprob(
|
||||
token=text,
|
||||
logprob=max(_logprob.logprob, -9999.0),
|
||||
bytes=list(text.encode("utf-8", errors="replace")),
|
||||
))
|
||||
return out
|
||||
|
||||
def _create_response_logprobs(
|
||||
self,
|
||||
token_ids: Sequence[int],
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
tokenizer: AnyTokenizer,
|
||||
top_logprobs: Optional[int] = None) -> list[Logprob]:
|
||||
assert logprobs is not None, "logprobs must be provided"
|
||||
assert len(token_ids) == len(logprobs), (
|
||||
"token_ids and logprobs.token_ids must have the same length")
|
||||
out = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
logprob = logprobs[i]
|
||||
token_logprob = logprob[token_id]
|
||||
text = token_logprob.decoded_token if token_logprob.decoded_token \
|
||||
is not None else tokenizer.decode([token_id])
|
||||
out.append(
|
||||
Logprob(
|
||||
token=text,
|
||||
logprob=max(token_logprob.logprob, -9999.0),
|
||||
bytes=list(text.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._topk_logprobs(logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
tokenizer=tokenizer)
|
||||
if top_logprobs else [],
|
||||
))
|
||||
return out
|
||||
|
||||
def _make_response_output_items(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
@ -547,7 +602,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
text=content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=None, # TODO
|
||||
logprobs=self._create_response_logprobs(
|
||||
token_ids=final_output.token_ids,
|
||||
logprobs=final_output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=request.top_logprobs,
|
||||
) if request.is_include_output_logprobs() else None,
|
||||
)
|
||||
message = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user