[responsesAPI][6] Fix multi turn MCP tokenization (#30230)

Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia 2025-12-09 18:13:13 -08:00 committed by GitHub
parent abe93bce59
commit c3487aca34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 110 additions and 13 deletions

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem, ResponseFunctionToolCallOutputItem,
) )
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
) )
from vllm.entrypoints.responses_utils import ( from vllm.entrypoints.responses_utils import (
construct_chat_message_with_tool_call, _construct_single_message_from_response_item,
construct_chat_messages_with_tool_call,
convert_tool_responses_to_completions_format, convert_tool_responses_to_completions_format,
) )
@ -42,7 +44,43 @@ class TestResponsesUtils:
assert result == {"type": "function", "function": input_tool} assert result == {"type": "function", "function": input_tool}
def test_construct_chat_message_with_tool_call(self): def test_construct_chat_messages_with_tool_call(self):
"""Test construction of chat messages with tool calls."""
reasoning_item = ResponseReasoningItem(
id="lol",
summary=[],
type="reasoning",
content=[
Content(
text="Leroy Jenkins",
type="reasoning_text",
)
],
encrypted_content=None,
status=None,
)
mcp_tool_item = ResponseFunctionToolCall(
id="mcp_123",
call_id="call_123",
type="function_call",
status="completed",
name="python",
arguments='{"code": "123+456"}',
)
input_items = [reasoning_item, mcp_tool_item]
messages = construct_chat_messages_with_tool_call(input_items)
assert len(messages) == 1
message = messages[0]
assert message["role"] == "assistant"
assert message["reasoning"] == "Leroy Jenkins"
assert message["tool_calls"][0]["id"] == "call_123"
assert message["tool_calls"][0]["function"]["name"] == "python"
assert (
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
)
def test_construct_single_message_from_response_item(self):
item = ResponseReasoningItem( item = ResponseReasoningItem(
id="lol", id="lol",
summary=[], summary=[],
@ -56,7 +94,7 @@ class TestResponsesUtils:
encrypted_content=None, encrypted_content=None,
status=None, status=None,
) )
formatted_item = construct_chat_message_with_tool_call(item) formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert formatted_item["reasoning"] == "Leroy Jenkins" assert formatted_item["reasoning"] == "Leroy Jenkins"
@ -74,7 +112,7 @@ class TestResponsesUtils:
status=None, status=None,
) )
formatted_item = construct_chat_message_with_tool_call(item) formatted_item = _construct_single_message_from_response_item(item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert ( assert (
formatted_item["reasoning"] formatted_item["reasoning"]
@ -88,7 +126,7 @@ class TestResponsesUtils:
output="1234", output="1234",
status="completed", status="completed",
) )
formatted_item = construct_chat_message_with_tool_call(tool_call_output) formatted_item = _construct_single_message_from_response_item(tool_call_output)
assert formatted_item["role"] == "tool" assert formatted_item["role"] == "tool"
assert formatted_item["content"] == "1234" assert formatted_item["content"] == "1234"
assert formatted_item["tool_call_id"] == "temp" assert formatted_item["tool_call_id"] == "temp"
@ -102,7 +140,7 @@ class TestResponsesUtils:
status=None, status=None,
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
construct_chat_message_with_tool_call(item) _construct_single_message_from_response_item(item)
output_item = ResponseOutputMessage( output_item = ResponseOutputMessage(
id="msg_bf585bbbe3d500e0", id="msg_bf585bbbe3d500e0",
@ -119,6 +157,6 @@ class TestResponsesUtils:
type="message", type="message",
) )
formatted_item = construct_chat_message_with_tool_call(output_item) formatted_item = _construct_single_message_from_response_item(output_item)
assert formatted_item["role"] == "assistant" assert formatted_item["role"] == "assistant"
assert formatted_item["content"] == "dongyi" assert formatted_item["content"] == "dongyi"

View File

@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
# These constants help mitigate header abuse attacks # These constants help mitigate header abuse attacks
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
H11_MAX_HEADER_COUNT_DEFAULT = 256 H11_MAX_HEADER_COUNT_DEFAULT = 256
MCP_PREFIX = "mcp_"

View File

@ -19,6 +19,7 @@ from vllm import envs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.parser.harmony_utils import (
get_encoding, get_encoding,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
result_str = result.content[0].text result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem( message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}", id=f"mcpo_{random_uuid()}",
type="function_call_output", type="function_call_output",
call_id=f"call_{random_uuid()}", call_id=f"call_{random_uuid()}",
output=result_str, output=result_str,
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
if not self.parser.response_messages: if not self.parser.response_messages:
return [] return []
last_msg = self.parser.response_messages[-1] last_msg = self.parser.response_messages[-1]
# change this to a mcp_ function call
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
self.parser.response_messages[-1] = last_msg
if last_msg.name == "code_interpreter": if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg) return await self.call_python_tool(self._tool_sessions["python"], last_msg)
elif last_msg.name == "web_search_preview": elif last_msg.name == "web_search_preview":

View File

@ -1339,6 +1339,7 @@ class OpenAIServing:
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0] request_prompt = request_prompts[0]
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len( sampling_params.max_tokens = self.max_model_len - len(

View File

@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
from vllm import envs from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ResponseInputOutputItem, ResponseInputOutputItem,
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
) )
if isinstance(output_messages[-1], ResponseFunctionToolCall): if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall( mcp_message = McpCall(
id=f"mcp_{random_uuid()}", id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments, arguments=output_messages[-1].arguments,
name=output_messages[-1].name, name=output_messages[-1].name,
server_label=output_messages[ server_label=output_messages[
-1 -1
].name, # TODO: store the server label ].name, # TODO: store the server label
type="mcp_call", type=f"{MCP_PREFIX}call",
status="completed", status="completed",
output=message.output, output=message.output,
# TODO: support error output # TODO: support error output
@ -98,12 +99,63 @@ def construct_input_messages(
if isinstance(request_input, str): if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input}) messages.append({"role": "user", "content": request_input})
else: else:
for item in request_input: input_messages = construct_chat_messages_with_tool_call(request_input)
messages.append(construct_chat_message_with_tool_call(item)) messages.extend(input_messages)
return messages return messages
def construct_chat_message_with_tool_call( def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem, item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam: ) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall): if isinstance(item, ResponseFunctionToolCall):