mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 22:43:07 +08:00
[gpt-oss] disable tool server initialization if no tool in request (#25790)
Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
56d0073f2a
commit
79b2fe7f19
129
tests/entrypoints/openai/test_serving_responses.py
Normal file
129
tests/entrypoints/openai/test_serving_responses.py
Normal file
@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import AsyncExitStack
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.entrypoints.context import ConversationContext
|
||||
from vllm.entrypoints.openai.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.tool_server import ToolServer
|
||||
|
||||
|
||||
class MockConversationContext(ConversationContext):
|
||||
"""Mock conversation context for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.init_tool_sessions_called = False
|
||||
self.init_tool_sessions_args = None
|
||||
self.init_tool_sessions_kwargs = None
|
||||
|
||||
def append_output(self, output) -> None:
|
||||
pass
|
||||
|
||||
async def call_tool(self):
|
||||
return []
|
||||
|
||||
def need_builtin_tool_call(self) -> bool:
|
||||
return False
|
||||
|
||||
def render_for_completion(self):
|
||||
return []
|
||||
|
||||
async def init_tool_sessions(self, tool_server, exit_stack, request_id,
|
||||
mcp_tools):
|
||||
self.init_tool_sessions_called = True
|
||||
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
|
||||
mcp_tools)
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serving_responses():
|
||||
"""Create a mock OpenAIServingResponses instance"""
|
||||
serving_responses = MagicMock(spec=OpenAIServingResponses)
|
||||
serving_responses.tool_server = MagicMock(spec=ToolServer)
|
||||
return serving_responses
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock conversation context"""
|
||||
return MockConversationContext()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_exit_stack():
|
||||
"""Create a mock async exit stack"""
|
||||
return MagicMock(spec=AsyncExitStack)
|
||||
|
||||
|
||||
class TestInitializeToolSessions:
|
||||
"""Test class for _initialize_tool_sessions method"""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def serving_responses_instance(self):
|
||||
"""Create a real OpenAIServingResponses instance for testing"""
|
||||
# Create minimal mocks for required dependencies
|
||||
engine_client = MagicMock()
|
||||
engine_client.get_model_config = AsyncMock()
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
tool_server = MagicMock(spec=ToolServer)
|
||||
|
||||
# Create the actual instance
|
||||
instance = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
tool_server=tool_server,
|
||||
)
|
||||
|
||||
return instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_tool_sessions(self, serving_responses_instance,
|
||||
mock_context, mock_exit_stack):
|
||||
"""Test that method works correctly with only MCP tools"""
|
||||
|
||||
request = ResponsesRequest(input="test input", tools=[])
|
||||
|
||||
# Call the method
|
||||
await serving_responses_instance._initialize_tool_sessions(
|
||||
request, mock_context, mock_exit_stack)
|
||||
assert mock_context.init_tool_sessions_called is False
|
||||
|
||||
# Create only MCP tools
|
||||
tools = [
|
||||
{
|
||||
"type": "web_search_preview"
|
||||
},
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
"container": {
|
||||
"type": "auto"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
request = ResponsesRequest(input="test input", tools=tools)
|
||||
|
||||
# Call the method
|
||||
await serving_responses_instance._initialize_tool_sessions(
|
||||
request, mock_context, mock_exit_stack)
|
||||
|
||||
# Verify that init_tool_sessions was called
|
||||
assert mock_context.init_tool_sessions_called
|
||||
@ -445,6 +445,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
return messages, [prompt_token_ids], [engine_prompt]
|
||||
|
||||
async def _initialize_tool_sessions(self, request: ResponsesRequest,
|
||||
context: ConversationContext,
|
||||
exit_stack: AsyncExitStack):
|
||||
# we should only initialize the tool session if the request needs tools
|
||||
if len(request.tools) == 0:
|
||||
return
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id, mcp_tools)
|
||||
|
||||
async def responses_full_generator(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
@ -461,12 +474,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id, mcp_tools)
|
||||
await self._initialize_tool_sessions(request, context,
|
||||
exit_stack)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
@ -1650,12 +1659,10 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
processer = None
|
||||
if self.use_harmony:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id, mcp_tools)
|
||||
# TODO: in streaming, we noticed this bug:
|
||||
# https://github.com/vllm-project/vllm/issues/25697
|
||||
await self._initialize_tool_sessions(request, context,
|
||||
exit_stack)
|
||||
processer = self._process_harmony_streaming_events
|
||||
else:
|
||||
processer = self._process_simple_streaming_events
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user