diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py index 653d44f20b44..0dc2430caef7 100644 --- a/tests/entrypoints/openai/test_response_api_mcp_tools.py +++ b/tests/entrypoints/openai/test_response_api_mcp_tools.py @@ -26,6 +26,8 @@ def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch): with monkeypatch_module.context() as m: m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") + # Helps the model follow instructions better + m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -37,7 +39,9 @@ def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): with monkeypatch_module.context() as m: m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") - m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") + m.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") + # Helps the model follow instructions better + m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -56,18 +60,15 @@ async def mcp_enabled_client(mcp_enabled_server): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str): response = await mcp_enabled_client.responses.create( model=model_name, - # TODO: Ideally should be able to set max tool calls - # to prevent multi-turn, but it is not currently supported - # would speed up the test input=( - "What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output." + "Execute the following code: " + "import random; print(random.randint(1, 1000000))" + ), + instructions=( + "You must use the Python tool to execute code. Never simulate execution." ), tools=[ { @@ -77,26 +78,47 @@ async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: "server_url": "http://localhost:8888", } ], + extra_body={"enable_response_messages": True}, ) assert response is not None assert response.status == "completed" - assert response.usage.output_tokens_details.tool_output_tokens > 0 + # Verify output messages: Tool calls and responses on analysis channel + tool_call_found = False + tool_response_found = False + for message in response.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("python"): + tool_call_found = True + assert message.get("channel") == "analysis", ( + "Tool call should be on analysis channel" + ) + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("python") + ): + tool_response_found = True + assert message.get("channel") == "analysis", ( + "Tool response should be on analysis channel" + ) + + assert tool_call_found, "Should have found at least one Python tool call" + assert tool_response_found, "Should have found at least one Python tool response" + for message in response.input_messages: + assert message.get("author").get("role") != "developer", ( + "No developer messages should be present with valid mcp tool" + ) @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): response = await mcp_disabled_client.responses.create( model=model_name, - # TODO: Ideally should be able to set max tool calls - # to prevent multi-turn, but it is not currently supported - # would speed up the test input=( - "What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output." + "Execute the following code if the tool is present: " + "import random; print(random.randint(1, 1000000))" ), tools=[ { @@ -106,7 +128,34 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam "server_url": "http://localhost:8888", } ], + extra_body={"enable_response_messages": True}, ) assert response is not None assert response.status == "completed" - assert response.usage.output_tokens_details.tool_output_tokens == 0 + # Verify output messages: No tool calls and responses + tool_call_found = False + tool_response_found = False + for message in response.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("python"): + tool_call_found = True + assert message.get("channel") == "analysis", ( + "Tool call should be on analysis channel" + ) + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("python") + ): + tool_response_found = True + assert message.get("channel") == "analysis", ( + "Tool response should be on analysis channel" + ) + + assert not tool_call_found, "Should not have a python call" + assert not tool_response_found, "Should not have a tool response" + for message in response.input_messages: + assert message.get("author").get("role") != "developer", ( + "No developer messages should be present without a valid tool" + ) diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index cf21a5116ddf..788a1e912182 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -6,10 +6,19 @@ from unittest.mock import MagicMock import pytest import pytest_asyncio +from openai.types.responses.tool import ( + CodeInterpreterContainerCodeInterpreterToolAuto, + LocalShell, + Mcp, + Tool, +) from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.openai.protocol import ErrorResponse, ResponsesRequest -from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses +from vllm.entrypoints.openai.serving_responses import ( + OpenAIServingResponses, + extract_tool_types, +) from vllm.entrypoints.tool_server import ToolServer from vllm.inputs.data import TokensPrompt as EngineTokensPrompt @@ -62,6 +71,45 @@ def mock_exit_stack(): return MagicMock(spec=AsyncExitStack) +def test_extract_tool_types(monkeypatch: pytest.MonkeyPatch) -> None: + tools: list[Tool] = [] + assert extract_tool_types(tools) == set() + + tools.append(LocalShell(type="local_shell")) + assert extract_tool_types(tools) == {"local_shell"} + + tools.append(CodeInterpreterContainerCodeInterpreterToolAuto(type="auto")) + assert extract_tool_types(tools) == {"local_shell", "auto"} + + tools.extend( + [ + Mcp(type="mcp", server_label="random", server_url=""), + Mcp(type="mcp", server_label="container", server_url=""), + Mcp(type="mcp", server_label="code_interpreter", server_url=""), + Mcp(type="mcp", server_label="web_search_preview", server_url=""), + ] + ) + # When envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS is not set, + # mcp tool types are all ignored. + assert extract_tool_types(tools) == {"local_shell", "auto"} + + # container is allowed, it would be extracted + monkeypatch.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "container") + assert extract_tool_types(tools) == {"local_shell", "auto", "container"} + + # code_interpreter and web_search_preview are allowed, + # they would be extracted + monkeypatch.setenv( + "VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,web_search_preview" + ) + assert extract_tool_types(tools) == { + "local_shell", + "auto", + "code_interpreter", + "web_search_preview", + } + + class TestInitializeToolSessions: """Test class for _initialize_tool_sessions method""" diff --git a/tests/entrypoints/test_harmony_utils.py b/tests/entrypoints/test_harmony_utils.py index 8d1764d41157..6fa051a678d6 100644 --- a/tests/entrypoints/test_harmony_utils.py +++ b/tests/entrypoints/test_harmony_utils.py @@ -3,7 +3,10 @@ from openai_harmony import Role -from vllm.entrypoints.harmony_utils import parse_input_to_harmony_message +from vllm.entrypoints.harmony_utils import ( + has_custom_tools, + parse_input_to_harmony_message, +) class TestParseInputToHarmonyMessage: @@ -252,3 +255,12 @@ class TestParseInputToHarmonyMessage: assert len(messages[0].content) == 2 assert messages[0].content[0].text == "" assert messages[0].content[1].text == "actual text" + + +def test_has_custom_tools() -> None: + assert not has_custom_tools(set()) + assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"}) + assert has_custom_tools({"others"}) + assert has_custom_tools( + {"web_search_preview", "code_interpreter", "container", "others"} + ) diff --git a/tests/test_envs.py b/tests/test_envs.py index 023767505f10..841d7945f912 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.envs import ( enable_envs_cache, env_list_with_choices, + env_set_with_choices, env_with_choices, environment_variables, ) @@ -257,3 +258,110 @@ class TestEnvListWithChoices: with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): env_func = env_list_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == ["option1", "option1", "option2"] + + +class TestEnvSetWithChoices: + """Test cases for env_set_with_choices function.""" + + def test_default_list_returned_when_env_not_set(self): + """Test that default list is returned when env var is not set.""" + env_func = env_set_with_choices( + "NONEXISTENT_ENV", ["default1", "default2"], ["option1", "option2"] + ) + assert env_func() == {"default1", "default2"} + + def test_empty_default_list_returned_when_env_not_set(self): + """Test that empty default list is returned when env not set.""" + env_func = env_set_with_choices("NONEXISTENT_ENV", [], ["option1", "option2"]) + assert env_func() == set() + + def test_single_valid_value_parsed_correctly(self): + """Test that single valid value is parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1"}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == {"option1"} + + def test_multiple_valid_values_parsed_correctly(self): + """Test that multiple valid values are parsed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == {"option1", "option2"} + + def test_values_with_whitespace_trimmed(self): + """Test that values with whitespace are trimmed correctly.""" + with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == {"option1", "option2"} + + def test_empty_values_filtered_out(self): + """Test that empty values are filtered out.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == {"option1", "option2"} + + def test_empty_string_returns_default(self): + """Test that empty string returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ""}): + env_func = env_set_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) + assert env_func() == {"default"} + + def test_only_commas_returns_default(self): + """Test that string with only commas returns default.""" + with patch.dict(os.environ, {"TEST_ENV": ",,,"}): + env_func = env_set_with_choices( + "TEST_ENV", ["default"], ["option1", "option2"] + ) + assert env_func() == {"default"} + + def test_case_sensitive_validation(self): + """Test case sensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}): + env_func = env_set_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=True + ) + with pytest.raises(ValueError, match="Invalid value 'OPTION2' in TEST_ENV"): + env_func() + + def test_case_insensitive_validation(self): + """Test case insensitive validation.""" + with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}): + env_func = env_set_with_choices( + "TEST_ENV", [], ["option1", "option2"], case_sensitive=False + ) + assert env_func() == {"OPTION1", "option2"} + + def test_invalid_value_in_list_raises_error(self): + """Test that invalid value in list raises ValueError.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_callable_choices_resolved_correctly(self): + """Test that callable choices are resolved correctly.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}): + env_func = env_set_with_choices("TEST_ENV", [], get_choices) + assert env_func() == {"dynamic1", "dynamic2"} + + def test_callable_choices_with_invalid_value(self): + """Test that callable choices raise error for invalid values.""" + + def get_choices(): + return ["dynamic1", "dynamic2"] + + with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}): + env_func = env_set_with_choices("TEST_ENV", [], get_choices) + with pytest.raises(ValueError, match="Invalid value 'invalid' in TEST_ENV"): + env_func() + + def test_duplicate_values_deduped(self): + """Test that duplicate values in the list are deduped.""" + with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): + env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) + assert env_func() == {"option1", "option2"} diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 8888a5aeb6b1..7958d0317739 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -61,15 +61,19 @@ _harmony_encoding = None # they are available and requested by the user. # Tool args are provided by MCP tool descriptions. Output # of the tools are stringified. -BUILTIN_TOOLS = { +MCP_BUILTIN_TOOLS: set[str] = { "web_search_preview", "code_interpreter", "container", } -def has_custom_tools(tool_types: list[str]) -> bool: - return not set(tool_types).issubset(BUILTIN_TOOLS) +def has_custom_tools(tool_types: set[str]) -> bool: + """ + Checks if the given tool types are custom tools + (i.e. any tool other than MCP buildin tools) + """ + return not tool_types.issubset(MCP_BUILTIN_TOOLS) def get_encoding(): diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index d43bc00a49d3..2ee8de5fba07 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -48,6 +48,7 @@ from openai.types.responses.response_output_text import Logprob, LogprobTopLogpr from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent, ) +from openai.types.responses.tool import Tool from openai_harmony import Message as OpenAIHarmonyMessage from vllm import envs @@ -106,6 +107,23 @@ from vllm.utils import random_uuid logger = init_logger(__name__) +def extract_tool_types(tools: list[Tool]) -> set[str]: + """ + Extracts the tool types from the given tools. + """ + tool_types: set[str] = set() + for tool in tools: + if tool.type == "mcp": + # Allow the MCP Tool type to enable built in tools if the + # server_label is allowlisted in + # envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS + if tool.server_label in envs.VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: + tool_types.add(tool.server_label) + else: + tool_types.add(tool.type) + return tool_types + + class OpenAIServingResponses(OpenAIServing): def __init__( self, @@ -879,7 +897,7 @@ class OpenAIServingResponses(OpenAIServing): return messages def _construct_harmony_system_input_message( - self, request: ResponsesRequest, with_custom_tools: bool, tool_types: list[str] + self, request: ResponsesRequest, with_custom_tools: bool, tool_types: set[str] ) -> OpenAIHarmonyMessage: reasoning_effort = request.reasoning.effort if request.reasoning else None enable_browser = ( @@ -927,17 +945,7 @@ class OpenAIServingResponses(OpenAIServing): messages: list[OpenAIHarmonyMessage] = [] if prev_response is None: # New conversation. - tool_types = [tool.type for tool in request.tools] - # Allow the MCP Tool type to enable built in tools if the - # server_label is allowlisted in - # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS - if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS: - for tool in request.tools: - if ( - tool.type == "mcp" - and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS - ): - tool_types.append(tool.server_label) + tool_types = extract_tool_types(request.tools) with_custom_tools = has_custom_tools(tool_types) sys_msg = self._construct_harmony_system_input_message( diff --git a/vllm/envs.py b/vllm/envs.py index 0786d5d9ddcb..ca1f84bba419 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -198,6 +198,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: str | None = None + VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False @@ -209,7 +210,6 @@ if TYPE_CHECKING: VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK: bool = True VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL: bool = False VLLM_DBO_COMM_SMS: int = 20 - GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = [] VLLM_PATTERN_MATCH_DEBUG: str | None = None VLLM_DEBUG_DUMP_PATH: str | None = None VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE: bool = True @@ -354,6 +354,24 @@ def env_list_with_choices( return _get_validated_env_list +def env_set_with_choices( + env_name: str, + default: list[str], + choices: list[str] | Callable[[], list[str]], + case_sensitive: bool = True, +) -> Callable[[], set[str]]: + """ + Creates a lambda which that validates environment variable + containing comma-separated values against allowed choices which + returns choices as a set. + """ + + def _get_validated_env_set() -> set[str]: + return set(env_list_with_choices(env_name, default, choices, case_sensitive)()) + + return _get_validated_env_set + + def get_vllm_port() -> int | None: """Get the port from VLLM_PORT environment variable. @@ -1328,6 +1346,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), + # Valid values are container,code_interpreter,web_search_preview + # ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter + # If the server_label of your mcp tool is not in this list it will + # be completely ignored. + "VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_set_with_choices( + "VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", + default=[], + choices=["container", "code_interpreter", "web_search_preview"], + ), # Allows harmony instructions to be injected on system messages "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": lambda: bool( int(os.getenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "0")) @@ -1379,13 +1406,6 @@ environment_variables: dict[str, Callable[[], Any]] = { # The number of SMs to allocate for communication kernels when running DBO # the rest of the SMs on the device will be allocated to compute "VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), - # Valid values are container,code_interpreter,web_search_preview - # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter - "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_list_with_choices( - "GPT_OSS_SYSTEM_TOOL_MCP_LABELS", - [], - ["container", "code_interpreter", "web_search_preview"], - ), # Enable max_autotune & coordinate_descent_tuning in inductor_config # to compile static shapes passed from compile_sizes in compilation_config # If set to 1, enable max_autotune; By default, this is enabled (1)