mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 00:45:19 +08:00
[responsesAPI][3] ResponsesParser to set up non harmony MCP (#29413)
Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
0ec8422171
commit
52cb349fc0
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-8B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
|
||||
env_dict = dict(
|
||||
VLLM_ENABLE_RESPONSES_API_STORE="1",
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
|
||||
# uncomment for tool calling
|
||||
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
|
||||
)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_basic(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is 13 * 24?",
|
||||
)
|
||||
assert response is not None
|
||||
print("response: ", response)
|
||||
assert response.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=[
|
||||
{"type": "message", "content": "Hello.", "role": "user"},
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "lol",
|
||||
"content": [
|
||||
{
|
||||
"type": "reasoning_text",
|
||||
"text": "We need to respond: greeting.",
|
||||
}
|
||||
],
|
||||
"summary": [],
|
||||
},
|
||||
{
|
||||
"arguments": '{"location": "Paris", "unit": "celsius"}',
|
||||
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
|
||||
"name": "get_weather",
|
||||
"type": "function_call",
|
||||
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
|
||||
"status": "completed",
|
||||
},
|
||||
{
|
||||
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
|
||||
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
|
||||
"output": "The weather in Paris is 20 Celsius",
|
||||
"status": "completed",
|
||||
"type": "function_call_output",
|
||||
},
|
||||
],
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
# make sure we get a reasoning and text output
|
||||
assert response.output[0].type == "reasoning"
|
||||
assert response.output[1].type == "message"
|
||||
assert type(response.output[1].content[0].text) is str
|
||||
@ -1530,6 +1530,7 @@ def _parse_chat_message_content(
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
reasoning = message.get("reasoning") or message.get("reasoning_content")
|
||||
|
||||
if content is None:
|
||||
content = []
|
||||
elif isinstance(content, str):
|
||||
|
||||
@ -5,6 +5,7 @@ import contextlib
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
@ -17,9 +18,19 @@ from vllm.entrypoints.harmony_utils import (
|
||||
get_streamable_parser_for_assistant,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.openai.parser.responses_parser import (
|
||||
get_responses_parser_for_simple_context,
|
||||
)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ResponseInputOutputItem,
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.responses_utils import construct_tool_dicts
|
||||
from vllm.entrypoints.tool import Tool
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client import ClientSession
|
||||
@ -180,6 +191,71 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class ParsableContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
):
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
# TODO: num_reasoning_tokens is not implemented yet.
|
||||
self.num_reasoning_tokens = 0
|
||||
# not implemented yet for ParsableContext
|
||||
self.all_turn_metrics: list[TurnMetrics] = []
|
||||
|
||||
if reasoning_parser_cls is None:
|
||||
raise ValueError("reasoning_parser_cls must be provided.")
|
||||
|
||||
self.parser = get_responses_parser_for_simple_context(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
)
|
||||
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
self.called_tools: set[str] = set()
|
||||
|
||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
|
||||
def append_output(self, output: RequestOutput) -> None:
|
||||
self.num_prompt_tokens = len(output.prompt_token_ids or [])
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
self.parser.process(output.outputs[0])
|
||||
|
||||
def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
"""Return true if the last message is a MCP tool call"""
|
||||
return False
|
||||
|
||||
async def call_tool(self) -> list[ResponseInputOutputItem]:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
def render_for_completion(self):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(
|
||||
self,
|
||||
tool_server: ToolServer | None,
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str,
|
||||
mcp_tools: dict[str, Mcp],
|
||||
):
|
||||
pass
|
||||
|
||||
async def cleanup_session(self, *args, **kwargs) -> None:
|
||||
"""Can be used as coro to used in __aexit__"""
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
class HarmonyContext(ConversationContext):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
0
vllm/entrypoints/openai/parser/__init__.py
Normal file
0
vllm/entrypoints/openai/parser/__init__.py
Normal file
101
vllm/entrypoints/openai/parser/responses_parser.py
Normal file
101
vllm/entrypoints/openai/parser/responses_parser.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai.types.responses.response_output_message import ResponseOutputMessage
|
||||
from openai.types.responses.response_output_text import ResponseOutputText
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResponsesParser:
|
||||
"""Incremental parser over completion tokens with reasoning support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
):
|
||||
self.response_messages: list[ResponseInputOutputItem] = (
|
||||
# TODO: initial messages may not be properly typed
|
||||
response_messages
|
||||
)
|
||||
self.num_init_messages = len(response_messages)
|
||||
self.tokenizer = tokenizer
|
||||
self.request = request
|
||||
|
||||
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
|
||||
|
||||
def process(self, output: CompletionOutput) -> "ResponsesParser":
|
||||
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
|
||||
output.text, request=self.request
|
||||
)
|
||||
if reasoning_content:
|
||||
self.response_messages.append(
|
||||
ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
content=[
|
||||
Content(
|
||||
type="reasoning_text",
|
||||
text=reasoning_content,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if content:
|
||||
self.response_messages.append(
|
||||
ResponseOutputMessage(
|
||||
type="message",
|
||||
id=f"msg_{random_uuid()}",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
content=[
|
||||
ResponseOutputText(
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
text=content,
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def get_responses_parser_for_simple_context(
|
||||
*,
|
||||
tokenizer: AnyTokenizer,
|
||||
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
request: ResponsesRequest,
|
||||
) -> ResponsesParser:
|
||||
"""Factory function to create a ResponsesParser with
|
||||
optional reasoning parser.
|
||||
|
||||
Returns:
|
||||
ResponsesParser instance configured with the provided parser
|
||||
"""
|
||||
return ResponsesParser(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
response_messages=response_messages,
|
||||
request=request,
|
||||
)
|
||||
@ -60,6 +60,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
from vllm.entrypoints.context import (
|
||||
ConversationContext,
|
||||
HarmonyContext,
|
||||
ParsableContext,
|
||||
SimpleContext,
|
||||
StreamingHarmonyContext,
|
||||
)
|
||||
@ -96,8 +97,9 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.responses_utils import (
|
||||
construct_input_messages,
|
||||
convert_tool_responses_to_completions_format,
|
||||
construct_tool_dicts,
|
||||
extract_tool_types,
|
||||
make_response_output_items_from_parsable_context,
|
||||
)
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
@ -228,7 +230,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self.tool_parser = self._get_tool_parser(
|
||||
tool_parser_name=tool_parser, enable_auto_tools=enable_auto_tools
|
||||
)
|
||||
self.exclude_tools_when_tool_choice_none = False
|
||||
# HACK(woosuk): This is a hack. We should use a better store.
|
||||
# FIXME: If enable_store=True, this may cause a memory leak since we
|
||||
# never remove responses from the store.
|
||||
@ -413,7 +414,17 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
else:
|
||||
context = HarmonyContext(messages, available_tools)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
|
||||
# This is an feature in development for parsing
|
||||
# tokens during generation instead of at the end
|
||||
context = ParsableContext(
|
||||
response_messages=messages,
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=self.reasoning_parser,
|
||||
request=request,
|
||||
)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
|
||||
if self.reasoning_parser is not None:
|
||||
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||
@ -534,15 +545,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
prev_response: ResponsesResponse | None,
|
||||
tokenizer: TokenizerLike,
|
||||
):
|
||||
if request.tools is None or (
|
||||
request.tool_choice == "none" and self.exclude_tools_when_tool_choice_none
|
||||
):
|
||||
tool_dicts = None
|
||||
else:
|
||||
tool_dicts = [
|
||||
convert_tool_responses_to_completions_format(tool.model_dump())
|
||||
for tool in request.tools
|
||||
]
|
||||
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
# Construct the input messages.
|
||||
messages = construct_input_messages(
|
||||
request_instructions=request.instructions,
|
||||
@ -642,6 +645,22 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
status = "cancelled"
|
||||
else:
|
||||
status = "incomplete"
|
||||
elif isinstance(context, ParsableContext):
|
||||
response_messages = context.parser.response_messages[
|
||||
context.parser.num_init_messages :
|
||||
]
|
||||
output = make_response_output_items_from_parsable_context(response_messages)
|
||||
|
||||
# TODO: context for non-gptoss models doesn't use messages
|
||||
# so we can't get them out yet
|
||||
if request.enable_response_messages:
|
||||
raise NotImplementedError(
|
||||
"enable_response_messages is currently only supported for gpt-oss"
|
||||
)
|
||||
|
||||
# TODO: Calculate usage.
|
||||
# assert final_res.prompt_token_ids is not None
|
||||
num_tool_output_tokens = 0
|
||||
else:
|
||||
assert isinstance(context, SimpleContext)
|
||||
final_res = context.last_output
|
||||
@ -661,7 +680,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_tool_output_tokens = 0
|
||||
|
||||
assert isinstance(context, (SimpleContext, HarmonyContext))
|
||||
assert isinstance(context, (SimpleContext, HarmonyContext, ParsableContext))
|
||||
num_prompt_tokens = context.num_prompt_tokens
|
||||
num_generated_tokens = context.num_output_tokens
|
||||
num_cached_tokens = context.num_cached_tokens
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
@ -10,6 +12,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
Function as FunctionCallTool,
|
||||
)
|
||||
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
)
|
||||
@ -24,6 +27,20 @@ from vllm.entrypoints.openai.protocol import (
|
||||
)
|
||||
|
||||
|
||||
def make_response_output_items_from_parsable_context(
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
) -> list[ResponseOutputItem]:
|
||||
"""Given a list of sentences, construct ResponseOutput Items."""
|
||||
output_messages: list[ResponseOutputItem] = []
|
||||
for message in response_messages:
|
||||
if not isinstance(message, ResponseFunctionToolCallOutputItem):
|
||||
output_messages.append(message)
|
||||
else:
|
||||
raise NotImplementedError("tool calls not supported for response context")
|
||||
|
||||
return output_messages
|
||||
|
||||
|
||||
def construct_input_messages(
|
||||
*,
|
||||
request_instructions: str | None = None,
|
||||
@ -146,3 +163,16 @@ def convert_tool_responses_to_completions_format(tool: dict) -> dict:
|
||||
"type": "function",
|
||||
"function": tool,
|
||||
}
|
||||
|
||||
|
||||
def construct_tool_dicts(
|
||||
tools: list[Tool], tool_choice: ToolChoice
|
||||
) -> list[dict[str, Any]] | None:
|
||||
if tools is None or (tool_choice == "none"):
|
||||
tool_dicts = None
|
||||
else:
|
||||
tool_dicts = [
|
||||
convert_tool_responses_to_completions_format(tool.model_dump())
|
||||
for tool in tools
|
||||
]
|
||||
return tool_dicts
|
||||
|
||||
@ -214,6 +214,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||
VLLM_TUNED_CONFIG_FOLDER: str | None = None
|
||||
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
@ -1444,6 +1445,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
|
||||
),
|
||||
# Experimental: use this to enable MCP tool calling for non harmony models
|
||||
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))
|
||||
),
|
||||
# Allows vllm to find tuned config under customized folder
|
||||
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
# Valid values are container,code_interpreter,web_search_preview
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user