mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:15:34 +08:00
[responsesAPI] support input output messages for non harmony models (#29549)
Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
bbfb55c29e
commit
3a7751485b
@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
|
|||||||
assert response.status == "completed"
|
assert response.status == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_enable_response_messages(client: OpenAI, model_name: str):
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
input="Hello?",
|
||||||
|
extra_body={"enable_response_messages": True},
|
||||||
|
)
|
||||||
|
assert response.status == "completed"
|
||||||
|
assert response.input_messages[0]["type"] == "raw_message_tokens"
|
||||||
|
assert type(response.input_messages[0]["message"]) is str
|
||||||
|
assert len(response.input_messages[0]["message"]) > 10
|
||||||
|
assert type(response.input_messages[0]["tokens"][0]) is int
|
||||||
|
assert type(response.output_messages[0]["message"]) is str
|
||||||
|
assert len(response.output_messages[0]["message"]) > 10
|
||||||
|
assert type(response.output_messages[0]["tokens"][0]) is int
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_reasoning_item(client: OpenAI, model_name: str):
|
async def test_reasoning_item(client: OpenAI, model_name: str):
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import (
|
|||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ResponseInputOutputItem,
|
ResponseInputOutputItem,
|
||||||
|
ResponseRawMessageAndToken,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||||
@ -148,6 +149,8 @@ def _create_json_parse_error_messages(
|
|||||||
|
|
||||||
|
|
||||||
class SimpleContext(ConversationContext):
|
class SimpleContext(ConversationContext):
|
||||||
|
"""This is a context that cannot handle MCP tool calls"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.last_output = None
|
self.last_output = None
|
||||||
self.num_prompt_tokens = 0
|
self.num_prompt_tokens = 0
|
||||||
@ -158,6 +161,9 @@ class SimpleContext(ConversationContext):
|
|||||||
# not implemented yet for SimpleContext
|
# not implemented yet for SimpleContext
|
||||||
self.all_turn_metrics = []
|
self.all_turn_metrics = []
|
||||||
|
|
||||||
|
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||||
|
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||||
|
|
||||||
def append_output(self, output) -> None:
|
def append_output(self, output) -> None:
|
||||||
self.last_output = output
|
self.last_output = output
|
||||||
if not isinstance(output, RequestOutput):
|
if not isinstance(output, RequestOutput):
|
||||||
@ -166,6 +172,22 @@ class SimpleContext(ConversationContext):
|
|||||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||||
|
|
||||||
|
if len(self.input_messages) == 0:
|
||||||
|
output_prompt = output.prompt or ""
|
||||||
|
output_prompt_token_ids = output.prompt_token_ids or []
|
||||||
|
self.input_messages.append(
|
||||||
|
ResponseRawMessageAndToken(
|
||||||
|
message=output_prompt,
|
||||||
|
tokens=output_prompt_token_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.output_messages.append(
|
||||||
|
ResponseRawMessageAndToken(
|
||||||
|
message=output.outputs[0].text,
|
||||||
|
tokens=output.outputs[0].token_ids,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def append_tool_output(self, output) -> None:
|
def append_tool_output(self, output) -> None:
|
||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
|
|||||||
@ -1598,6 +1598,20 @@ def serialize_messages(msgs):
|
|||||||
return [serialize_message(msg) for msg in msgs] if msgs else None
|
return [serialize_message(msg) for msg in msgs] if msgs else None
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseRawMessageAndToken(OpenAIBaseModel):
|
||||||
|
"""Class to show the raw message.
|
||||||
|
If message / tokens diverge, tokens is the source of truth"""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
tokens: list[int]
|
||||||
|
type: Literal["raw_message_tokens"] = "raw_message_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
ResponseInputOutputMessage: TypeAlias = (
|
||||||
|
list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResponsesResponse(OpenAIBaseModel):
|
class ResponsesResponse(OpenAIBaseModel):
|
||||||
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
|
||||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||||
@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
|||||||
# These are populated when enable_response_messages is set to True
|
# These are populated when enable_response_messages is set to True
|
||||||
# NOTE: custom serialization is needed
|
# NOTE: custom serialization is needed
|
||||||
# see serialize_input_messages and serialize_output_messages
|
# see serialize_input_messages and serialize_output_messages
|
||||||
input_messages: list[ChatCompletionMessageParam] | None = None
|
input_messages: ResponseInputOutputMessage | None = None
|
||||||
output_messages: list[ChatCompletionMessageParam] | None = None
|
output_messages: ResponseInputOutputMessage | None = None
|
||||||
# --8<-- [end:responses-extra-params]
|
# --8<-- [end:responses-extra-params]
|
||||||
|
|
||||||
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
# NOTE: openAI harmony doesn't serialize TextContent properly,
|
||||||
@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
|||||||
output: list[ResponseOutputItem],
|
output: list[ResponseOutputItem],
|
||||||
status: ResponseStatus,
|
status: ResponseStatus,
|
||||||
usage: ResponseUsage | None = None,
|
usage: ResponseUsage | None = None,
|
||||||
input_messages: list[ChatCompletionMessageParam] | None = None,
|
input_messages: ResponseInputOutputMessage | None = None,
|
||||||
output_messages: list[ChatCompletionMessageParam] | None = None,
|
output_messages: ResponseInputOutputMessage | None = None,
|
||||||
) -> "ResponsesResponse":
|
) -> "ResponsesResponse":
|
||||||
incomplete_details: IncompleteDetails | None = None
|
incomplete_details: IncompleteDetails | None = None
|
||||||
if status == "incomplete":
|
if status == "incomplete":
|
||||||
|
|||||||
@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ResponseCompletedEvent,
|
ResponseCompletedEvent,
|
||||||
ResponseCreatedEvent,
|
ResponseCreatedEvent,
|
||||||
ResponseInProgressEvent,
|
ResponseInProgressEvent,
|
||||||
|
ResponseInputOutputMessage,
|
||||||
ResponseReasoningPartAddedEvent,
|
ResponseReasoningPartAddedEvent,
|
||||||
ResponseReasoningPartDoneEvent,
|
ResponseReasoningPartDoneEvent,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# "completed" is implemented as the "catch-all" for now.
|
# "completed" is implemented as the "catch-all" for now.
|
||||||
status: ResponseStatus = "completed"
|
status: ResponseStatus = "completed"
|
||||||
|
|
||||||
input_messages = None
|
input_messages: ResponseInputOutputMessage | None = None
|
||||||
output_messages = None
|
output_messages: ResponseInputOutputMessage | None = None
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
assert isinstance(context, HarmonyContext)
|
assert isinstance(context, HarmonyContext)
|
||||||
output = self._make_response_output_items_with_harmony(context)
|
output = self._make_response_output_items_with_harmony(context)
|
||||||
@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
output = self._make_response_output_items(request, final_output, tokenizer)
|
output = self._make_response_output_items(request, final_output, tokenizer)
|
||||||
|
|
||||||
# TODO: context for non-gptoss models doesn't use messages
|
|
||||||
# so we can't get them out yet
|
|
||||||
if request.enable_response_messages:
|
if request.enable_response_messages:
|
||||||
raise NotImplementedError(
|
input_messages = context.input_messages
|
||||||
"enable_response_messages is currently only supported for gpt-oss"
|
output_messages = context.output_messages
|
||||||
)
|
|
||||||
# Calculate usage.
|
# Calculate usage.
|
||||||
assert final_res.prompt_token_ids is not None
|
assert final_res.prompt_token_ids is not None
|
||||||
num_tool_output_tokens = 0
|
num_tool_output_tokens = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user