[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)

Signed-off-by: Andrew Xia <axia@fb.com>
Signed-off-by: Andrew Xia <axia@meta.com>
Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
Andrew Xia 2025-12-05 08:11:50 -08:00 committed by GitHub
parent 949a6a19d2
commit da7bc54ea8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 347 additions and 16 deletions

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"] assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)
args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict( env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1", VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1", VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
# uncomment for tool calling PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
) )
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning" assert response.output[0].type == "reasoning"
assert response.output[1].type == "message" assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str assert type(response.output[1].content[0].text) is str
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
args = json.loads(function_call.arguments)
assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text

View File

@ -9,10 +9,16 @@ from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm import envs from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.harmony_utils import (
get_encoding, get_encoding,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import (
get_responses_parser_for_simple_context, get_responses_parser_for_simple_context,
) )
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem, ResponseInputOutputItem,
ResponseRawMessageAndToken, ResponseRawMessageAndToken,
ResponsesRequest, ResponsesRequest,
) )
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.entrypoints.responses_utils import construct_tool_dicts from vllm.entrypoints.responses_utils import construct_tool_dicts
from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
if TYPE_CHECKING: if TYPE_CHECKING:
from mcp.client import ClientSession from mcp.client import ClientSession
@ -221,6 +231,10 @@ class ParsableContext(ConversationContext):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None, reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
request: ResponsesRequest, request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
): ):
self.num_prompt_tokens = 0 self.num_prompt_tokens = 0
self.num_output_tokens = 0 self.num_output_tokens = 0
@ -238,12 +252,19 @@ class ParsableContext(ConversationContext):
reasoning_parser_cls=reasoning_parser_cls, reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages, response_messages=response_messages,
request=request, request=request,
tool_parser_cls=tool_parser_cls,
) )
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {} self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set() self.called_tools: set[str] = set()
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice) self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
def append_output(self, output: RequestOutput) -> None: def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_prompt_tokens = len(output.prompt_token_ids or [])
@ -252,14 +273,50 @@ class ParsableContext(ConversationContext):
self.parser.process(output.outputs[0]) self.parser.process(output.outputs[0])
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None: def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
raise NotImplementedError("Should not be called.") self.parser.response_messages.extend(output)
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call""" """Return true if the last message is a MCP tool call"""
last_message = self.parser.response_messages[-1]
# TODO: figure out which tools are MCP tools
if ( # noqa: SIM103
last_message.type == "function_call"
and last_message.name in ("code_interpreter", "python")
):
return True
return False return False
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text
message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)
return [message]
async def call_tool(self) -> list[ResponseInputOutputItem]: async def call_tool(self) -> list[ResponseInputOutputItem]:
raise NotImplementedError("Should not be called.") if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
return []
def render_for_completion(self): def render_for_completion(self):
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
@ -271,11 +328,38 @@ class ParsableContext(ConversationContext):
request_id: str, request_id: str,
mcp_tools: dict[str, Mcp], mcp_tools: dict[str, Mcp],
): ):
pass if tool_server:
for tool_name in self.available_tools:
if tool_name in self._tool_sessions:
continue
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def cleanup_session(self, *args, **kwargs) -> None: async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__""" """Can be used as coro to used in __aexit__"""
raise NotImplementedError("Should not be called.")
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)
class HarmonyContext(ConversationContext): class HarmonyContext(ConversationContext):

View File

@ -3,6 +3,7 @@
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
) )
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -29,6 +32,7 @@ class ResponsesParser:
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest, request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
): ):
self.response_messages: list[ResponseInputOutputItem] = ( self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed # TODO: initial messages may not be properly typed
@ -39,6 +43,9 @@ class ResponsesParser:
self.request = request self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
def process(self, output: CompletionOutput) -> "ResponsesParser": def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning( reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
@ -59,6 +66,29 @@ class ResponsesParser:
) )
) )
function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None
if content: if content:
self.response_messages.append( self.response_messages.append(
ResponseOutputMessage( ResponseOutputMessage(
@ -76,6 +106,8 @@ class ResponsesParser:
], ],
) )
) )
if len(function_calls) > 0:
self.response_messages.extend(function_calls)
return self return self
@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser], reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest, request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser: ) -> ResponsesParser:
"""Factory function to create a ResponsesParser with """Factory function to create a ResponsesParser with
optional reasoning parser. optional reasoning parser.
@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls=reasoning_parser_cls, reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages, response_messages=response_messages,
request=request, request=request,
tool_parser_cls=tool_parser_cls,
) )

View File

@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers from starlette.datastructures import Headers
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import ( from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest, ClassificationChatRequest,
ClassificationCompletionRequest, ClassificationCompletionRequest,
@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest, ScoreRequest,
ScoreResponse, ScoreResponse,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from typing import TypedDict from typing import TypedDict
@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest, DetokenizeRequest,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
FunctionCall,
FunctionDefinition, FunctionDefinition,
ResponsesRequest,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
@ -1224,6 +1236,31 @@ class OpenAIServing:
) )
return engine_request, tokenization_kwargs return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: AnyTokenizer,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
new_messages,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return request_prompts, engine_prompts
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
@ -1286,11 +1323,27 @@ class OpenAIServing:
# Create inputs for the next turn. # Create inputs for the next turn.
# Render the next prompt token ids. # Render the next prompt token ids.
prompt_token_ids = context.render_for_completion() if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) prompt_token_ids = context.render_for_completion()
request_prompt = prompt_token_ids engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids) sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION # OPTIMIZATION
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1

View File

@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = [] builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None: if self.tool_server is not None:
if self.tool_server.has_tool("browser"): if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser") builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"): if self.tool_server.has_tool("python"):
@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer=tokenizer, tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser, reasoning_parser_cls=self.reasoning_parser,
request=request, request=request,
tool_parser_cls=self.tool_parser,
available_tools=available_tools,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
) )
else: else:
context = SimpleContext() context = SimpleContext()

View File

@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem, ResponseFunctionToolCallOutputItem,
) )
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ResponseInputOutputItem, ResponseInputOutputItem,
) )
from vllm.utils import random_uuid
def make_response_output_items_from_parsable_context( def make_response_output_items_from_parsable_context(
@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context(
if not isinstance(message, ResponseFunctionToolCallOutputItem): if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message) output_messages.append(message)
else: else:
raise NotImplementedError("tool calls not supported for response context") if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"mcp_{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type="mcp_call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages return output_messages

View File

@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai_harmony import Author, Message, Role, TextContent from openai_harmony import Author, Message, Role, TextContent
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid
if TYPE_CHECKING: if TYPE_CHECKING:
# Avoid circular import. # Avoid circular import.
@ -46,6 +51,10 @@ class Tool(ABC):
async def get_result(self, context: "ConversationContext") -> Any: async def get_result(self, context: "ConversationContext") -> Any:
pass pass
@abstractmethod
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
pass
class HarmonyBrowserTool(Tool): class HarmonyBrowserTool(Tool):
def __init__(self): def __init__(self):
@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
tool_output_msgs.append(msg) tool_output_msgs.append(msg)
return tool_output_msgs return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
raise NotImplementedError("Not implemented yet")
@property @property
def tool_config(self) -> Any: def tool_config(self) -> Any:
return self.browser_tool.tool_config return self.browser_tool.tool_config
@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
tool_output_msgs.append(msg) tool_output_msgs.append(msg)
return tool_output_msgs return tool_output_msgs
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
"""
This function converts parsable context types to harmony and
back so we can use GPTOSS demo python tool
"""
from vllm.entrypoints.context import ParsableContext
assert isinstance(context, ParsableContext)
last_msg = context.parser.response_messages[-1]
args = json.loads(last_msg.arguments)
last_msg_harmony = Message(
author=Author(role="assistant", name=None),
content=[TextContent(text=args["code"])],
channel="analysis",
recipient="python",
content_type="code",
)
tool_output_msgs = []
async for msg in self.python_tool.process(last_msg_harmony):
processed = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=msg.content[0].text,
status="completed",
)
tool_output_msgs.append(processed)
return tool_output_msgs
@property @property
def tool_config(self) -> Any: def tool_config(self) -> Any:
return self.python_tool.tool_config return self.python_tool.tool_config