mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 00:57:58 +08:00
[gpt-oss] disable tool server initialization if no tool in request (#25790)
Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Andrew Xia <axia@fb.com> Co-authored-by: Andrew Xia <axia@fb.com>
This commit is contained in:
parent
6a7796e871
commit
e5017cd6d6
129
tests/entrypoints/openai/test_serving_responses.py
Normal file
129
tests/entrypoints/openai/test_serving_responses.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from vllm.entrypoints.context import ConversationContext
|
||||||
|
from vllm.entrypoints.openai.protocol import ResponsesRequest
|
||||||
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
|
|
||||||
|
|
||||||
|
class MockConversationContext(ConversationContext):
|
||||||
|
"""Mock conversation context for testing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.init_tool_sessions_called = False
|
||||||
|
self.init_tool_sessions_args = None
|
||||||
|
self.init_tool_sessions_kwargs = None
|
||||||
|
|
||||||
|
def append_output(self, output) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def call_tool(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def need_builtin_tool_call(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def render_for_completion(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def init_tool_sessions(self, tool_server, exit_stack, request_id,
|
||||||
|
mcp_tools):
|
||||||
|
self.init_tool_sessions_called = True
|
||||||
|
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
|
||||||
|
mcp_tools)
|
||||||
|
|
||||||
|
async def cleanup_session(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_serving_responses():
|
||||||
|
"""Create a mock OpenAIServingResponses instance"""
|
||||||
|
serving_responses = MagicMock(spec=OpenAIServingResponses)
|
||||||
|
serving_responses.tool_server = MagicMock(spec=ToolServer)
|
||||||
|
return serving_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_context():
|
||||||
|
"""Create a mock conversation context"""
|
||||||
|
return MockConversationContext()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_exit_stack():
|
||||||
|
"""Create a mock async exit stack"""
|
||||||
|
return MagicMock(spec=AsyncExitStack)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitializeToolSessions:
|
||||||
|
"""Test class for _initialize_tool_sessions method"""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def serving_responses_instance(self):
|
||||||
|
"""Create a real OpenAIServingResponses instance for testing"""
|
||||||
|
# Create minimal mocks for required dependencies
|
||||||
|
engine_client = MagicMock()
|
||||||
|
engine_client.get_model_config = AsyncMock()
|
||||||
|
|
||||||
|
model_config = MagicMock()
|
||||||
|
model_config.hf_config.model_type = "test"
|
||||||
|
model_config.get_diff_sampling_param.return_value = {}
|
||||||
|
|
||||||
|
models = MagicMock()
|
||||||
|
|
||||||
|
tool_server = MagicMock(spec=ToolServer)
|
||||||
|
|
||||||
|
# Create the actual instance
|
||||||
|
instance = OpenAIServingResponses(
|
||||||
|
engine_client=engine_client,
|
||||||
|
model_config=model_config,
|
||||||
|
models=models,
|
||||||
|
request_logger=None,
|
||||||
|
chat_template=None,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
tool_server=tool_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize_tool_sessions(self, serving_responses_instance,
|
||||||
|
mock_context, mock_exit_stack):
|
||||||
|
"""Test that method works correctly with only MCP tools"""
|
||||||
|
|
||||||
|
request = ResponsesRequest(input="test input", tools=[])
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await serving_responses_instance._initialize_tool_sessions(
|
||||||
|
request, mock_context, mock_exit_stack)
|
||||||
|
assert mock_context.init_tool_sessions_called is False
|
||||||
|
|
||||||
|
# Create only MCP tools
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "web_search_preview"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "code_interpreter",
|
||||||
|
"container": {
|
||||||
|
"type": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
request = ResponsesRequest(input="test input", tools=tools)
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await serving_responses_instance._initialize_tool_sessions(
|
||||||
|
request, mock_context, mock_exit_stack)
|
||||||
|
|
||||||
|
# Verify that init_tool_sessions was called
|
||||||
|
assert mock_context.init_tool_sessions_called
|
||||||
@ -445,6 +445,19 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
return messages, [prompt_token_ids], [engine_prompt]
|
return messages, [prompt_token_ids], [engine_prompt]
|
||||||
|
|
||||||
|
async def _initialize_tool_sessions(self, request: ResponsesRequest,
|
||||||
|
context: ConversationContext,
|
||||||
|
exit_stack: AsyncExitStack):
|
||||||
|
# we should only initialize the tool session if the request needs tools
|
||||||
|
if len(request.tools) == 0:
|
||||||
|
return
|
||||||
|
mcp_tools = {
|
||||||
|
tool.server_label: tool
|
||||||
|
for tool in request.tools if tool.type == "mcp"
|
||||||
|
}
|
||||||
|
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||||
|
request.request_id, mcp_tools)
|
||||||
|
|
||||||
async def responses_full_generator(
|
async def responses_full_generator(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
@ -461,12 +474,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
mcp_tools = {
|
await self._initialize_tool_sessions(request, context,
|
||||||
tool.server_label: tool
|
exit_stack)
|
||||||
for tool in request.tools if tool.type == "mcp"
|
|
||||||
}
|
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
|
||||||
request.request_id, mcp_tools)
|
|
||||||
async for _ in result_generator:
|
async for _ in result_generator:
|
||||||
pass
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -1650,12 +1659,10 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
processer = None
|
processer = None
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
mcp_tools = {
|
# TODO: in streaming, we noticed this bug:
|
||||||
tool.server_label: tool
|
# https://github.com/vllm-project/vllm/issues/25697
|
||||||
for tool in request.tools if tool.type == "mcp"
|
await self._initialize_tool_sessions(request, context,
|
||||||
}
|
exit_stack)
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
|
||||||
request.request_id, mcp_tools)
|
|
||||||
processer = self._process_harmony_streaming_events
|
processer = self._process_harmony_streaming_events
|
||||||
else:
|
else:
|
||||||
processer = self._process_simple_streaming_events
|
processer = self._process_simple_streaming_events
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user