[Frontend] Responses API MCP tools for built in tools and to pass through headers (#24628)

Signed-off-by: Alec Solder <alecs@fb.com>
Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com>
Co-authored-by: Alec Solder <alecs@fb.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Alec S 2025-09-22 19:38:19 -04:00 committed by GitHub
parent 8bed179109
commit 45d7d852d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 463 additions and 29 deletions

View File

@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
"code_interpreter,container")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888"
}],
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens == 0

View File

@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Multiply 64548*15151 using builtin python interpreter.",
# TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported
# would speed up the test
input=("What's the first 4 digits after the decimal point of "
"cube root of `19910212 * 20250910`? "
"Show only the digits. The python interpreter is not stateful "
"and you must print to see the output."),
tools=[{
"type": "code_interpreter",
"container": {
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
)
assert response is not None
assert response.status == "completed"
assert response.usage.output_tokens_details.tool_output_tokens > 0
def get_weather(latitude, longitude):

216
tests/test_envs.py Normal file
View File

@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from unittest.mock import patch
import pytest
from vllm.envs import env_list_with_choices, env_with_choices
class TestEnvWithChoices:
"""Test cases for env_with_choices function."""
def test_default_value_returned_when_env_not_set(self):
"""Test default is returned when env var is not set."""
env_func = env_with_choices("NONEXISTENT_ENV", "default",
["option1", "option2"])
assert env_func() == "default"
def test_none_default_returned_when_env_not_set(self):
"""Test that None is returned when env not set and default is None."""
env_func = env_with_choices("NONEXISTENT_ENV", None,
["option1", "option2"])
assert env_func() is None
def test_valid_value_returned_case_sensitive(self):
"""Test that valid value is returned in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
assert env_func() == "option1"
def test_valid_lowercase_value_returned_case_insensitive(self):
"""Test that lowercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["OPTION1", "OPTION2"],
case_sensitive=False)
assert env_func() == "option1"
def test_valid_uppercase_value_returned_case_insensitive(self):
"""Test that uppercase value is accepted in case insensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
assert env_func() == "OPTION1"
def test_invalid_value_raises_error_case_sensitive(self):
"""Test that invalid value raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_case_mismatch_raises_error_case_sensitive(self):
"""Test that case mismatch raises ValueError in case sensitive mode."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION1' for TEST_ENV"):
env_func()
def test_invalid_value_raises_error_case_insensitive(self):
"""Test that invalid value raises ValueError when case insensitive."""
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV",
"default", ["option1", "option2"],
case_sensitive=False)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
assert env_func() == "dynamic1"
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
env_func = env_with_choices("TEST_ENV", "default", get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' for TEST_ENV"):
env_func()
class TestEnvListWithChoices:
"""Test cases for env_list_with_choices function."""
def test_default_list_returned_when_env_not_set(self):
"""Test that default list is returned when env var is not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV",
["default1", "default2"],
["option1", "option2"])
assert env_func() == ["default1", "default2"]
def test_empty_default_list_returned_when_env_not_set(self):
"""Test that empty default list is returned when env not set."""
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
["option1", "option2"])
assert env_func() == []
def test_single_valid_value_parsed_correctly(self):
"""Test that single valid value is parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1"]
def test_multiple_valid_values_parsed_correctly(self):
"""Test that multiple valid values are parsed correctly."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_values_with_whitespace_trimmed(self):
"""Test that values with whitespace are trimmed correctly."""
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_values_filtered_out(self):
"""Test that empty values are filtered out."""
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option2"]
def test_empty_string_returns_default(self):
"""Test that empty string returns default."""
with patch.dict(os.environ, {"TEST_ENV": ""}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_only_commas_returns_default(self):
"""Test that string with only commas returns default."""
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
env_func = env_list_with_choices("TEST_ENV", ["default"],
["option1", "option2"])
assert env_func() == ["default"]
def test_case_sensitive_validation(self):
"""Test case sensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=True)
with pytest.raises(ValueError,
match="Invalid value 'OPTION2' in TEST_ENV"):
env_func()
def test_case_insensitive_validation(self):
"""Test case insensitive validation."""
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"],
case_sensitive=False)
assert env_func() == ["OPTION1", "option2"]
def test_invalid_value_in_list_raises_error(self):
"""Test that invalid value in list raises ValueError."""
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_callable_choices_resolved_correctly(self):
"""Test that callable choices are resolved correctly."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
assert env_func() == ["dynamic1", "dynamic2"]
def test_callable_choices_with_invalid_value(self):
"""Test that callable choices raise error for invalid values."""
def get_choices():
return ["dynamic1", "dynamic2"]
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
with pytest.raises(ValueError,
match="Invalid value 'invalid' in TEST_ENV"):
env_func()
def test_duplicate_values_preserved(self):
"""Test that duplicate values in the list are preserved."""
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_list_with_choices("TEST_ENV", [],
["option1", "option2"])
assert env_func() == ["option1", "option1", "option2"]

View File

@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent
from vllm.entrypoints.harmony_utils import (
@ -21,6 +22,24 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# This is currently needed as the tool type doesn't 1:1 match the
# tool namespace, which is what is used to look up the
# connection to the tool server
_TOOL_NAME_TO_TYPE_MAP = {
"browser": "web_search_preview",
"python": "code_interpreter",
"container": "container",
}
def _map_tool_name_to_tool_type(tool_name: str) -> str:
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
raise ValueError(
f"Built-in tool name '{tool_name}' not defined in mapping. "
f"Available tools: {available_tools}")
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
@ -59,8 +78,8 @@ class ConversationContext(ABC):
@abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
@abstractmethod
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]) -> None:
pass
async def cleanup_session(self) -> None:
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
]
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack,
request_id: str) -> None:
exit_stack: AsyncExitStack, request_id: str,
mcp_tools: dict[str, Mcp]):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = mcp_tools[
tool_type].headers if tool_type in mcp_tools else None
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id))
tool_server.new_session(tool_name, request_id,
headers))
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

View File

@ -126,8 +126,10 @@ def get_developer_message(
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter",
"container"):
"container", "mcp"):
# These are built-in tools that are added to the system message.
# Adding in MCP for now until we support MCP tools executed
# server side
pass
elif tool.type == "function":

View File

@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
async for _ in result_generator:
pass
except asyncio.CancelledError:
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
# New conversation.
reasoning_effort = (request.reasoning.effort
if request.reasoning else None)
# Temporary: OpenAI types doesn't have container tool
# so we used MCP to cover that, up for change
tool_types = [tool.type for tool in request.tools]
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
tool_types.append("container")
# Allow the MCP Tool type to enable built in tools if the
# server_label is allowlisted in
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
for tool in request.tools:
if (tool.type == "mcp" and tool.server_label
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
tool_types.append(tool.server_label)
enable_browser = ("web_search_preview" in tool_types
and self.tool_server is not None
and self.tool_server.has_tool("browser"))
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id)
request.request_id, mcp_tools)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events

View File

@ -18,7 +18,6 @@ if TYPE_CHECKING:
async def list_server_and_tools(server_url: str):
from mcp import ClientSession
from mcp.client.sse import sse_client
async with sse_client(url=server_url) as streams, ClientSession(
*streams) as session:
initialize_response = await session.initialize()
@ -86,8 +85,12 @@ class ToolServer(ABC):
pass
@abstractmethod
def new_session(self, tool_name: str,
session_id: str) -> AbstractAsyncContextManager[Any]:
def new_session(
self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None
) -> AbstractAsyncContextManager[Any]:
"""
Create a session for the tool.
"""
@ -144,15 +147,20 @@ class MCPToolServer(ToolServer):
return self.harmony_tool_descriptions.get(tool_name)
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
from mcp import ClientSession
from mcp.client.sse import sse_client
url = self.urls.get(tool_name)
headers = {"x-session-id": session_id}
request_headers = {"x-session-id": session_id}
if headers is not None:
request_headers.update(headers)
if not url:
raise KeyError(f"Tool '{tool_name}' is not supported")
async with sse_client(url=url,
headers=headers) as streams, ClientSession(
async with sse_client(
url=url, headers=request_headers) as streams, ClientSession(
*streams) as session:
await session.initialize()
yield session
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
raise ValueError(f"Unknown tool {tool_name}")
@asynccontextmanager
async def new_session(self, tool_name: str, session_id: str):
async def new_session(self,
tool_name: str,
session_id: str,
headers: Optional[dict[str, str]] = None):
if tool_name not in self.tools:
raise KeyError(f"Tool '{tool_name}' is not supported")
yield self.tools[tool_name]

View File

@ -185,11 +185,11 @@ if TYPE_CHECKING:
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
@ -261,6 +261,58 @@ def env_with_choices(
return _get_validated_env
def env_list_with_choices(
env_name: str,
default: list[str],
choices: Union[list[str], Callable[[], list[str]]],
case_sensitive: bool = True) -> Callable[[], list[str]]:
"""
Create a lambda that validates environment variable
containing comma-separated values against allowed choices
Args:
env_name: Name of the environment variable
default: Default list of values if not set
choices: List of valid string options or callable that returns list
case_sensitive: Whether validation should be case sensitive
Returns:
Lambda function for environment_variables
dict that returns list of strings
"""
def _get_validated_env_list() -> list[str]:
value = os.getenv(env_name)
if value is None:
return default
# Split comma-separated values and strip whitespace
values = [v.strip() for v in value.split(",") if v.strip()]
if not values:
return default
# Resolve choices if it's a callable (for lazy loading)
actual_choices = choices() if callable(choices) else choices
# Validate each value
for val in values:
if not case_sensitive:
check_value = val.lower()
check_choices = [choice.lower() for choice in actual_choices]
else:
check_value = val
check_choices = actual_choices
if check_value not in check_choices:
raise ValueError(f"Invalid value '{val}' in {env_name}. "
f"Valid options: {actual_choices}.")
return values
return _get_validated_env_list
def get_vllm_port() -> Optional[int]:
"""Get the port from VLLM_PORT environment variable.
@ -1320,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
# Allows vllm use container tool
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
# Allows harmony instructions to be injected on system messages
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
lambda: bool(
@ -1343,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
# Valid values are container,code_interpreter,web_search_preview
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
["container",
"code_interpreter",
"web_search_preview"]),
}
# --8<-- [end:env-vars-definition]