mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:44:28 +08:00
[Feature][Responses API] Support logprobs(non-stream) (#23319)
Signed-off-by: Kebe <mail@kebe7jun.com>
This commit is contained in:
parent
8ef6b8a38c
commit
5368f76855
@ -73,3 +73,16 @@ async def test_chat_with_input_type(client: openai.AsyncOpenAI):
|
|||||||
], )
|
], )
|
||||||
print(response)
|
print(response)
|
||||||
assert response.status == "completed"
|
assert response.status == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_logprobs(client: openai.AsyncOpenAI):
|
||||||
|
response = await client.responses.create(
|
||||||
|
include=["message.output_text.logprobs"],
|
||||||
|
input="What is 13 * 24?",
|
||||||
|
top_logprobs=5,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
outputs = response.output
|
||||||
|
assert outputs[-1].content[-1].logprobs
|
||||||
|
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5
|
||||||
|
|||||||
@ -357,13 +357,22 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=self.top_logprobs,
|
logprobs=self.top_logprobs
|
||||||
|
if self.is_include_output_logprobs() else None,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
output_kind=(RequestOutputKind.DELTA
|
output_kind=(RequestOutputKind.DELTA
|
||||||
if self.stream else RequestOutputKind.FINAL_ONLY),
|
if self.stream else RequestOutputKind.FINAL_ONLY),
|
||||||
guided_decoding=guided_decoding,
|
guided_decoding=guided_decoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_include_output_logprobs(self) -> bool:
|
||||||
|
"""Check if the request includes output logprobs."""
|
||||||
|
if self.include is None:
|
||||||
|
return False
|
||||||
|
return isinstance(
|
||||||
|
self.include,
|
||||||
|
list) and "message.output_text.logprobs" in self.include
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def validate_background(cls, data):
|
def validate_background(cls, data):
|
||||||
if not data.get("background"):
|
if not data.get("background"):
|
||||||
@ -1808,7 +1817,7 @@ class ResponsesResponse(OpenAIBaseModel):
|
|||||||
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
|
||||||
status: ResponseStatus
|
status: ResponseStatus
|
||||||
text: Optional[ResponseTextConfig] = None
|
text: Optional[ResponseTextConfig] = None
|
||||||
top_logprobs: int
|
top_logprobs: Optional[int] = None
|
||||||
truncation: Literal["auto", "disabled"]
|
truncation: Literal["auto", "disabled"]
|
||||||
usage: Optional[ResponseUsage] = None
|
usage: Optional[ResponseUsage] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
@ -25,6 +25,8 @@ from openai.types.responses import (ResponseCreatedEvent,
|
|||||||
ResponseReasoningItem,
|
ResponseReasoningItem,
|
||||||
ResponseReasoningTextDeltaEvent,
|
ResponseReasoningTextDeltaEvent,
|
||||||
ResponseReasoningTextDoneEvent)
|
ResponseReasoningTextDoneEvent)
|
||||||
|
from openai.types.responses.response_output_text import (Logprob,
|
||||||
|
LogprobTopLogprob)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from openai.types.responses.response_reasoning_item import (
|
from openai.types.responses.response_reasoning_item import (
|
||||||
Content as ResponseReasoningTextContent)
|
Content as ResponseReasoningTextContent)
|
||||||
@ -59,6 +61,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.outputs import CompletionOutput
|
from vllm.outputs import CompletionOutput
|
||||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import Logprob as SampleLogprob
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
@ -201,6 +205,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# (i.e., their request's `store=True` just because it's the default
|
# (i.e., their request's `store=True` just because it's the default
|
||||||
# value).
|
# value).
|
||||||
request.store = False
|
request.store = False
|
||||||
|
if self.use_harmony and request.is_include_output_logprobs():
|
||||||
|
return self.create_error_response(
|
||||||
|
err_type="invalid_request_error",
|
||||||
|
message="logprobs are not supported with gpt-oss models",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
# Handle the previous response ID.
|
# Handle the previous response ID.
|
||||||
prev_response_id = request.previous_response_id
|
prev_response_id = request.previous_response_id
|
||||||
@ -491,6 +501,51 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
self.response_store[response.id] = response
|
self.response_store[response.id] = response
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _topk_logprobs(self, logprobs: dict[int,
|
||||||
|
SampleLogprob], top_logprobs: int,
|
||||||
|
tokenizer: AnyTokenizer) -> list[LogprobTopLogprob]:
|
||||||
|
"""Returns the top-k logprobs from the logprobs dictionary."""
|
||||||
|
out = []
|
||||||
|
for i, (token_id, _logprob) in enumerate(logprobs.items()):
|
||||||
|
if i >= top_logprobs:
|
||||||
|
break
|
||||||
|
text = _logprob.decoded_token if _logprob.decoded_token \
|
||||||
|
is not None else tokenizer.decode([token_id])
|
||||||
|
out.append(
|
||||||
|
LogprobTopLogprob(
|
||||||
|
token=text,
|
||||||
|
logprob=max(_logprob.logprob, -9999.0),
|
||||||
|
bytes=list(text.encode("utf-8", errors="replace")),
|
||||||
|
))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _create_response_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: Sequence[int],
|
||||||
|
logprobs: Optional[SampleLogprobs],
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
top_logprobs: Optional[int] = None) -> list[Logprob]:
|
||||||
|
assert logprobs is not None, "logprobs must be provided"
|
||||||
|
assert len(token_ids) == len(logprobs), (
|
||||||
|
"token_ids and logprobs.token_ids must have the same length")
|
||||||
|
out = []
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
logprob = logprobs[i]
|
||||||
|
token_logprob = logprob[token_id]
|
||||||
|
text = token_logprob.decoded_token if token_logprob.decoded_token \
|
||||||
|
is not None else tokenizer.decode([token_id])
|
||||||
|
out.append(
|
||||||
|
Logprob(
|
||||||
|
token=text,
|
||||||
|
logprob=max(token_logprob.logprob, -9999.0),
|
||||||
|
bytes=list(text.encode("utf-8", errors="replace")),
|
||||||
|
top_logprobs=self._topk_logprobs(logprob,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
if top_logprobs else [],
|
||||||
|
))
|
||||||
|
return out
|
||||||
|
|
||||||
def _make_response_output_items(
|
def _make_response_output_items(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
@ -547,7 +602,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
text=content,
|
text=content,
|
||||||
annotations=[], # TODO
|
annotations=[], # TODO
|
||||||
type="output_text",
|
type="output_text",
|
||||||
logprobs=None, # TODO
|
logprobs=self._create_response_logprobs(
|
||||||
|
token_ids=final_output.token_ids,
|
||||||
|
logprobs=final_output.logprobs,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
top_logprobs=request.top_logprobs,
|
||||||
|
) if request.is_include_output_logprobs() else None,
|
||||||
)
|
)
|
||||||
message = ResponseOutputMessage(
|
message = ResponseOutputMessage(
|
||||||
id=f"msg_{random_uuid()}",
|
id=f"msg_{random_uuid()}",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user