mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:17:07 +08:00
[gpt-oss] Support tool call and implement MCP tool server (#22427)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
e290594072
commit
fe6d8257a1
@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
|
||||
id=f"rs_{random_uuid()}",
|
||||
summary=[],
|
||||
type="reasoning",
|
||||
text=content.text,
|
||||
content=[
|
||||
ResponseReasoningTextContent(text=content.text,
|
||||
type="reasoning_text")
|
||||
],
|
||||
status=None,
|
||||
)
|
||||
output_items.append(reasoning_item)
|
||||
|
||||
@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import (
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription, OpenAIServingTranslation)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.tool_server import DemoToolServer, ToolServer
|
||||
from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer,
|
||||
ToolServer)
|
||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||
log_non_default_args, with_cancellation)
|
||||
from vllm.logger import init_logger
|
||||
@ -1635,6 +1636,9 @@ async def init_app_state(
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: Optional[ToolServer] = DemoToolServer()
|
||||
elif args.tool_server:
|
||||
tool_server = MCPToolServer()
|
||||
await tool_server.add_tool_server(args.tool_server)
|
||||
else:
|
||||
tool_server = None
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import AsyncExitStack
|
||||
from copy import copy
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Final, Optional, Union
|
||||
@ -226,65 +227,114 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[ConversationContext, None]] = []
|
||||
try:
|
||||
tool_sessions: dict[str, Any] = {}
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
context: ConversationContext
|
||||
if self.use_harmony:
|
||||
if request.stream:
|
||||
context = StreamingHarmonyContext(
|
||||
messages, tool_sessions)
|
||||
else:
|
||||
context = HarmonyContext(messages, tool_sessions)
|
||||
builtin_tool_list: list[str] = []
|
||||
if self.use_harmony and self.tool_server is not None:
|
||||
if self.tool_server.has_tool("browser"):
|
||||
builtin_tool_list.append("browser")
|
||||
if self.tool_server.has_tool("python"):
|
||||
builtin_tool_list.append("python")
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
if self.tool_server is not None:
|
||||
# TODO: initialize tool sessions lazily when the session
|
||||
# is actually used.
|
||||
tool_session_ctxs: dict[str, Any] = {
|
||||
tool_name:
|
||||
exit_stack.enter_async_context(
|
||||
self.tool_server.new_session(tool_name))
|
||||
for tool_name in builtin_tool_list
|
||||
}
|
||||
tool_sessions = {}
|
||||
for tool_name in builtin_tool_list:
|
||||
tool_sessions[tool_name] = (
|
||||
await tool_session_ctxs[tool_name])
|
||||
else:
|
||||
context = SimpleContext()
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
request_prompt=request_prompts[i],
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
trace_headers=trace_headers,
|
||||
assert len(builtin_tool_list) == 0
|
||||
tool_sessions = {}
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(
|
||||
raw_request.headers))
|
||||
|
||||
context: ConversationContext
|
||||
if self.use_harmony:
|
||||
if request.stream:
|
||||
context = StreamingHarmonyContext(
|
||||
messages, tool_sessions)
|
||||
else:
|
||||
context = HarmonyContext(messages, tool_sessions)
|
||||
else:
|
||||
context = SimpleContext()
|
||||
generator = self._generate_with_builtin_tools(
|
||||
request_id=request.request_id,
|
||||
request_prompt=request_prompts[i],
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
|
||||
# Store the input messages.
|
||||
if request.store:
|
||||
self.msg_store[request.request_id] = messages
|
||||
|
||||
if request.background:
|
||||
created_time = int(time.time())
|
||||
response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=created_time,
|
||||
output=[],
|
||||
status="queued",
|
||||
usage=None,
|
||||
)
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
async with self.response_store_lock:
|
||||
self.response_store[response.id] = response
|
||||
|
||||
assert len(generators) == 1
|
||||
result_generator, = generators
|
||||
# Run the request in the background.
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
|
||||
# Store the input messages.
|
||||
if request.store:
|
||||
self.msg_store[request.request_id] = messages
|
||||
# For cleanup.
|
||||
response_id = response.id
|
||||
self.background_tasks[response_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self.background_tasks.pop(response_id, None))
|
||||
return response
|
||||
|
||||
if request.background:
|
||||
created_time = int(time.time())
|
||||
response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
model_name=model_name,
|
||||
created_time=created_time,
|
||||
output=[],
|
||||
status="queued",
|
||||
usage=None,
|
||||
)
|
||||
async with self.response_store_lock:
|
||||
self.response_store[response.id] = response
|
||||
if request.stream:
|
||||
raise NotImplementedError(
|
||||
"Streaming responses are not supported")
|
||||
|
||||
# Run the request in the background.
|
||||
task = asyncio.create_task(
|
||||
self._run_background_request(
|
||||
try:
|
||||
return await self.responses_full_generator(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
@ -292,33 +342,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
created_time,
|
||||
),
|
||||
name=f"create_{response.id}",
|
||||
)
|
||||
|
||||
# For cleanup.
|
||||
response_id = response.id
|
||||
self.background_tasks[response_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self.background_tasks.pop(response_id, None))
|
||||
return response
|
||||
|
||||
if request.stream:
|
||||
raise NotImplementedError("Streaming responses are not supported")
|
||||
|
||||
try:
|
||||
return await self.responses_full_generator(
|
||||
request,
|
||||
sampling_params,
|
||||
result_generator,
|
||||
context,
|
||||
model_name,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
return self.create_error_response(str(e))
|
||||
return self.create_error_response("Should not reach here")
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from openai_harmony import ToolNamespaceConfig
|
||||
|
||||
@ -11,6 +11,61 @@ from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import ListToolsResult
|
||||
|
||||
|
||||
async def list_server_and_tools(server_url: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
async with sse_client(url=server_url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
initialize_response = await session.initialize()
|
||||
list_tools_response = await session.list_tools()
|
||||
return initialize_response, list_tools_response
|
||||
|
||||
|
||||
def trim_schema(schema: dict) -> dict:
|
||||
# Turn JSON Schema from MCP generated into Harmony's variant.
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
if "default" in schema and schema["default"] is None:
|
||||
del schema["default"]
|
||||
if "anyOf" in schema:
|
||||
# Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
|
||||
# into "type": ["type-1", "type-2"]
|
||||
# if there's more than 1 types, also remove "null" type as Harmony will
|
||||
# just ignore it
|
||||
types = [
|
||||
type_dict["type"] for type_dict in schema["anyOf"]
|
||||
if type_dict["type"] != 'null'
|
||||
]
|
||||
schema["type"] = types
|
||||
del schema["anyOf"]
|
||||
if "properties" in schema:
|
||||
schema["properties"] = {
|
||||
k: trim_schema(v)
|
||||
for k, v in schema["properties"].items()
|
||||
}
|
||||
return schema
|
||||
|
||||
|
||||
def post_process_tools_description(
|
||||
list_tools_result: "ListToolsResult") -> "ListToolsResult":
|
||||
# Adapt the MCP tool result for Harmony
|
||||
for tool in list_tools_result.tools:
|
||||
tool.inputSchema = trim_schema(tool.inputSchema)
|
||||
|
||||
# Some tools schema don't need to be part of the prompt (e.g. simple text
|
||||
# in text out for Python)
|
||||
list_tools_result.tools = [
|
||||
tool for tool in list_tools_result.tools
|
||||
if getattr(tool.annotations, "include_in_prompt", True)
|
||||
]
|
||||
|
||||
return list_tools_result
|
||||
|
||||
|
||||
class ToolServer(ABC):
|
||||
|
||||
@ -38,6 +93,66 @@ class ToolServer(ABC):
|
||||
...
|
||||
|
||||
|
||||
class MCPToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import mcp # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"mcp is not installed. Please run `pip install mcp` to use "
|
||||
"MCPToolServer.") from None
|
||||
self.harmony_tool_descriptions = {}
|
||||
|
||||
async def add_tool_server(self, server_url: str):
|
||||
from mcp.types import ToolDescription
|
||||
tool_urls = server_url.split(",")
|
||||
self.harmony_tool_descriptions = {}
|
||||
self.urls: dict[str, str] = {}
|
||||
for url in tool_urls:
|
||||
url = f"http://{url}/sse"
|
||||
initialize_response, list_tools_response = (
|
||||
await list_server_and_tools(url))
|
||||
|
||||
list_tools_response = post_process_tools_description(
|
||||
list_tools_response)
|
||||
|
||||
tool_from_mcp = ToolNamespaceConfig(
|
||||
name=initialize_response.serverInfo.name,
|
||||
description=initialize_response.instructions,
|
||||
tools=[
|
||||
ToolDescription.new(name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=tool.inputSchema)
|
||||
for tool in list_tools_response.tools
|
||||
])
|
||||
self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
|
||||
if tool_from_mcp.name not in self.urls:
|
||||
self.urls[tool_from_mcp.name] = url
|
||||
else:
|
||||
logger.warning(
|
||||
"Tool %s already exists. Ignoring duplicate tool server %s",
|
||||
tool_from_mcp.name, url)
|
||||
|
||||
def has_tool(self, tool_name: str):
|
||||
return tool_name in self.harmony_tool_descriptions
|
||||
|
||||
def get_tool_description(self, tool_name: str):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
url = self.urls.get(tool_name)
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(url=url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
|
||||
class DemoToolServer(ToolServer):
|
||||
|
||||
def __init__(self):
|
||||
@ -67,4 +182,6 @@ class DemoToolServer(ToolServer):
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user