[gpt-oss] Support tool call and implement MCP tool server (#22427)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-08-08 15:06:37 -07:00 committed by GitHub
parent e290594072
commit fe6d8257a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 233 additions and 82 deletions

View File

@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
text=content.text,
content=[
ResponseReasoningTextContent(text=content.text,
type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)

View File

@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import (
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.tool_server import DemoToolServer, ToolServer
from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
ToolServer)
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
log_non_default_args, with_cancellation)
from vllm.logger import init_logger
@ -1635,6 +1636,9 @@ async def init_app_state(
if args.tool_server == "demo":
tool_server: Optional[ToolServer] = DemoToolServer()
elif args.tool_server:
tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server)
else:
tool_server = None

View File

@ -4,6 +4,7 @@
import asyncio
import time
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Any, Callable, Final, Optional, Union
@ -226,65 +227,114 @@ class OpenAIServingResponses(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[ConversationContext, None]] = []
try:
tool_sessions: dict[str, Any] = {}
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(
messages, tool_sessions)
else:
context = HarmonyContext(messages, tool_sessions)
builtin_tool_list: list[str] = []
if self.use_harmony and self.tool_server is not None:
if self.tool_server.has_tool("browser"):
builtin_tool_list.append("browser")
if self.tool_server.has_tool("python"):
builtin_tool_list.append("python")
async with AsyncExitStack() as exit_stack:
try:
if self.tool_server is not None:
# TODO: initialize tool sessions lazily when the session
# is actually used.
tool_session_ctxs: dict[str, Any] = {
tool_name:
exit_stack.enter_async_context(
self.tool_server.new_session(tool_name))
for tool_name in builtin_tool_list
}
tool_sessions = {}
for tool_name in builtin_tool_list:
tool_sessions[tool_name] = (
await tool_session_ctxs[tool_name])
else:
context = SimpleContext()
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
assert len(builtin_tool_list) == 0
tool_sessions = {}
for i, engine_prompt in enumerate(engine_prompts):
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(
raw_request.headers))
context: ConversationContext
if self.use_harmony:
if request.stream:
context = StreamingHarmonyContext(
messages, tool_sessions)
else:
context = HarmonyContext(messages, tool_sessions)
else:
context = SimpleContext()
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
lora_request=lora_request,
priority=request.priority,
trace_headers=trace_headers,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async with self.response_store_lock:
self.response_store[response.id] = response
assert len(generators) == 1
result_generator, = generators
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# Store the input messages.
if request.store:
self.msg_store[request.request_id] = messages
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
if request.background:
created_time = int(time.time())
response = ResponsesResponse.from_request(
request,
sampling_params,
model_name=model_name,
created_time=created_time,
output=[],
status="queued",
usage=None,
)
async with self.response_store_lock:
self.response_store[response.id] = response
if request.stream:
raise NotImplementedError(
"Streaming responses are not supported")
# Run the request in the background.
task = asyncio.create_task(
self._run_background_request(
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
@ -292,33 +342,10 @@ class OpenAIServingResponses(OpenAIServing):
model_name,
tokenizer,
request_metadata,
created_time,
),
name=f"create_{response.id}",
)
# For cleanup.
response_id = response.id
self.background_tasks[response_id] = task
task.add_done_callback(
lambda _: self.background_tasks.pop(response_id, None))
return response
if request.stream:
raise NotImplementedError("Streaming responses are not supported")
try:
return await self.responses_full_generator(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
)
except Exception as e:
return self.create_error_response(str(e))
)
except Exception as e:
return self.create_error_response(str(e))
return self.create_error_response("Should not reach here")
async def _make_request(
self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
from openai_harmony import ToolNamespaceConfig
@ -11,6 +11,61 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
if TYPE_CHECKING:
from mcp.types import ListToolsResult
async def list_server_and_tools(server_url: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
list_tools_response = await session.list_tools()
return initialize_response, list_tools_response
def trim_schema(schema: dict) -> dict:
# Turn JSON Schema from MCP generated into Harmony's variant.
if "title" in schema:
del schema["title"]
if "default" in schema and schema["default"] is None:
del schema["default"]
if "anyOf" in schema:
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
# into "type": ["type-1", "type-2"]
# if there's more than 1 types, also remove "null" type as Harmony will
# just ignore it
types = [
type_dict["type"] for type_dict in schema["anyOf"]
if type_dict["type"] != 'null'
]
schema["type"] = types
del schema["anyOf"]
if "properties" in schema:
schema["properties"] = {
k: trim_schema(v)
for k, v in schema["properties"].items()
}
return schema
def post_process_tools_description(
list_tools_result: "ListToolsResult") -> "ListToolsResult":
# Adapt the MCP tool result for Harmony
for tool in list_tools_result.tools:
tool.inputSchema = trim_schema(tool.inputSchema)
# Some tools schema don't need to be part of the prompt (e.g. simple text
# in text out for Python)
list_tools_result.tools = [
tool for tool in list_tools_result.tools
if getattr(tool.annotations, "include_in_prompt", True)
]
return list_tools_result
class ToolServer(ABC):
@ -38,6 +93,66 @@ class ToolServer(ABC):
...
class MCPToolServer(ToolServer):
def __init__(self):
try:
import mcp # noqa: F401
except ImportError:
raise ImportError(
"mcp is not installed. Please run `pip install mcp` to use "
"MCPToolServer.") from None
self.harmony_tool_descriptions = {}
async def add_tool_server(self, server_url: str):
from mcp.types import ToolDescription
tool_urls = server_url.split(",")
self.harmony_tool_descriptions = {}
self.urls: dict[str, str] = {}
for url in tool_urls:
url = f"http://{url}/sse"
initialize_response, list_tools_response = (
await list_server_and_tools(url))
list_tools_response = post_process_tools_description(
list_tools_response)
tool_from_mcp = ToolNamespaceConfig(
name=initialize_response.serverInfo.name,
description=initialize_response.instructions,
tools=[
ToolDescription.new(name=tool.name,
description=tool.description,
parameters=tool.inputSchema)
for tool in list_tools_response.tools
])
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
if tool_from_mcp.name not in self.urls:
self.urls[tool_from_mcp.name] = url
else:
logger.warning(
"Tool %s already exists. Ignoring duplicate tool server %s",
tool_from_mcp.name, url)
def has_tool(self, tool_name: str):
return tool_name in self.harmony_tool_descriptions
def get_tool_description(self, tool_name: str):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def new_session(self, tool_name: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session
class DemoToolServer(ToolServer):
def __init__(self):
@ -67,4 +182,6 @@ class DemoToolServer(ToolServer):
@asynccontextmanager
async def new_session(self, tool_name: str):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]