mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 01:47:52 +08:00
[responsesAPI][5] ResponsesParser with tools for full MCP python loop (#29798)
Signed-off-by: Andrew Xia <axia@fb.com> Signed-off-by: Andrew Xia <axia@meta.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
949a6a19d2
commit
da7bc54ea8
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@ -13,12 +15,27 @@ MODEL_NAME = "Qwen/Qwen3-8B"
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
|
assert importlib.util.find_spec("gpt_oss") is not None, (
|
||||||
|
"Harmony tests require gpt_oss package to be installed"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--reasoning-parser",
|
||||||
|
"qwen3",
|
||||||
|
"--max_model_len",
|
||||||
|
"5000",
|
||||||
|
"--structured-outputs-config.backend",
|
||||||
|
"xgrammar",
|
||||||
|
"--enable-auto-tool-choice",
|
||||||
|
"--tool-call-parser",
|
||||||
|
"hermes",
|
||||||
|
"--tool-server",
|
||||||
|
"demo",
|
||||||
|
]
|
||||||
env_dict = dict(
|
env_dict = dict(
|
||||||
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
||||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
|
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
|
||||||
# uncomment for tool calling
|
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||||
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||||
@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
|
|||||||
assert response.output[0].type == "reasoning"
|
assert response.output[0].type == "reasoning"
|
||||||
assert response.output[1].type == "message"
|
assert response.output[1].type == "message"
|
||||||
assert type(response.output[1].content[0].text) is str
|
assert type(response.output[1].content[0].text) is str
|
||||||
|
|
||||||
|
|
||||||
|
def get_horoscope(sign):
|
||||||
|
return f"{sign}: Next Tuesday you will befriend a baby otter."
|
||||||
|
|
||||||
|
|
||||||
|
def call_function(name, args):
|
||||||
|
if name == "get_horoscope":
|
||||||
|
return get_horoscope(**args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown function: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_function_call_first_turn(client: OpenAI, model_name: str):
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": "get_horoscope",
|
||||||
|
"description": "Get today's horoscope for an astrological sign.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"sign": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["sign"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
"strict": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
input="What is the horoscope for Aquarius today?",
|
||||||
|
tools=tools,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert response.status == "completed"
|
||||||
|
assert len(response.output) == 2
|
||||||
|
assert response.output[0].type == "reasoning"
|
||||||
|
assert response.output[1].type == "function_call"
|
||||||
|
|
||||||
|
function_call = response.output[1]
|
||||||
|
assert function_call.name == "get_horoscope"
|
||||||
|
assert function_call.call_id is not None
|
||||||
|
|
||||||
|
args = json.loads(function_call.arguments)
|
||||||
|
assert "sign" in args
|
||||||
|
|
||||||
|
# the multi turn function call is tested above in
|
||||||
|
# test_reasoning_and_function_items
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_mcp_tool_call(client: OpenAI, model_name: str):
|
||||||
|
response = await client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
input="What is 13 * 24? Use python to calculate the result.",
|
||||||
|
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.status == "completed"
|
||||||
|
assert response.output[0].type == "reasoning"
|
||||||
|
assert response.output[1].type == "mcp_call"
|
||||||
|
assert type(response.output[1].arguments) is str
|
||||||
|
assert type(response.output[1].output) is str
|
||||||
|
assert response.output[2].type == "reasoning"
|
||||||
|
# make sure the correct math is in the final output
|
||||||
|
assert response.output[3].type == "message"
|
||||||
|
assert "312" in response.output[3].content[0].text
|
||||||
|
|||||||
@ -9,10 +9,16 @@ from collections.abc import Callable
|
|||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
|
ResponseFunctionToolCallOutputItem,
|
||||||
|
)
|
||||||
from openai.types.responses.tool import Mcp
|
from openai.types.responses.tool import Mcp
|
||||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.entrypoints.chat_utils import (
|
||||||
|
ChatTemplateContentFormatOption,
|
||||||
|
)
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_encoding,
|
get_encoding,
|
||||||
get_streamable_parser_for_assistant,
|
get_streamable_parser_for_assistant,
|
||||||
@ -22,16 +28,20 @@ from vllm.entrypoints.openai.parser.responses_parser import (
|
|||||||
get_responses_parser_for_simple_context,
|
get_responses_parser_for_simple_context,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
FunctionCall,
|
||||||
ResponseInputOutputItem,
|
ResponseInputOutputItem,
|
||||||
ResponseRawMessageAndToken,
|
ResponseRawMessageAndToken,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
)
|
)
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||||
from vllm.entrypoints.tool import Tool
|
from vllm.entrypoints.tool import Tool
|
||||||
from vllm.entrypoints.tool_server import ToolServer
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||||
|
from vllm.tokenizers.protocol import TokenizerLike
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from mcp.client import ClientSession
|
from mcp.client import ClientSession
|
||||||
@ -221,6 +231,10 @@ class ParsableContext(ConversationContext):
|
|||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
|
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
|
available_tools: list[str] | None,
|
||||||
|
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||||
|
chat_template: str | None,
|
||||||
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
):
|
):
|
||||||
self.num_prompt_tokens = 0
|
self.num_prompt_tokens = 0
|
||||||
self.num_output_tokens = 0
|
self.num_output_tokens = 0
|
||||||
@ -238,12 +252,19 @@ class ParsableContext(ConversationContext):
|
|||||||
reasoning_parser_cls=reasoning_parser_cls,
|
reasoning_parser_cls=reasoning_parser_cls,
|
||||||
response_messages=response_messages,
|
response_messages=response_messages,
|
||||||
request=request,
|
request=request,
|
||||||
|
tool_parser_cls=tool_parser_cls,
|
||||||
)
|
)
|
||||||
|
self.tool_parser_cls = tool_parser_cls
|
||||||
|
self.request = request
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
self.available_tools = available_tools or []
|
||||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||||
self.called_tools: set[str] = set()
|
self.called_tools: set[str] = set()
|
||||||
|
|
||||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||||
|
self.chat_template = chat_template
|
||||||
|
self.chat_template_content_format = chat_template_content_format
|
||||||
|
|
||||||
def append_output(self, output: RequestOutput) -> None:
|
def append_output(self, output: RequestOutput) -> None:
|
||||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||||
@ -252,14 +273,50 @@ class ParsableContext(ConversationContext):
|
|||||||
self.parser.process(output.outputs[0])
|
self.parser.process(output.outputs[0])
|
||||||
|
|
||||||
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
||||||
raise NotImplementedError("Should not be called.")
|
self.parser.response_messages.extend(output)
|
||||||
|
|
||||||
def need_builtin_tool_call(self) -> bool:
|
def need_builtin_tool_call(self) -> bool:
|
||||||
"""Return true if the last message is a MCP tool call"""
|
"""Return true if the last message is a MCP tool call"""
|
||||||
|
last_message = self.parser.response_messages[-1]
|
||||||
|
# TODO: figure out which tools are MCP tools
|
||||||
|
if ( # noqa: SIM103
|
||||||
|
last_message.type == "function_call"
|
||||||
|
and last_message.name in ("code_interpreter", "python")
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def call_python_tool(
|
||||||
|
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
|
||||||
|
) -> list[ResponseInputOutputItem]:
|
||||||
|
self.called_tools.add("python")
|
||||||
|
if isinstance(tool_session, Tool):
|
||||||
|
return await tool_session.get_result_parsable_context(self)
|
||||||
|
args = json.loads(last_msg.arguments)
|
||||||
|
param = {
|
||||||
|
"code": args["code"],
|
||||||
|
}
|
||||||
|
result = await tool_session.call_tool("python", param)
|
||||||
|
result_str = result.content[0].text
|
||||||
|
|
||||||
|
message = ResponseFunctionToolCallOutputItem(
|
||||||
|
id=f"fco_{random_uuid()}",
|
||||||
|
type="function_call_output",
|
||||||
|
call_id=f"call_{random_uuid()}",
|
||||||
|
output=result_str,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
return [message]
|
||||||
|
|
||||||
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
||||||
raise NotImplementedError("Should not be called.")
|
if not self.parser.response_messages:
|
||||||
|
return []
|
||||||
|
last_msg = self.parser.response_messages[-1]
|
||||||
|
if last_msg.name == "code_interpreter":
|
||||||
|
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||||
|
return []
|
||||||
|
|
||||||
def render_for_completion(self):
|
def render_for_completion(self):
|
||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
@ -271,11 +328,38 @@ class ParsableContext(ConversationContext):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
mcp_tools: dict[str, Mcp],
|
mcp_tools: dict[str, Mcp],
|
||||||
):
|
):
|
||||||
pass
|
if tool_server:
|
||||||
|
for tool_name in self.available_tools:
|
||||||
|
if tool_name in self._tool_sessions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||||
|
headers = (
|
||||||
|
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
|
||||||
|
)
|
||||||
|
tool_session = await exit_stack.enter_async_context(
|
||||||
|
tool_server.new_session(tool_name, request_id, headers)
|
||||||
|
)
|
||||||
|
self._tool_sessions[tool_name] = tool_session
|
||||||
|
exit_stack.push_async_exit(self.cleanup_session)
|
||||||
|
|
||||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||||
"""Can be used as coro to used in __aexit__"""
|
"""Can be used as coro to used in __aexit__"""
|
||||||
raise NotImplementedError("Should not be called.")
|
|
||||||
|
async def cleanup_tool_session(tool_session):
|
||||||
|
if not isinstance(tool_session, Tool):
|
||||||
|
logger.info(
|
||||||
|
"Cleaning up tool session for %s", tool_session._client_info
|
||||||
|
)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await tool_session.call_tool("cleanup_session", {})
|
||||||
|
|
||||||
|
await asyncio.gather(
|
||||||
|
*(
|
||||||
|
cleanup_tool_session(self._tool_sessions[tool])
|
||||||
|
for tool in self.called_tools
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class HarmonyContext(ConversationContext):
|
class HarmonyContext(ConversationContext):
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||||
from openai.types.responses.response_output_text import ResponseOutputText
|
from openai.types.responses.response_output_text import ResponseOutputText
|
||||||
from openai.types.responses.response_reasoning_item import (
|
from openai.types.responses.response_reasoning_item import (
|
||||||
@ -11,8 +12,10 @@ from openai.types.responses.response_reasoning_item import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
|
||||||
from vllm.outputs import CompletionOutput
|
from vllm.outputs import CompletionOutput
|
||||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||||
|
from vllm.tokenizers.protocol import TokenizerLike
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
@ -29,6 +32,7 @@ class ResponsesParser:
|
|||||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||||
response_messages: list[ResponseInputOutputItem],
|
response_messages: list[ResponseInputOutputItem],
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
|
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
|
||||||
):
|
):
|
||||||
self.response_messages: list[ResponseInputOutputItem] = (
|
self.response_messages: list[ResponseInputOutputItem] = (
|
||||||
# TODO: initial messages may not be properly typed
|
# TODO: initial messages may not be properly typed
|
||||||
@ -39,6 +43,9 @@ class ResponsesParser:
|
|||||||
self.request = request
|
self.request = request
|
||||||
|
|
||||||
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
||||||
|
self.tool_parser_instance = None
|
||||||
|
if tool_parser_cls is not None:
|
||||||
|
self.tool_parser_instance = tool_parser_cls(tokenizer)
|
||||||
|
|
||||||
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
||||||
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
||||||
@ -59,6 +66,29 @@ class ResponsesParser:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
function_calls: list[ResponseFunctionToolCall] = []
|
||||||
|
if self.tool_parser_instance is not None:
|
||||||
|
tool_call_info = self.tool_parser_instance.extract_tool_calls(
|
||||||
|
content if content is not None else "",
|
||||||
|
request=self.request, # type: ignore
|
||||||
|
)
|
||||||
|
if tool_call_info is not None and tool_call_info.tools_called:
|
||||||
|
# extract_tool_calls() returns a list of tool calls.
|
||||||
|
function_calls.extend(
|
||||||
|
ResponseFunctionToolCall(
|
||||||
|
id=f"fc_{random_uuid()}",
|
||||||
|
call_id=f"call_{random_uuid()}",
|
||||||
|
type="function_call",
|
||||||
|
status="completed",
|
||||||
|
name=tool_call.function.name,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
)
|
||||||
|
for tool_call in tool_call_info.tool_calls
|
||||||
|
)
|
||||||
|
content = tool_call_info.content
|
||||||
|
if content and content.strip() == "":
|
||||||
|
content = None
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
self.response_messages.append(
|
self.response_messages.append(
|
||||||
ResponseOutputMessage(
|
ResponseOutputMessage(
|
||||||
@ -76,6 +106,8 @@ class ResponsesParser:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if len(function_calls) > 0:
|
||||||
|
self.response_messages.extend(function_calls)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
|
|||||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||||
response_messages: list[ResponseInputOutputItem],
|
response_messages: list[ResponseInputOutputItem],
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
|
tool_parser_cls,
|
||||||
) -> ResponsesParser:
|
) -> ResponsesParser:
|
||||||
"""Factory function to create a ResponsesParser with
|
"""Factory function to create a ResponsesParser with
|
||||||
optional reasoning parser.
|
optional reasoning parser.
|
||||||
@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
|
|||||||
reasoning_parser_cls=reasoning_parser_cls,
|
reasoning_parser_cls=reasoning_parser_cls,
|
||||||
response_messages=response_messages,
|
response_messages=response_messages,
|
||||||
request=request,
|
request=request,
|
||||||
|
tool_parser_cls=tool_parser_cls,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -18,6 +18,16 @@ from pydantic import ConfigDict, TypeAdapter
|
|||||||
from starlette.datastructures import Headers
|
from starlette.datastructures import Headers
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
|
from vllm.entrypoints.context import (
|
||||||
|
HarmonyContext,
|
||||||
|
ParsableContext,
|
||||||
|
StreamingHarmonyContext,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
FunctionCall,
|
||||||
|
ResponseInputOutputItem,
|
||||||
|
ResponsesRequest,
|
||||||
|
)
|
||||||
from vllm.entrypoints.pooling.classify.protocol import (
|
from vllm.entrypoints.pooling.classify.protocol import (
|
||||||
ClassificationChatRequest,
|
ClassificationChatRequest,
|
||||||
ClassificationCompletionRequest,
|
ClassificationCompletionRequest,
|
||||||
@ -39,6 +49,7 @@ from vllm.entrypoints.pooling.score.protocol import (
|
|||||||
ScoreRequest,
|
ScoreRequest,
|
||||||
ScoreResponse,
|
ScoreResponse,
|
||||||
)
|
)
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
@ -72,9 +83,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FunctionCall,
|
|
||||||
FunctionDefinition,
|
FunctionDefinition,
|
||||||
ResponsesRequest,
|
|
||||||
TokenizeChatRequest,
|
TokenizeChatRequest,
|
||||||
TokenizeCompletionRequest,
|
TokenizeCompletionRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
@ -85,6 +94,9 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||||
|
from vllm.entrypoints.responses_utils import (
|
||||||
|
construct_input_messages,
|
||||||
|
)
|
||||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
@ -1224,6 +1236,31 @@ class OpenAIServing:
|
|||||||
)
|
)
|
||||||
return engine_request, tokenization_kwargs
|
return engine_request, tokenization_kwargs
|
||||||
|
|
||||||
|
async def _render_next_turn(
|
||||||
|
self,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
messages: list[ResponseInputOutputItem],
|
||||||
|
tool_dicts: list[dict[str, Any]] | None,
|
||||||
|
tool_parser,
|
||||||
|
chat_template: str | None,
|
||||||
|
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||||
|
):
|
||||||
|
new_messages = construct_input_messages(
|
||||||
|
request_input=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, request_prompts, engine_prompts = await self._preprocess_chat(
|
||||||
|
request,
|
||||||
|
tokenizer,
|
||||||
|
new_messages,
|
||||||
|
tool_dicts=tool_dicts,
|
||||||
|
tool_parser=tool_parser,
|
||||||
|
chat_template=chat_template,
|
||||||
|
chat_template_content_format=chat_template_content_format,
|
||||||
|
)
|
||||||
|
return request_prompts, engine_prompts
|
||||||
|
|
||||||
async def _generate_with_builtin_tools(
|
async def _generate_with_builtin_tools(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
@ -1286,11 +1323,27 @@ class OpenAIServing:
|
|||||||
|
|
||||||
# Create inputs for the next turn.
|
# Create inputs for the next turn.
|
||||||
# Render the next prompt token ids.
|
# Render the next prompt token ids.
|
||||||
prompt_token_ids = context.render_for_completion()
|
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
prompt_token_ids = context.render_for_completion()
|
||||||
request_prompt = prompt_token_ids
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||||
|
request_prompt = prompt_token_ids
|
||||||
|
elif isinstance(context, ParsableContext):
|
||||||
|
request_prompts, engine_prompts = await self._render_next_turn(
|
||||||
|
context.request,
|
||||||
|
context.tokenizer,
|
||||||
|
context.parser.response_messages,
|
||||||
|
context.tool_dicts,
|
||||||
|
context.tool_parser_cls,
|
||||||
|
context.chat_template,
|
||||||
|
context.chat_template_content_format,
|
||||||
|
)
|
||||||
|
engine_prompt = engine_prompts[0]
|
||||||
|
request_prompt = request_prompts[0]
|
||||||
|
|
||||||
# Update the sampling params.
|
# Update the sampling params.
|
||||||
sampling_params.max_tokens = self.max_model_len - len(prompt_token_ids)
|
sampling_params.max_tokens = self.max_model_len - len(
|
||||||
|
engine_prompt["prompt_token_ids"]
|
||||||
|
)
|
||||||
# OPTIMIZATION
|
# OPTIMIZATION
|
||||||
priority = orig_priority - 1
|
priority = orig_priority - 1
|
||||||
sub_request += 1
|
sub_request += 1
|
||||||
|
|||||||
@ -375,7 +375,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||||
|
|
||||||
builtin_tool_list: list[str] = []
|
builtin_tool_list: list[str] = []
|
||||||
if self.use_harmony and self.tool_server is not None:
|
if self.tool_server is not None:
|
||||||
if self.tool_server.has_tool("browser"):
|
if self.tool_server.has_tool("browser"):
|
||||||
builtin_tool_list.append("browser")
|
builtin_tool_list.append("browser")
|
||||||
if self.tool_server.has_tool("python"):
|
if self.tool_server.has_tool("python"):
|
||||||
@ -423,6 +423,10 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
reasoning_parser_cls=self.reasoning_parser,
|
reasoning_parser_cls=self.reasoning_parser,
|
||||||
request=request,
|
request=request,
|
||||||
|
tool_parser_cls=self.tool_parser,
|
||||||
|
available_tools=available_tools,
|
||||||
|
chat_template=self.chat_template,
|
||||||
|
chat_template_content_format=self.chat_template_content_format,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context = SimpleContext()
|
context = SimpleContext()
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from openai.types.responses.response import ToolChoice
|
|||||||
from openai.types.responses.response_function_tool_call_output_item import (
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
ResponseFunctionToolCallOutputItem,
|
ResponseFunctionToolCallOutputItem,
|
||||||
)
|
)
|
||||||
|
from openai.types.responses.response_output_item import McpCall
|
||||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
@ -25,6 +26,7 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ResponseInputOutputItem,
|
ResponseInputOutputItem,
|
||||||
)
|
)
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
def make_response_output_items_from_parsable_context(
|
def make_response_output_items_from_parsable_context(
|
||||||
@ -36,7 +38,24 @@ def make_response_output_items_from_parsable_context(
|
|||||||
if not isinstance(message, ResponseFunctionToolCallOutputItem):
|
if not isinstance(message, ResponseFunctionToolCallOutputItem):
|
||||||
output_messages.append(message)
|
output_messages.append(message)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("tool calls not supported for response context")
|
if len(output_messages) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
|
||||||
|
)
|
||||||
|
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
||||||
|
mcp_message = McpCall(
|
||||||
|
id=f"mcp_{random_uuid()}",
|
||||||
|
arguments=output_messages[-1].arguments,
|
||||||
|
name=output_messages[-1].name,
|
||||||
|
server_label=output_messages[
|
||||||
|
-1
|
||||||
|
].name, # TODO: store the server label
|
||||||
|
type="mcp_call",
|
||||||
|
status="completed",
|
||||||
|
output=message.output,
|
||||||
|
# TODO: support error output
|
||||||
|
)
|
||||||
|
output_messages[-1] = mcp_message
|
||||||
|
|
||||||
return output_messages
|
return output_messages
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,17 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
|
ResponseFunctionToolCallOutputItem,
|
||||||
|
)
|
||||||
from openai_harmony import Author, Message, Role, TextContent
|
from openai_harmony import Author, Message, Role, TextContent
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# Avoid circular import.
|
# Avoid circular import.
|
||||||
@ -46,6 +51,10 @@ class Tool(ABC):
|
|||||||
async def get_result(self, context: "ConversationContext") -> Any:
|
async def get_result(self, context: "ConversationContext") -> Any:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HarmonyBrowserTool(Tool):
|
class HarmonyBrowserTool(Tool):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -81,6 +90,9 @@ class HarmonyBrowserTool(Tool):
|
|||||||
tool_output_msgs.append(msg)
|
tool_output_msgs.append(msg)
|
||||||
return tool_output_msgs
|
return tool_output_msgs
|
||||||
|
|
||||||
|
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||||
|
raise NotImplementedError("Not implemented yet")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_config(self) -> Any:
|
def tool_config(self) -> Any:
|
||||||
return self.browser_tool.tool_config
|
return self.browser_tool.tool_config
|
||||||
@ -138,6 +150,38 @@ class HarmonyPythonTool(Tool):
|
|||||||
tool_output_msgs.append(msg)
|
tool_output_msgs.append(msg)
|
||||||
return tool_output_msgs
|
return tool_output_msgs
|
||||||
|
|
||||||
|
async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
|
||||||
|
"""
|
||||||
|
This function converts parsable context types to harmony and
|
||||||
|
back so we can use GPTOSS demo python tool
|
||||||
|
"""
|
||||||
|
from vllm.entrypoints.context import ParsableContext
|
||||||
|
|
||||||
|
assert isinstance(context, ParsableContext)
|
||||||
|
|
||||||
|
last_msg = context.parser.response_messages[-1]
|
||||||
|
args = json.loads(last_msg.arguments)
|
||||||
|
|
||||||
|
last_msg_harmony = Message(
|
||||||
|
author=Author(role="assistant", name=None),
|
||||||
|
content=[TextContent(text=args["code"])],
|
||||||
|
channel="analysis",
|
||||||
|
recipient="python",
|
||||||
|
content_type="code",
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_output_msgs = []
|
||||||
|
async for msg in self.python_tool.process(last_msg_harmony):
|
||||||
|
processed = ResponseFunctionToolCallOutputItem(
|
||||||
|
id=f"fco_{random_uuid()}",
|
||||||
|
type="function_call_output",
|
||||||
|
call_id=f"call_{random_uuid()}",
|
||||||
|
output=msg.content[0].text,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
tool_output_msgs.append(processed)
|
||||||
|
return tool_output_msgs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_config(self) -> Any:
|
def tool_config(self) -> Any:
|
||||||
return self.python_tool.tool_config
|
return self.python_tool.tool_config
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user