[gpt-oss] Support tool call and implement MCP tool server (#22427)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-08-08 15:06:37 -07:00 committed by GitHub
parent e290594072
commit fe6d8257a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 233 additions and 82 deletions

View File

@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
id=f"rs_{random_uuid()}", id=f"rs_{random_uuid()}",
summary=[], summary=[],
type="reasoning", type="reasoning",
text=content.text, content=[
ResponseReasoningTextContent(text=content.text,
type="reasoning_text")
],
status=None, status=None,
) )
output_items.append(reasoning_item) output_items.append(reasoning_item)

View File

@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription, OpenAIServingTranslation) OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.tool_server import DemoToolServer, ToolServer from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
ToolServer)
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
log_non_default_args, with_cancellation) log_non_default_args, with_cancellation)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -1635,6 +1636,9 @@ async def init_app_state(
if args.tool_server == "demo": if args.tool_server == "demo":
tool_server: Optional[ToolServer] = DemoToolServer() tool_server: Optional[ToolServer] = DemoToolServer()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else: else:
tool_server = None tool_server = None

View File

@ -4,6 +4,7 @@
import asyncio import asyncio
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union from typing import Any, Callable, Final, Optional, Union
@ -226,8 +227,31 @@ class OpenAIServingResponses(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = [] generators: list[AsyncGenerator[ConversationContext, None]] = []
builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None:
if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try: try:
tool_sessions: dict[str, Any] = {} if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
else:
assert len(builtin_tool_list) == 0
tool_sessions = {}
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len( default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]) engine_prompt["prompt_token_ids"])
@ -235,7 +259,8 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens, self.default_sampling_params) default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers)) self._get_trace_headers(
raw_request.headers))
context: ConversationContext context: ConversationContext
if self.use_harmony: if self.use_harmony:
@ -305,7 +330,8 @@ class OpenAIServingResponses(OpenAIServing):
return response return response
if request.stream: if request.stream:
raise NotImplementedError("Streaming responses are not supported") raise NotImplementedError(
"Streaming responses are not supported")
try: try:
return await self.responses_full_generator( return await self.responses_full_generator(
@ -319,6 +345,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
return self.create_error_response("Should not reach here")
async def _make_request( async def _make_request(
self, self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
from openai_harmony import ToolNamespaceConfig from openai_harmony import ToolNamespaceConfig
@ -11,6 +11,61 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING:
from mcp.types import ListToolsResult
async def list_server_and_tools(server_url: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
# into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will
# just ignore it
types = [
type_dict["type"] for type_dict in schema["anyOf"]
if type_dict["type"] != 'null'
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v)
for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: "ListToolsResult") -> "ListToolsResult":
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text
# in text out for Python)
list_tools_result.tools = [
tool for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
class ToolServer(ABC): class ToolServer(ABC):
@ -38,6 +93,66 @@ class ToolServer(ABC):
... ...
class MCPToolServer(ToolServer):
def __init__(self):
try:
import mcp # noqa: F401
except ImportError:
raise ImportError(
"mcp is not installed. Please run `pip install mcp` to use "
"MCPToolServer.") from None
self.harmony_tool_descriptions = {}
async def add_tool_server(self, server_url: str):
from mcp.types import ToolDescription
tool_urls = server_url.split(",")
self.harmony_tool_descriptions = {}
self.urls: dict[str, str] = {}
for url in tool_urls:
url = f"http://{url}/sse"
initialize_response, list_tools_response = (
await list_server_and_tools(url))
list_tools_response = post_process_tools_description(
list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(name=tool.name,
description=tool.description,
parameters=tool.inputSchema)
for tool in list_tools_response.tools
])
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
else:
logger.warning(
"Tool %s already exists. Ignoring duplicate tool server %s",
tool_from_mcp.name, url)
def has_tool(self, tool_name: str):
return tool_name in self.harmony_tool_descriptions
def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def new_session(self, tool_name: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session
class DemoToolServer(ToolServer): class DemoToolServer(ToolServer):
def __init__(self): def __init__(self):
@ -67,4 +182,6 @@ class DemoToolServer(ToolServer):
@asynccontextmanager @asynccontextmanager
async def new_session(self, tool_name: str): async def new_session(self, tool_name: str):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name] yield self.tools[tool_name]