[gpt-oss] disable tool server initialization if no tool in request (#25790)

Signed-off-by: Andrew Xia <axia@meta.com>
Signed-off-by: Andrew Xia <axia@fb.com>
Co-authored-by: Andrew Xia <axia@fb.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Andrew Xia 2025-10-02 22:08:35 -07:00 committed by yewentao256
parent 56d0073f2a
commit 79b2fe7f19
2 changed files with 148 additions and 12 deletions

View File

@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import AsyncExitStack
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.openai.protocol import ResponsesRequest
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.tool_server import ToolServer
class MockConversationContext(ConversationContext):
"""Mock conversation context for testing"""
def __init__(self):
self.init_tool_sessions_called = False
self.init_tool_sessions_args = None
self.init_tool_sessions_kwargs = None
def append_output(self, output) -> None:
pass
async def call_tool(self):
return []
def need_builtin_tool_call(self) -> bool:
return False
def render_for_completion(self):
return []
async def init_tool_sessions(self, tool_server, exit_stack, request_id,
mcp_tools):
self.init_tool_sessions_called = True
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
mcp_tools)
async def cleanup_session(self) -> None:
pass
@pytest.fixture
def mock_serving_responses():
"""Create a mock OpenAIServingResponses instance"""
serving_responses = MagicMock(spec=OpenAIServingResponses)
serving_responses.tool_server = MagicMock(spec=ToolServer)
return serving_responses
@pytest.fixture
def mock_context():
"""Create a mock conversation context"""
return MockConversationContext()
@pytest.fixture
def mock_exit_stack():
"""Create a mock async exit stack"""
return MagicMock(spec=AsyncExitStack)
class TestInitializeToolSessions:
"""Test class for _initialize_tool_sessions method"""
@pytest_asyncio.fixture
async def serving_responses_instance(self):
"""Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies
engine_client = MagicMock()
engine_client.get_model_config = AsyncMock()
model_config = MagicMock()
model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {}
models = MagicMock()
tool_server = MagicMock(spec=ToolServer)
# Create the actual instance
instance = OpenAIServingResponses(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
tool_server=tool_server,
)
return instance
@pytest.mark.asyncio
async def test_initialize_tool_sessions(self, serving_responses_instance,
mock_context, mock_exit_stack):
"""Test that method works correctly with only MCP tools"""
request = ResponsesRequest(input="test input", tools=[])
# Call the method
await serving_responses_instance._initialize_tool_sessions(
request, mock_context, mock_exit_stack)
assert mock_context.init_tool_sessions_called is False
# Create only MCP tools
tools = [
{
"type": "web_search_preview"
},
{
"type": "code_interpreter",
"container": {
"type": "auto"
}
},
]
request = ResponsesRequest(input="test input", tools=tools)
# Call the method
await serving_responses_instance._initialize_tool_sessions(
request, mock_context, mock_exit_stack)
# Verify that init_tool_sessions was called
assert mock_context.init_tool_sessions_called

View File

@ -445,6 +445,19 @@ class OpenAIServingResponses(OpenAIServing):
return messages, [prompt_token_ids], [engine_prompt]
async def _initialize_tool_sessions(self, request: ResponsesRequest,
context: ConversationContext,
exit_stack: AsyncExitStack):
# we should only initialize the tool session if the request needs tools
if len(request.tools) == 0:
return
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)
async def responses_full_generator(
self,
request: ResponsesRequest,
@ -461,12 +474,8 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)
await self._initialize_tool_sessions(request, context,
exit_stack)
async for _ in result_generator:
pass
except asyncio.CancelledError:
@ -1650,12 +1659,10 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)
# TODO: in streaming, we noticed this bug:
# https://github.com/vllm-project/vllm/issues/25697
await self._initialize_tool_sessions(request, context,
exit_stack)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events