From fe6d8257a1859cdd938cb2ec2a63a45c666dcca3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 8 Aug 2025 15:06:37 -0700 Subject: [PATCH] [gpt-oss] Support tool call and implement MCP tool server (#22427) Signed-off-by: Chen Zhang --- vllm/entrypoints/harmony_utils.py | 5 +- vllm/entrypoints/openai/api_server.py | 6 +- vllm/entrypoints/openai/serving_responses.py | 185 +++++++++++-------- vllm/entrypoints/tool_server.py | 119 +++++++++++- 4 files changed, 233 insertions(+), 82 deletions(-) diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 87e76e08a0b4..efca1472e44c 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -237,7 +237,10 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: id=f"rs_{random_uuid()}", summary=[], type="reasoning", - text=content.text, + content=[ + ResponseReasoningTextContent(text=content.text, + type="reasoning_text") + ], status=None, ) output_items.append(reasoning_item) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c695ea8b5a0e..00eaba8c872f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -94,7 +94,8 @@ from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.tool_server import DemoToolServer, ToolServer +from vllm.entrypoints.tool_server import (DemoToolServer, MCPToolServer, + ToolServer) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) from vllm.logger import init_logger @@ -1635,6 +1636,9 @@ async def init_app_state( if args.tool_server == "demo": tool_server: Optional[ToolServer] = DemoToolServer() + elif args.tool_server: + tool_server = MCPToolServer() + await tool_server.add_tool_server(args.tool_server) else: tool_server = None diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index a7554e0d6831..1e3746e956e0 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -4,6 +4,7 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import AsyncExitStack from copy import copy from http import HTTPStatus from typing import Any, Callable, Final, Optional, Union @@ -226,65 +227,114 @@ class OpenAIServingResponses(OpenAIServing): # Schedule the request and get the result generator. generators: list[AsyncGenerator[ConversationContext, None]] = [] - try: - tool_sessions: dict[str, Any] = {} - for i, engine_prompt in enumerate(engine_prompts): - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"]) - sampling_params = request.to_sampling_params( - default_max_tokens, self.default_sampling_params) - trace_headers = (None if raw_request is None else await - self._get_trace_headers(raw_request.headers)) - - context: ConversationContext - if self.use_harmony: - if request.stream: - context = StreamingHarmonyContext( - messages, tool_sessions) - else: - context = HarmonyContext(messages, tool_sessions) + builtin_tool_list: list[str] = [] + if self.use_harmony and self.tool_server is not None: + if self.tool_server.has_tool("browser"): + builtin_tool_list.append("browser") + if self.tool_server.has_tool("python"): + builtin_tool_list.append("python") + async with AsyncExitStack() as exit_stack: + try: + if self.tool_server is not None: + # TODO: initialize tool sessions lazily when the session + # is actually used. + tool_session_ctxs: dict[str, Any] = { + tool_name: + exit_stack.enter_async_context( + self.tool_server.new_session(tool_name)) + for tool_name in builtin_tool_list + } + tool_sessions = {} + for tool_name in builtin_tool_list: + tool_sessions[tool_name] = ( + await tool_session_ctxs[tool_name]) else: - context = SimpleContext() - generator = self._generate_with_builtin_tools( - request_id=request.request_id, - request_prompt=request_prompts[i], - engine_prompt=engine_prompt, - sampling_params=sampling_params, - context=context, - lora_request=lora_request, - priority=request.priority, - trace_headers=trace_headers, + assert len(builtin_tool_list) == 0 + tool_sessions = {} + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers( + raw_request.headers)) + + context: ConversationContext + if self.use_harmony: + if request.stream: + context = StreamingHarmonyContext( + messages, tool_sessions) + else: + context = HarmonyContext(messages, tool_sessions) + else: + context = SimpleContext() + generator = self._generate_with_builtin_tools( + request_id=request.request_id, + request_prompt=request_prompts[i], + engine_prompt=engine_prompt, + sampling_params=sampling_params, + context=context, + lora_request=lora_request, + priority=request.priority, + trace_headers=trace_headers, + ) + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, ) - generators.append(generator) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + async with self.response_store_lock: + self.response_store[response.id] = response - assert len(generators) == 1 - result_generator, = generators + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + context, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) - # Store the input messages. - if request.store: - self.msg_store[request.request_id] = messages + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response - if request.background: - created_time = int(time.time()) - response = ResponsesResponse.from_request( - request, - sampling_params, - model_name=model_name, - created_time=created_time, - output=[], - status="queued", - usage=None, - ) - async with self.response_store_lock: - self.response_store[response.id] = response + if request.stream: + raise NotImplementedError( + "Streaming responses are not supported") - # Run the request in the background. - task = asyncio.create_task( - self._run_background_request( + try: + return await self.responses_full_generator( request, sampling_params, result_generator, @@ -292,33 +342,10 @@ class OpenAIServingResponses(OpenAIServing): model_name, tokenizer, request_metadata, - created_time, - ), - name=f"create_{response.id}", - ) - - # For cleanup. - response_id = response.id - self.background_tasks[response_id] = task - task.add_done_callback( - lambda _: self.background_tasks.pop(response_id, None)) - return response - - if request.stream: - raise NotImplementedError("Streaming responses are not supported") - - try: - return await self.responses_full_generator( - request, - sampling_params, - result_generator, - context, - model_name, - tokenizer, - request_metadata, - ) - except Exception as e: - return self.create_error_response(str(e)) + ) + except Exception as e: + return self.create_error_response(str(e)) + return self.create_error_response("Should not reach here") async def _make_request( self, diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 769c40e8cc58..352704b2b374 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional from openai_harmony import ToolNamespaceConfig @@ -11,6 +11,61 @@ from vllm.logger import init_logger logger = init_logger(__name__) +if TYPE_CHECKING: + from mcp.types import ListToolsResult + + +async def list_server_and_tools(server_url: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + + async with sse_client(url=server_url) as streams, ClientSession( + *streams) as session: + initialize_response = await session.initialize() + list_tools_response = await session.list_tools() + return initialize_response, list_tools_response + + +def trim_schema(schema: dict) -> dict: + # Turn JSON Schema from MCP generated into Harmony's variant. + if "title" in schema: + del schema["title"] + if "default" in schema and schema["default"] is None: + del schema["default"] + if "anyOf" in schema: + # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}] + # into "type": ["type-1", "type-2"] + # if there's more than 1 types, also remove "null" type as Harmony will + # just ignore it + types = [ + type_dict["type"] for type_dict in schema["anyOf"] + if type_dict["type"] != 'null' + ] + schema["type"] = types + del schema["anyOf"] + if "properties" in schema: + schema["properties"] = { + k: trim_schema(v) + for k, v in schema["properties"].items() + } + return schema + + +def post_process_tools_description( + list_tools_result: "ListToolsResult") -> "ListToolsResult": + # Adapt the MCP tool result for Harmony + for tool in list_tools_result.tools: + tool.inputSchema = trim_schema(tool.inputSchema) + + # Some tools schema don't need to be part of the prompt (e.g. simple text + # in text out for Python) + list_tools_result.tools = [ + tool for tool in list_tools_result.tools + if getattr(tool.annotations, "include_in_prompt", True) + ] + + return list_tools_result + class ToolServer(ABC): @@ -38,6 +93,66 @@ class ToolServer(ABC): ... +class MCPToolServer(ToolServer): + + def __init__(self): + try: + import mcp # noqa: F401 + except ImportError: + raise ImportError( + "mcp is not installed. Please run `pip install mcp` to use " + "MCPToolServer.") from None + self.harmony_tool_descriptions = {} + + async def add_tool_server(self, server_url: str): + from mcp.types import ToolDescription + tool_urls = server_url.split(",") + self.harmony_tool_descriptions = {} + self.urls: dict[str, str] = {} + for url in tool_urls: + url = f"http://{url}/sse" + initialize_response, list_tools_response = ( + await list_server_and_tools(url)) + + list_tools_response = post_process_tools_description( + list_tools_response) + + tool_from_mcp = ToolNamespaceConfig( + name=initialize_response.serverInfo.name, + description=initialize_response.instructions, + tools=[ + ToolDescription.new(name=tool.name, + description=tool.description, + parameters=tool.inputSchema) + for tool in list_tools_response.tools + ]) + self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp + if tool_from_mcp.name not in self.urls: + self.urls[tool_from_mcp.name] = url + else: + logger.warning( + "Tool %s already exists. Ignoring duplicate tool server %s", + tool_from_mcp.name, url) + + def has_tool(self, tool_name: str): + return tool_name in self.harmony_tool_descriptions + + def get_tool_description(self, tool_name: str): + return self.harmony_tool_descriptions.get(tool_name) + + @asynccontextmanager + async def new_session(self, tool_name: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + url = self.urls.get(tool_name) + if not url: + raise KeyError(f"Tool '{tool_name}' is not supported") + async with sse_client(url=url) as streams, ClientSession( + *streams) as session: + await session.initialize() + yield session + + class DemoToolServer(ToolServer): def __init__(self): @@ -67,4 +182,6 @@ class DemoToolServer(ToolServer): @asynccontextmanager async def new_session(self, tool_name: str): + if tool_name not in self.tools: + raise KeyError(f"Tool '{tool_name}' is not supported") yield self.tools[tool_name]