From da7bc54ea8f44a2dcacc4a9869721bd105006e10 Mon Sep 17 00:00:00 2001 From: Andrew Xia Date: Fri, 5 Dec 2025 08:11:50 -0800 Subject: [PATCH] [responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798) Signed-off-by: Andrew Xia Signed-off-by: Andrew Xia Co-authored-by: Andrew Xia --- .../openai_responses_client_with_tools.py | 2 +- .../test_response_api_parsable_context.py | 99 ++++++++++++++++++- vllm/entrypoints/context.py | 92 ++++++++++++++++- .../openai/parser/responses_parser.py | 34 +++++++ vllm/entrypoints/openai/serving_engine.py | 65 ++++++++++-- vllm/entrypoints/openai/serving_responses.py | 6 +- vllm/entrypoints/responses_utils.py | 21 +++- vllm/entrypoints/tool.py | 44 +++++++++ 8 files changed, 347 insertions(+), 16 deletions(-) diff --git a/examples/online_serving/openai_responses_client_with_tools.py b/examples/online_serving/openai_responses_client_with_tools.py index 276010197b5ab..c85c8cf807b49 100644 --- a/examples/online_serving/openai_responses_client_with_tools.py +++ b/examples/online_serving/openai_responses_client_with_tools.py @@ -3,7 +3,7 @@ """ Set up this example by starting a vLLM OpenAI-compatible server with tool call options enabled. -Reasoning models can be used through the Responses API as seen here +Reasoning models can be used through the Responses API as seen here https://platform.openai.com/docs/api-reference/responses For example: vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \ diff --git a/tests/entrypoints/openai/test_response_api_parsable_context.py b/tests/entrypoints/openai/test_response_api_parsable_context.py index 1b2795770d4c7..1899c5f04fe3f 100644 --- a/tests/entrypoints/openai/test_response_api_parsable_context.py +++ b/tests/entrypoints/openai/test_response_api_parsable_context.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import json import pytest import pytest_asyncio @@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B" @pytest.fixture(scope="module") def server(): - args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"] + assert importlib.util.find_spec("gpt_oss") is not None, ( + "Harmony tests require gpt_oss package to be installed" + ) + + args = [ + "--reasoning-parser", + "qwen3", + "--max_model_len", + "5000", + "--structured-outputs-config.backend", + "xgrammar", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--tool-server", + "demo", + ] env_dict = dict( VLLM_ENABLE_RESPONSES_API_STORE="1", VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1", - # uncomment for tool calling - # PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + PYTHON_EXECUTION_BACKEND="dangerously_use_uv", ) with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: @@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str): assert response.output[0].type == "reasoning" assert response.output[1].type == "message" assert type(response.output[1].content[0].text) is str + + +def get_horoscope(sign): + return f"{sign}: Next Tuesday you will befriend a baby otter." + + +def call_function(name, args): + if name == "get_horoscope": + return get_horoscope(**args) + else: + raise ValueError(f"Unknown function: {name}") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_call_first_turn(client: OpenAI, model_name: str): + tools = [ + { + "type": "function", + "name": "get_horoscope", + "description": "Get today's horoscope for an astrological sign.", + "parameters": { + "type": "object", + "properties": { + "sign": {"type": "string"}, + }, + "required": ["sign"], + "additionalProperties": False, + }, + "strict": True, + } + ] + + response = await client.responses.create( + model=model_name, + input="What is the horoscope for Aquarius today?", + tools=tools, + temperature=0.0, + ) + assert response is not None + assert response.status == "completed" + assert len(response.output) == 2 + assert response.output[0].type == "reasoning" + assert response.output[1].type == "function_call" + + function_call = response.output[1] + assert function_call.name == "get_horoscope" + assert function_call.call_id is not None + + args = json.loads(function_call.arguments) + assert "sign" in args + + # the multi turn function call is tested above in + # test_reasoning_and_function_items + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_mcp_tool_call(client: OpenAI, model_name: str): + response = await client.responses.create( + model=model_name, + input="What is 13 * 24? Use python to calculate the result.", + tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], + temperature=0.0, + ) + + assert response is not None + assert response.status == "completed" + assert response.output[0].type == "reasoning" + assert response.output[1].type == "mcp_call" + assert type(response.output[1].arguments) is str + assert type(response.output[1].output) is str + assert response.output[2].type == "reasoning" + # make sure the correct math is in the final output + assert response.output[3].type == "message" + assert "312" in response.output[3].content[0].text diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 43783c92667af..f50c473d7a773 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -9,10 +9,16 @@ from collections.abc import Callable from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Union +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent from vllm import envs +from vllm.entrypoints.chat_utils import ( + ChatTemplateContentFormatOption, +) from vllm.entrypoints.harmony_utils import ( get_encoding, get_streamable_parser_for_assistant, @@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import ( get_responses_parser_for_simple_context, ) from vllm.entrypoints.openai.protocol import ( + FunctionCall, ResponseInputOutputItem, ResponseRawMessageAndToken, ResponsesRequest, ) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.entrypoints.responses_utils import construct_tool_dicts from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.tokenizers.protocol import TokenizerLike from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid if TYPE_CHECKING: from mcp.client import ClientSession @@ -221,6 +231,10 @@ class ParsableContext(ConversationContext): tokenizer: AnyTokenizer, reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None, request: ResponsesRequest, + available_tools: list[str] | None, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, ): self.num_prompt_tokens = 0 self.num_output_tokens = 0 @@ -238,12 +252,19 @@ class ParsableContext(ConversationContext): reasoning_parser_cls=reasoning_parser_cls, response_messages=response_messages, request=request, + tool_parser_cls=tool_parser_cls, ) + self.tool_parser_cls = tool_parser_cls + self.request = request + self.tokenizer = tokenizer + self.available_tools = available_tools or [] self._tool_sessions: dict[str, ClientSession | Tool] = {} self.called_tools: set[str] = set() self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) + self.chat_template = chat_template + self.chat_template_content_format = chat_template_content_format def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) @@ -252,14 +273,50 @@ class ParsableContext(ConversationContext): self.parser.process(output.outputs[0]) def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None: - raise NotImplementedError("Should not be called.") + self.parser.response_messages.extend(output) def need_builtin_tool_call(self) -> bool: """Return true if the last message is a MCP tool call""" + last_message = self.parser.response_messages[-1] + # TODO: figure out which tools are MCP tools + if ( # noqa: SIM103 + last_message.type == "function_call" + and last_message.name in ("code_interpreter", "python") + ): + return True + return False + async def call_python_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall + ) -> list[ResponseInputOutputItem]: + self.called_tools.add("python") + if isinstance(tool_session, Tool): + return await tool_session.get_result_parsable_context(self) + args = json.loads(last_msg.arguments) + param = { + "code": args["code"], + } + result = await tool_session.call_tool("python", param) + result_str = result.content[0].text + + message = ResponseFunctionToolCallOutputItem( + id=f"fco_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=result_str, + status="completed", + ) + + return [message] + async def call_tool(self) -> list[ResponseInputOutputItem]: - raise NotImplementedError("Should not be called.") + if not self.parser.response_messages: + return [] + last_msg = self.parser.response_messages[-1] + if last_msg.name == "code_interpreter": + return await self.call_python_tool(self._tool_sessions["python"], last_msg) + return [] def render_for_completion(self): raise NotImplementedError("Should not be called.") @@ -271,11 +328,38 @@ class ParsableContext(ConversationContext): request_id: str, mcp_tools: dict[str, Mcp], ): - pass + if tool_server: + for tool_name in self.available_tools: + if tool_name in self._tool_sessions: + continue + + tool_type = _map_tool_name_to_tool_type(tool_name) + headers = ( + mcp_tools[tool_type].headers if tool_type in mcp_tools else None + ) + tool_session = await exit_stack.enter_async_context( + tool_server.new_session(tool_name, request_id, headers) + ) + self._tool_sessions[tool_name] = tool_session + exit_stack.push_async_exit(self.cleanup_session) async def cleanup_session(self, *args, **kwargs) -> None: """Can be used as coro to used in __aexit__""" - raise NotImplementedError("Should not be called.") + + async def cleanup_tool_session(tool_session): + if not isinstance(tool_session, Tool): + logger.info( + "Cleaning up tool session for %s", tool_session._client_info + ) + with contextlib.suppress(Exception): + await tool_session.call_tool("cleanup_session", {}) + + await asyncio.gather( + *( + cleanup_tool_session(self._tool_sessions[tool]) + for tool in self.called_tools + ) + ) class HarmonyContext(ConversationContext): diff --git a/vllm/entrypoints/openai/parser/responses_parser.py b/vllm/entrypoints/openai/parser/responses_parser.py index 1bc8e81bd9dfc..00045a7ccfd24 100644 --- a/vllm/entrypoints/openai/parser/responses_parser.py +++ b/vllm/entrypoints/openai/parser/responses_parser.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_reasoning_item import ( @@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import ( ) from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser from vllm.outputs import CompletionOutput from vllm.reasoning.abs_reasoning_parsers import ReasoningParser +from vllm.tokenizers.protocol import TokenizerLike from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid @@ -29,6 +32,7 @@ class ResponsesParser: reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], response_messages: list[ResponseInputOutputItem], request: ResponsesRequest, + tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, ): self.response_messages: list[ResponseInputOutputItem] = ( # TODO: initial messages may not be properly typed @@ -39,6 +43,9 @@ class ResponsesParser: self.request = request self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) + self.tool_parser_instance = None + if tool_parser_cls is not None: + self.tool_parser_instance = tool_parser_cls(tokenizer) def process(self, output: CompletionOutput) -> "ResponsesParser": reasoning_content, content = self.reasoning_parser_instance.extract_reasoning( @@ -59,6 +66,29 @@ class ResponsesParser: ) ) + function_calls: list[ResponseFunctionToolCall] = [] + if self.tool_parser_instance is not None: + tool_call_info = self.tool_parser_instance.extract_tool_calls( + content if content is not None else "", + request=self.request, # type: ignore + ) + if tool_call_info is not None and tool_call_info.tools_called: + # extract_tool_calls() returns a list of tool calls. + function_calls.extend( + ResponseFunctionToolCall( + id=f"fc_{random_uuid()}", + call_id=f"call_{random_uuid()}", + type="function_call", + status="completed", + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in tool_call_info.tool_calls + ) + content = tool_call_info.content + if content and content.strip() == "": + content = None + if content: self.response_messages.append( ResponseOutputMessage( @@ -76,6 +106,8 @@ class ResponsesParser: ], ) ) + if len(function_calls) > 0: + self.response_messages.extend(function_calls) return self @@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context( reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], response_messages: list[ResponseInputOutputItem], request: ResponsesRequest, + tool_parser_cls, ) -> ResponsesParser: """Factory function to create a ResponsesParser with optional reasoning parser. @@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context( reasoning_parser_cls=reasoning_parser_cls, response_messages=response_messages, request=request, + tool_parser_cls=tool_parser_cls, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bfa98f29a064b..99936f588f28b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter from starlette.datastructures import Headers from typing_extensions import TypeIs +from vllm.entrypoints.context import ( + HarmonyContext, + ParsableContext, + StreamingHarmonyContext, +) +from vllm.entrypoints.openai.protocol import ( + FunctionCall, + ResponseInputOutputItem, + ResponsesRequest, +) from vllm.entrypoints.pooling.classify.protocol import ( ClassificationChatRequest, ClassificationCompletionRequest, @@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ScoreRequest, ScoreResponse, ) +from vllm.transformers_utils.tokenizer import AnyTokenizer if sys.version_info >= (3, 12): from typing import TypedDict @@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import ( DetokenizeRequest, ErrorInfo, ErrorResponse, - FunctionCall, FunctionDefinition, - ResponsesRequest, TokenizeChatRequest, TokenizeCompletionRequest, TokenizeResponse, @@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.responses_utils import ( + construct_input_messages, +) from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import PromptType @@ -1224,6 +1236,31 @@ class OpenAIServing: ) return engine_request, tokenization_kwargs + async def _render_next_turn( + self, + request: ResponsesRequest, + tokenizer: AnyTokenizer, + messages: list[ResponseInputOutputItem], + tool_dicts: list[dict[str, Any]] | None, + tool_parser, + chat_template: str | None, + chat_template_content_format: ChatTemplateContentFormatOption, + ): + new_messages = construct_input_messages( + request_input=messages, + ) + + _, request_prompts, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + new_messages, + tool_dicts=tool_dicts, + tool_parser=tool_parser, + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + ) + return request_prompts, engine_prompts + async def _generate_with_builtin_tools( self, request_id: str, @@ -1286,11 +1323,27 @@ class OpenAIServing: # Create inputs for the next turn. # Render the next prompt token ids. - prompt_token_ids = context.render_for_completion() - engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) - request_prompt = prompt_token_ids + if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): + prompt_token_ids = context.render_for_completion() + engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + request_prompt = prompt_token_ids + elif isinstance(context, ParsableContext): + request_prompts, engine_prompts = await self._render_next_turn( + context.request, + context.tokenizer, + context.parser.response_messages, + context.tool_dicts, + context.tool_parser_cls, + context.chat_template, + context.chat_template_content_format, + ) + engine_prompt = engine_prompts[0] + request_prompt = request_prompts[0] + # Update the sampling params. - sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) + sampling_params.max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"] + ) # OPTIMIZATION priority = orig_priority - 1 sub_request += 1 diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 3c9ae8e8c8087..1eb1243e7e5bc 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing): generators: list[AsyncGenerator[ConversationContext, None]] = [] builtin_tool_list: list[str] = [] - if self.use_harmony and self.tool_server is not None: + if self.tool_server is not None: if self.tool_server.has_tool("browser"): builtin_tool_list.append("browser") if self.tool_server.has_tool("python"): @@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing): tokenizer=tokenizer, reasoning_parser_cls=self.reasoning_parser, request=request, + tool_parser_cls=self.tool_parser, + available_tools=available_tools, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, ) else: context = SimpleContext() diff --git a/vllm/entrypoints/responses_utils.py b/vllm/entrypoints/responses_utils.py index 5f21e2c44450c..fbc137bac4543 100644 --- a/vllm/entrypoints/responses_utils.py +++ b/vllm/entrypoints/responses_utils.py @@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice from openai.types.responses.response_function_tool_call_output_item import ( ResponseFunctionToolCallOutputItem, ) +from openai.types.responses.response_output_item import McpCall from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_reasoning_item import ResponseReasoningItem from openai.types.responses.tool import Tool @@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionMessageParam, ResponseInputOutputItem, ) +from vllm.utils import random_uuid def make_response_output_items_from_parsable_context( @@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context( if not isinstance(message, ResponseFunctionToolCallOutputItem): output_messages.append(message) else: - raise NotImplementedError("tool calls not supported for response context") + if len(output_messages) == 0: + raise ValueError( + "Cannot have a FunctionToolCallOutput before FunctionToolCall." + ) + if isinstance(output_messages[-1], ResponseFunctionToolCall): + mcp_message = McpCall( + id=f"mcp_{random_uuid()}", + arguments=output_messages[-1].arguments, + name=output_messages[-1].name, + server_label=output_messages[ + -1 + ].name, # TODO: store the server label + type="mcp_call", + status="completed", + output=message.output, + # TODO: support error output + ) + output_messages[-1] = mcp_message return output_messages diff --git a/vllm/entrypoints/tool.py b/vllm/entrypoints/tool.py index c74ce1ee16de1..4feed827385d1 100644 --- a/vllm/entrypoints/tool.py +++ b/vllm/entrypoints/tool.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from openai.types.responses.response_function_tool_call_output_item import ( + ResponseFunctionToolCallOutputItem, +) from openai_harmony import Author, Message, Role, TextContent from vllm.logger import init_logger +from vllm.utils import random_uuid if TYPE_CHECKING: # Avoid circular import. @@ -46,6 +51,10 @@ class Tool(ABC): async def get_result(self, context: "ConversationContext") -> Any: pass + @abstractmethod + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + pass + class HarmonyBrowserTool(Tool): def __init__(self): @@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool): tool_output_msgs.append(msg) return tool_output_msgs + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + raise NotImplementedError("Not implemented yet") + @property def tool_config(self) -> Any: return self.browser_tool.tool_config @@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool): tool_output_msgs.append(msg) return tool_output_msgs + async def get_result_parsable_context(self, context: "ConversationContext") -> Any: + """ + This function converts parsable context types to harmony and + back so we can use GPTOSS demo python tool + """ + from vllm.entrypoints.context import ParsableContext + + assert isinstance(context, ParsableContext) + + last_msg = context.parser.response_messages[-1] + args = json.loads(last_msg.arguments) + + last_msg_harmony = Message( + author=Author(role="assistant", name=None), + content=[TextContent(text=args["code"])], + channel="analysis", + recipient="python", + content_type="code", + ) + + tool_output_msgs = [] + async for msg in self.python_tool.process(last_msg_harmony): + processed = ResponseFunctionToolCallOutputItem( + id=f"fco_{random_uuid()}", + type="function_call_output", + call_id=f"call_{random_uuid()}", + output=msg.content[0].text, + status="completed", + ) + tool_output_msgs.append(processed) + return tool_output_msgs + @property def tool_config(self) -> Any: return self.python_tool.tool_config