mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 21:45:40 +08:00
[gpt-oss] Support tool call and implement MCP tool server (#22427)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
e290594072
commit
fe6d8257a1
@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
|||||||
id=f"rs_{random_uuid()}",
|
id=f"rs_{random_uuid()}",
|
||||||
summary=[],
|
summary=[],
|
||||||
type="reasoning",
|
type="reasoning",
|
||||||
text=content.text,
|
content=[
|
||||||
|
ResponseReasoningTextContent(text=content.text,
|
||||||
|
type="reasoning_text")
|
||||||
|
],
|
||||||
status=None,
|
status=None,
|
||||||
)
|
)
|
||||||
output_items.append(reasoning_item)
|
output_items.append(reasoning_item)
|
||||||
|
|||||||
@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import (
|
|||||||
from vllm.entrypoints.openai.serving_transcription import (
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
OpenAIServingTranscription, OpenAIServingTranslation)
|
OpenAIServingTranscription, OpenAIServingTranslation)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.tool_server import DemoToolServer, ToolServer
|
from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
|
||||||
|
ToolServer)
|
||||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||||
log_non_default_args, with_cancellation)
|
log_non_default_args, with_cancellation)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -1635,6 +1636,9 @@ async def init_app_state(
|
|||||||
|
|
||||||
if args.tool_server == "demo":
|
if args.tool_server == "demo":
|
||||||
tool_server: Optional[ToolServer] = DemoToolServer()
|
tool_server: Optional[ToolServer] = DemoToolServer()
|
||||||
|
elif args.tool_server:
|
||||||
|
tool_server = MCPToolServer()
|
||||||
|
await tool_server.add_tool_server(args.tool_server)
|
||||||
else:
|
else:
|
||||||
tool_server = None
|
tool_server = None
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Callable, Final, Optional, Union
|
from typing import Any, Callable, Final, Optional, Union
|
||||||
@ -226,8 +227,31 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||||
|
|
||||||
|
builtin_tool_list: list[str] = []
|
||||||
|
if self.use_harmony and self.tool_server is not None:
|
||||||
|
if self.tool_server.has_tool("browser"):
|
||||||
|
builtin_tool_list.append("browser")
|
||||||
|
if self.tool_server.has_tool("python"):
|
||||||
|
builtin_tool_list.append("python")
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
tool_sessions: dict[str, Any] = {}
|
if self.tool_server is not None:
|
||||||
|
# TODO: initialize tool sessions lazily when the session
|
||||||
|
# is actually used.
|
||||||
|
tool_session_ctxs: dict[str, Any] = {
|
||||||
|
tool_name:
|
||||||
|
exit_stack.enter_async_context(
|
||||||
|
self.tool_server.new_session(tool_name))
|
||||||
|
for tool_name in builtin_tool_list
|
||||||
|
}
|
||||||
|
tool_sessions = {}
|
||||||
|
for tool_name in builtin_tool_list:
|
||||||
|
tool_sessions[tool_name] = (
|
||||||
|
await tool_session_ctxs[tool_name])
|
||||||
|
else:
|
||||||
|
assert len(builtin_tool_list) == 0
|
||||||
|
tool_sessions = {}
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
@ -235,7 +259,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
default_max_tokens, self.default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
trace_headers = (None if raw_request is None else await
|
trace_headers = (None if raw_request is None else await
|
||||||
self._get_trace_headers(raw_request.headers))
|
self._get_trace_headers(
|
||||||
|
raw_request.headers))
|
||||||
|
|
||||||
context: ConversationContext
|
context: ConversationContext
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
@ -305,7 +330,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
raise NotImplementedError("Streaming responses are not supported")
|
raise NotImplementedError(
|
||||||
|
"Streaming responses are not supported")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.responses_full_generator(
|
return await self.responses_full_generator(
|
||||||
@ -319,6 +345,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
return self.create_error_response("Should not reach here")
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from openai_harmony import ToolNamespaceConfig
|
from openai_harmony import ToolNamespaceConfig
|
||||||
|
|
||||||
@ -11,6 +11,61 @@ from vllm.logger import init_logger
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from mcp.types import ListToolsResult
|
||||||
|
|
||||||
|
|
||||||
|
async def list_server_and_tools(server_url: str):
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
async with sse_client(url=server_url) as streams, ClientSession(
|
||||||
|
*streams) as session:
|
||||||
|
initialize_response = await session.initialize()
|
||||||
|
list_tools_response = await session.list_tools()
|
||||||
|
return initialize_response, list_tools_response
|
||||||
|
|
||||||
|
|
||||||
|
def trim_schema(schema: dict) -> dict:
|
||||||
|
# Turn JSON Schema from MCP generated into Harmony's variant.
|
||||||
|
if "title" in schema:
|
||||||
|
del schema["title"]
|
||||||
|
if "default" in schema and schema["default"] is None:
|
||||||
|
del schema["default"]
|
||||||
|
if "anyOf" in schema:
|
||||||
|
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
|
||||||
|
# into "type": ["type-1", "type-2"]
|
||||||
|
# if there's more than 1 types, also remove "null" type as Harmony will
|
||||||
|
# just ignore it
|
||||||
|
types = [
|
||||||
|
type_dict["type"] for type_dict in schema["anyOf"]
|
||||||
|
if type_dict["type"] != 'null'
|
||||||
|
]
|
||||||
|
schema["type"] = types
|
||||||
|
del schema["anyOf"]
|
||||||
|
if "properties" in schema:
|
||||||
|
schema["properties"] = {
|
||||||
|
k: trim_schema(v)
|
||||||
|
for k, v in schema["properties"].items()
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_tools_description(
|
||||||
|
list_tools_result: "ListToolsResult") -> "ListToolsResult":
|
||||||
|
# Adapt the MCP tool result for Harmony
|
||||||
|
for tool in list_tools_result.tools:
|
||||||
|
tool.inputSchema = trim_schema(tool.inputSchema)
|
||||||
|
|
||||||
|
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
||||||
|
# in text out for Python)
|
||||||
|
list_tools_result.tools = [
|
||||||
|
tool for tool in list_tools_result.tools
|
||||||
|
if getattr(tool.annotations, "include_in_prompt", True)
|
||||||
|
]
|
||||||
|
|
||||||
|
return list_tools_result
|
||||||
|
|
||||||
|
|
||||||
class ToolServer(ABC):
|
class ToolServer(ABC):
|
||||||
|
|
||||||
@ -38,6 +93,66 @@ class ToolServer(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolServer(ToolServer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
try:
|
||||||
|
import mcp # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"mcp is not installed. Please run `pip install mcp` to use "
|
||||||
|
"MCPToolServer.") from None
|
||||||
|
self.harmony_tool_descriptions = {}
|
||||||
|
|
||||||
|
async def add_tool_server(self, server_url: str):
|
||||||
|
from mcp.types import ToolDescription
|
||||||
|
tool_urls = server_url.split(",")
|
||||||
|
self.harmony_tool_descriptions = {}
|
||||||
|
self.urls: dict[str, str] = {}
|
||||||
|
for url in tool_urls:
|
||||||
|
url = f"http://{url}/sse"
|
||||||
|
initialize_response, list_tools_response = (
|
||||||
|
await list_server_and_tools(url))
|
||||||
|
|
||||||
|
list_tools_response = post_process_tools_description(
|
||||||
|
list_tools_response)
|
||||||
|
|
||||||
|
tool_from_mcp = ToolNamespaceConfig(
|
||||||
|
name=initialize_response.serverInfo.name,
|
||||||
|
description=initialize_response.instructions,
|
||||||
|
tools=[
|
||||||
|
ToolDescription.new(name=tool.name,
|
||||||
|
description=tool.description,
|
||||||
|
parameters=tool.inputSchema)
|
||||||
|
for tool in list_tools_response.tools
|
||||||
|
])
|
||||||
|
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||||
|
if tool_from_mcp.name not in self.urls:
|
||||||
|
self.urls[tool_from_mcp.name] = url
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Tool %s already exists. Ignoring duplicate tool server %s",
|
||||||
|
tool_from_mcp.name, url)
|
||||||
|
|
||||||
|
def has_tool(self, tool_name: str):
|
||||||
|
return tool_name in self.harmony_tool_descriptions
|
||||||
|
|
||||||
|
def get_tool_description(self, tool_name: str):
|
||||||
|
return self.harmony_tool_descriptions.get(tool_name)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def new_session(self, tool_name: str):
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
url = self.urls.get(tool_name)
|
||||||
|
if not url:
|
||||||
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
|
async with sse_client(url=url) as streams, ClientSession(
|
||||||
|
*streams) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
class DemoToolServer(ToolServer):
|
class DemoToolServer(ToolServer):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -67,4 +182,6 @@ class DemoToolServer(ToolServer):
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_session(self, tool_name: str):
|
async def new_session(self, tool_name: str):
|
||||||
|
if tool_name not in self.tools:
|
||||||
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
yield self.tools[tool_name]
|
yield self.tools[tool_name]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user