[Frontend] [gpt-oss] Mcp type bug (#27689)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Signed-off-by: Alec Solder <alecs@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: Alec Solder <alecs@fb.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Alec S 2025-10-29 06:01:32 -04:00 committed by GitHub
parent 3c7fefdeba
commit ab2eb27b74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 293 additions and 44 deletions

View File

@ -26,6 +26,8 @@ def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@ -37,7 +39,9 @@ def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container")
m.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@ -56,18 +60,15 @@ async def mcp_enabled_client(mcp_enabled_server):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=(
"What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
@ -77,26 +78,47 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name:
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
# Verify output messages: Tool calls and responses on analysis channel
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, "Should have found at least one Python tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=(
"What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
@ -106,7 +128,34 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens == 0
# Verify output messages: No tool calls and responses
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert not tool_call_found, "Should not have a python call"
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool"
)

View File

@ -6,10 +6,19 @@ from unittest.mock import MagicMock
import pytest
import pytest_asyncio
from openai.types.responses.tool import (
CodeInterpreterContainerCodeInterpreterToolAuto,
LocalShell,
Mcp,
Tool,
)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_responses import (
OpenAIServingResponses,
extract_tool_types,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
@ -62,6 +71,45 @@ def mock_exit_stack():
return MagicMock(spec=AsyncExitStack)
def test_extract_tool_types(monkeypatch: pytest.MonkeyPatch) -> None:
tools: list[Tool] = []
assert extract_tool_types(tools) == set()
tools.append(LocalShell(type="local_shell"))
assert extract_tool_types(tools) == {"local_shell"}
tools.append(CodeInterpreterContainerCodeInterpreterToolAuto(type="auto"))
assert extract_tool_types(tools) == {"local_shell", "auto"}
tools.extend(
[
Mcp(type="mcp", server_label="random", server_url=""),
Mcp(type="mcp", server_label="container", server_url=""),
Mcp(type="mcp", server_label="code_interpreter", server_url=""),
Mcp(type="mcp", server_label="web_search_preview", server_url=""),
]
)
# When envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS is not set,
# mcp tool types are all ignored.
assert extract_tool_types(tools) == {"local_shell", "auto"}
# container is allowed, it would be extracted
monkeypatch.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "container")
assert extract_tool_types(tools) == {"local_shell", "auto", "container"}
# code_interpreter and web_search_preview are allowed,
# they would be extracted
monkeypatch.setenv(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,web_search_preview"
)
assert extract_tool_types(tools) == {
"local_shell",
"auto",
"code_interpreter",
"web_search_preview",
}
class TestInitializeToolSessions:
"""Test class for _initialize_tool_sessions method"""

View File

@ -3,7 +3,10 @@
from openai_harmony import Role
from vllm.entrypoints.harmony_utils import parse_input_to_harmony_message
from vllm.entrypoints.harmony_utils import (
has_custom_tools,
parse_input_to_harmony_message,
)
class TestParseInputToHarmonyMessage:
@ -252,3 +255,12 @@ class TestParseInputToHarmonyMessage:
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
def test_has_custom_tools() -> None:
assert not has_custom_tools(set())
assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"})
assert has_custom_tools({"others"})
assert has_custom_tools(
{"web_search_preview", "code_interpreter", "container", "others"}
)

View File

@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.envs import (
enable_envs_cache,
env_list_with_choices,
env_set_with_choices,
env_with_choices,
environment_variables,
)
@ -257,3 +258,110 @@ class TestEnvListWithChoices:
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == ["option1", "option1", "option2"]
class TestEnvSetWithChoices:
"""Test cases for env_set_with_choices function."""
def test_default_list_returned_when_env_not_set(self):
"""Test that default list is returned when env var is not set."""
env_func = env_set_with_choices(
"NONEXISTENT_ENV", ["default1", "default2"], ["option1", "option2"]
)
assert env_func() == {"default1", "default2"}
def test_empty_default_list_returned_when_env_not_set(self):
"""Test that empty default list is returned when env not set."""
env_func = env_set_with_choices("NONEXISTENT_ENV", [], ["option1", "option2"])
assert env_func() == set()
def test_single_valid_value_parsed_correctly(self):
"""Test that single valid value is parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1"}
def test_multiple_valid_values_parsed_correctly(self):
"""Test that multiple valid values are parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"}
def test_values_with_whitespace_trimmed(self):
"""Test that values with whitespace are trimmed correctly."""
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"}
def test_empty_values_filtered_out(self):
"""Test that empty values are filtered out."""
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"}
def test_empty_string_returns_default(self):
"""Test that empty string returns default."""
with patch.dict(os.environ, {"TEST_ENV": ""}):
env_func = env_set_with_choices(
"TEST_ENV", ["default"], ["option1", "option2"]
)
assert env_func() == {"default"}
def test_only_commas_returns_default(self):
"""Test that string with only commas returns default."""
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
env_func = env_set_with_choices(
"TEST_ENV", ["default"], ["option1", "option2"]
)
assert env_func() == {"default"}
def test_case_sensitive_validation(self):
"""Test case sensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
env_func = env_set_with_choices(
"TEST_ENV", [], ["option1", "option2"], case_sensitive=True
)
with pytest.raises(ValueError, match="Invalid value 'OPTION2' in TEST_ENV"):
env_func()
def test_case_insensitive_validation(self):
"""Test case insensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
env_func = env_set_with_choices(
"TEST_ENV", [], ["option1", "option2"], case_sensitive=False
)
assert env_func() == {"OPTION1", "option2"}
def test_invalid_value_in_list_raises_error(self):
"""Test that invalid value in list raises ValueError."""
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
env_func = env_set_with_choices("TEST_ENV", [], get_choices)
assert env_func() == {"dynamic1", "dynamic2"}
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
env_func = env_set_with_choices("TEST_ENV", [], get_choices)
with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_duplicate_values_deduped(self):
"""Test that duplicate values in the list are deduped."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"}

View File

@ -61,15 +61,19 @@ _harmony_encoding = None
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS = {
MCP_BUILTIN_TOOLS: set[str] = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: list[str]) -> bool:
return not set(tool_types).issubset(BUILTIN_TOOLS)
def has_custom_tools(tool_types: set[str]) -> bool:
"""
Checks if the given tool types are custom tools
(i.e. any tool other than MCP buildin tools)
"""
return not tool_types.issubset(MCP_BUILTIN_TOOLS)
def get_encoding():

View File

@ -48,6 +48,7 @@ from openai.types.responses.response_output_text import Logprob, LogprobTopLogpr
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai_harmony import Message as OpenAIHarmonyMessage
from vllm import envs
@ -106,6 +107,23 @@ from vllm.utils import random_uuid
logger = init_logger(__name__)
def extract_tool_types(tools: list[Tool]) -> set[str]:
"""
Extracts the tool types from the given tools.
"""
tool_types: set[str] = set()
for tool in tools:
if tool.type == "mcp":
# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if tool.server_label in envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
tool_types.add(tool.server_label)
else:
tool_types.add(tool.type)
return tool_types
class OpenAIServingResponses(OpenAIServing):
def __init__(
self,
@ -879,7 +897,7 @@ class OpenAIServingResponses(OpenAIServing):
return messages
def _construct_harmony_system_input_message(
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: list[str]
self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str]
) -> OpenAIHarmonyMessage:
reasoning_effort = request.reasoning.effort if request.reasoning else None
enable_browser = (
@ -927,17 +945,7 @@ class OpenAIServingResponses(OpenAIServing):
messages: list[OpenAIHarmonyMessage] = []
if prev_response is None:
# New conversation.
tool_types = [tool.type for tool in request.tools]
# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
for tool in request.tools:
if (
tool.type == "mcp"
and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
):
tool_types.append(tool.server_label)
tool_types = extract_tool_types(request.tools)
with_custom_tools = has_custom_tools(tool_types)
sys_msg = self._construct_harmony_system_input_message(

View File

@ -198,6 +198,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
@ -209,7 +210,6 @@ if TYPE_CHECKING:
VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK: bool = True
VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL: bool = False
VLLM_DBO_COMM_SMS: int = 20
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: str | None = None
VLLM_DEBUG_DUMP_PATH: str | None = None
VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True
@ -354,6 +354,24 @@ def env_list_with_choices(
return _get_validated_env_list
def env_set_with_choices(
env_name: str,
default: list[str],
choices: list[str] | Callable[[], list[str]],
case_sensitive: bool = True,
) -> Callable[[], set[str]]:
"""
Creates a lambda which that validates environment variable
containing comma-separated values against allowed choices which
returns choices as a set.
"""
def _get_validated_env_set() -> set[str]:
return set(env_list_with_choices(env_name, default, choices, case_sensitive)())
return _get_validated_env_set
def get_vllm_port() -> int | None:
"""Get the port from VLLM_PORT environment variable.
@ -1328,6 +1346,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Valid values are container,code_interpreter,web_search_preview
# ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
# If the server_label of your mcp tool is not in this list it will
# be completely ignored.
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_set_with_choices(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
default=[],
choices=["container", "code_interpreter", "web_search_preview"],
),
# Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool(
int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0"))
@ -1379,13 +1406,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# The number of SMs to allocate for communication kernels when running DBO
# the rest of the SMs on the device will be allocated to compute
"VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")),
# Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_list_with_choices(
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
[],
["container", "code_interpreter", "web_search_preview"],
),
# Enable max_autotune & coordinate_descent_tuning in inductor_config
# to compile static shapes passed from compile_sizes in compilation_config
# If set to 1, enable max_autotune; By default, this is enabled (1)