mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[responsesAPI] support input output messages for non harmony models (#29549)
Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
bbfb55c29e
commit
3a7751485b
@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
|
||||
assert response.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_enable_response_messages(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Hello?",
|
||||
extra_body={"enable_response_messages": True},
|
||||
)
|
||||
assert response.status == "completed"
|
||||
assert response.input_messages[0]["type"] == "raw_message_tokens"
|
||||
assert type(response.input_messages[0]["message"]) is str
|
||||
assert len(response.input_messages[0]["message"]) > 10
|
||||
assert type(response.input_messages[0]["tokens"][0]) is int
|
||||
assert type(response.output_messages[0]["message"]) is str
|
||||
assert len(response.output_messages[0]["message"]) > 10
|
||||
assert type(response.output_messages[0]["tokens"][0]) is int
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_reasoning_item(client: OpenAI, model_name: str):
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponseRawMessageAndToken,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||
@ -148,6 +149,8 @@ def _create_json_parse_error_messages(
|
||||
|
||||
|
||||
class SimpleContext(ConversationContext):
|
||||
"""This is a context that cannot handle MCP tool calls"""
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
self.num_prompt_tokens = 0
|
||||
@ -158,6 +161,9 @@ class SimpleContext(ConversationContext):
|
||||
# not implemented yet for SimpleContext
|
||||
self.all_turn_metrics = []
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
self.last_output = output
|
||||
if not isinstance(output, RequestOutput):
|
||||
@ -166,6 +172,22 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
self.input_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output_prompt,
|
||||
tokens=output_prompt_token_ids,
|
||||
)
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
@ -1598,6 +1598,20 @@ def serialize_messages(msgs):
|
||||
return [serialize_message(msg) for msg in msgs] if msgs else None
|
||||
|
||||
|
||||
class ResponseRawMessageAndToken(OpenAIBaseModel):
|
||||
"""Class to show the raw message.
|
||||
If message / tokens diverge, tokens is the source of truth"""
|
||||
|
||||
message: str
|
||||
tokens: list[int]
|
||||
type: Literal["raw_message_tokens"] = "raw_message_tokens"
|
||||
|
||||
|
||||
ResponseInputOutputMessage: TypeAlias = (
|
||||
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
|
||||
)
|
||||
|
||||
|
||||
class ResponsesResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# NOTE: custom serialization is needed
|
||||
# see serialize_input_messages and serialize_output_messages
|
||||
input_messages: list[ChatCompletionMessageParam] | None = None
|
||||
output_messages: list[ChatCompletionMessageParam] | None = None
|
||||
input_messages: ResponseInputOutputMessage | None = None
|
||||
output_messages: ResponseInputOutputMessage | None = None
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
||||
@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: ResponseUsage | None = None,
|
||||
input_messages: list[ChatCompletionMessageParam] | None = None,
|
||||
output_messages: list[ChatCompletionMessageParam] | None = None,
|
||||
input_messages: ResponseInputOutputMessage | None = None,
|
||||
output_messages: ResponseInputOutputMessage | None = None,
|
||||
) -> "ResponsesResponse":
|
||||
incomplete_details: IncompleteDetails | None = None
|
||||
if status == "incomplete":
|
||||
|
||||
@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ResponseCompletedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputOutputMessage,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponsesRequest,
|
||||
@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# "completed" is implemented as the "catch-all" for now.
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
input_messages = None
|
||||
output_messages = None
|
||||
input_messages: ResponseInputOutputMessage | None = None
|
||||
output_messages: ResponseInputOutputMessage | None = None
|
||||
if self.use_harmony:
|
||||
assert isinstance(context, HarmonyContext)
|
||||
output = self._make_response_output_items_with_harmony(context)
|
||||
@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
output = self._make_response_output_items(request, final_output, tokenizer)
|
||||
|
||||
# TODO: context for non-gptoss models doesn't use messages
|
||||
# so we can't get them out yet
|
||||
if request.enable_response_messages:
|
||||
raise NotImplementedError(
|
||||
"enable_response_messages is currently only supported for gpt-oss"
|
||||
)
|
||||
input_messages = context.input_messages
|
||||
output_messages = context.output_messages
|
||||
|
||||
# Calculate usage.
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_tool_output_tokens = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user