diff --git a/tests/entrypoints/openai/test_response_api_simple.py b/tests/entrypoints/openai/test_response_api_simple.py index 425b8199a0fd..aee03199bc6f 100644 --- a/tests/entrypoints/openai/test_response_api_simple.py +++ b/tests/entrypoints/openai/test_response_api_simple.py @@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str): assert response.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_enable_response_messages(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="Hello?", + extra_body={"enable_response_messages": True}, + ) + assert response.status == "completed" + assert response.input_messages[0]["type"] == "raw_message_tokens" + assert type(response.input_messages[0]["message"]) is str + assert len(response.input_messages[0]["message"]) > 10 + assert type(response.input_messages[0]["tokens"][0]) is int + assert type(response.output_messages[0]["message"]) is str + assert len(response.output_messages[0]["message"]) > 10 + assert type(response.output_messages[0]["tokens"][0]) is int + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_reasoning_item(client: OpenAI, model_name: str): diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 1260f65dba59..43783c92667a 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -23,6 +23,7 @@ from vllm.entrypoints.openai.parser.responses_parser import ( ) from vllm.entrypoints.openai.protocol import ( ResponseInputOutputItem, + ResponseRawMessageAndToken, ResponsesRequest, ) from vllm.entrypoints.responses_utils import construct_tool_dicts @@ -148,6 +149,8 @@ def _create_json_parse_error_messages( class SimpleContext(ConversationContext): + """This is a context that cannot handle MCP tool calls""" + def __init__(self): self.last_output = None self.num_prompt_tokens = 0 @@ -158,6 +161,9 @@ class SimpleContext(ConversationContext): # not implemented yet for SimpleContext self.all_turn_metrics = [] + self.input_messages: list[ResponseRawMessageAndToken] = [] + self.output_messages: list[ResponseRawMessageAndToken] = [] + def append_output(self, output) -> None: self.last_output = output if not isinstance(output, RequestOutput): @@ -166,6 +172,22 @@ class SimpleContext(ConversationContext): self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) + if len(self.input_messages) == 0: + output_prompt = output.prompt or "" + output_prompt_token_ids = output.prompt_token_ids or [] + self.input_messages.append( + ResponseRawMessageAndToken( + message=output_prompt, + tokens=output_prompt_token_ids, + ) + ) + self.output_messages.append( + ResponseRawMessageAndToken( + message=output.outputs[0].text, + tokens=output.outputs[0].token_ids, + ) + ) + def append_tool_output(self, output) -> None: raise NotImplementedError("Should not be called.") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0f4b2b4d7aad..2d34a6a0cd5a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1598,6 +1598,20 @@ def serialize_messages(msgs): return [serialize_message(msg) for msg in msgs] if msgs else None +class ResponseRawMessageAndToken(OpenAIBaseModel): + """Class to show the raw message. + If message / tokens diverge, tokens is the source of truth""" + + message: str + tokens: list[int] + type: Literal["raw_message_tokens"] = "raw_message_tokens" + + +ResponseInputOutputMessage: TypeAlias = ( + list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken] +) + + class ResponsesResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") created_at: int = Field(default_factory=lambda: int(time.time())) @@ -1631,8 +1645,8 @@ class ResponsesResponse(OpenAIBaseModel): # These are populated when enable_response_messages is set to True # NOTE: custom serialization is needed # see serialize_input_messages and serialize_output_messages - input_messages: list[ChatCompletionMessageParam] | None = None - output_messages: list[ChatCompletionMessageParam] | None = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None # --8<-- [end:responses-extra-params] # NOTE: openAI harmony doesn't serialize TextContent properly, @@ -1658,8 +1672,8 @@ class ResponsesResponse(OpenAIBaseModel): output: list[ResponseOutputItem], status: ResponseStatus, usage: ResponseUsage | None = None, - input_messages: list[ChatCompletionMessageParam] | None = None, - output_messages: list[ChatCompletionMessageParam] | None = None, + input_messages: ResponseInputOutputMessage | None = None, + output_messages: ResponseInputOutputMessage | None = None, ) -> "ResponsesResponse": incomplete_details: IncompleteDetails | None = None if status == "incomplete": diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 5ad86194ce1b..3c9ae8e8c808 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -86,6 +86,7 @@ from vllm.entrypoints.openai.protocol import ( ResponseCompletedEvent, ResponseCreatedEvent, ResponseInProgressEvent, + ResponseInputOutputMessage, ResponseReasoningPartAddedEvent, ResponseReasoningPartDoneEvent, ResponsesRequest, @@ -629,8 +630,8 @@ class OpenAIServingResponses(OpenAIServing): # "completed" is implemented as the "catch-all" for now. status: ResponseStatus = "completed" - input_messages = None - output_messages = None + input_messages: ResponseInputOutputMessage | None = None + output_messages: ResponseInputOutputMessage | None = None if self.use_harmony: assert isinstance(context, HarmonyContext) output = self._make_response_output_items_with_harmony(context) @@ -670,12 +671,10 @@ class OpenAIServingResponses(OpenAIServing): output = self._make_response_output_items(request, final_output, tokenizer) - # TODO: context for non-gptoss models doesn't use messages - # so we can't get them out yet if request.enable_response_messages: - raise NotImplementedError( - "enable_response_messages is currently only supported for gpt-oss" - ) + input_messages = context.input_messages + output_messages = context.output_messages + # Calculate usage. assert final_res.prompt_token_ids is not None num_tool_output_tokens = 0