mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:45:34 +08:00
[Frontend] Responses API MCP tools for built in tools and to pass through headers (#24628)
Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
8bed179109
commit
45d7d852d3
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
MODEL_NAME = "openai/gpt-oss-20b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def monkeypatch_module():
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
mpatch = MonkeyPatch()
|
||||||
|
yield mpatch
|
||||||
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||||
|
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||||
|
|
||||||
|
with monkeypatch_module.context() as m:
|
||||||
|
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||||
|
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||||
|
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||||
|
|
||||||
|
with monkeypatch_module.context() as m:
|
||||||
|
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||||
|
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||||
|
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
|
||||||
|
"code_interpreter,container")
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mcp_disabled_client(mcp_disabled_server):
|
||||||
|
async with mcp_disabled_server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mcp_enabled_client(mcp_enabled_server):
|
||||||
|
async with mcp_enabled_server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||||
|
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
|
||||||
|
model_name: str):
|
||||||
|
response = await mcp_enabled_client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
# TODO: Ideally should be able to set max tool calls
|
||||||
|
# to prevent multi-turn, but it is not currently supported
|
||||||
|
# would speed up the test
|
||||||
|
input=("What's the first 4 digits after the decimal point of "
|
||||||
|
"cube root of `19910212 * 20250910`? "
|
||||||
|
"Show only the digits. The python interpreter is not stateful "
|
||||||
|
"and you must print to see the output."),
|
||||||
|
tools=[{
|
||||||
|
"type": "mcp",
|
||||||
|
"server_label": "code_interpreter",
|
||||||
|
# URL unused for DemoToolServer
|
||||||
|
"server_url": "http://localhost:8888"
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert response.status == "completed"
|
||||||
|
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||||
|
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
|
||||||
|
model_name: str):
|
||||||
|
response = await mcp_disabled_client.responses.create(
|
||||||
|
model=model_name,
|
||||||
|
# TODO: Ideally should be able to set max tool calls
|
||||||
|
# to prevent multi-turn, but it is not currently supported
|
||||||
|
# would speed up the test
|
||||||
|
input=("What's the first 4 digits after the decimal point of "
|
||||||
|
"cube root of `19910212 * 20250910`? "
|
||||||
|
"Show only the digits. The python interpreter is not stateful "
|
||||||
|
"and you must print to see the output."),
|
||||||
|
tools=[{
|
||||||
|
"type": "mcp",
|
||||||
|
"server_label": "code_interpreter",
|
||||||
|
# URL unused for DemoToolServer
|
||||||
|
"server_url": "http://localhost:8888"
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
assert response is not None
|
||||||
|
assert response.status == "completed"
|
||||||
|
assert response.usage.output_tokens_details.tool_output_tokens == 0
|
||||||
@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
|
|||||||
async def test_code_interpreter(client: OpenAI, model_name: str):
|
async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||||
response = await client.responses.create(
|
response = await client.responses.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input="Multiply 64548*15151 using builtin python interpreter.",
|
# TODO: Ideally should be able to set max tool calls
|
||||||
|
# to prevent multi-turn, but it is not currently supported
|
||||||
|
# would speed up the test
|
||||||
|
input=("What's the first 4 digits after the decimal point of "
|
||||||
|
"cube root of `19910212 * 20250910`? "
|
||||||
|
"Show only the digits. The python interpreter is not stateful "
|
||||||
|
"and you must print to see the output."),
|
||||||
tools=[{
|
tools=[{
|
||||||
"type": "code_interpreter",
|
"type": "code_interpreter",
|
||||||
"container": {
|
"container": {
|
||||||
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
|
|||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.status == "completed"
|
assert response.status == "completed"
|
||||||
|
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def get_weather(latitude, longitude):
|
def get_weather(latitude, longitude):
|
||||||
|
|||||||
216
tests/test_envs.py
Normal file
216
tests/test_envs.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.envs import env_list_with_choices, env_with_choices
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvWithChoices:
|
||||||
|
"""Test cases for env_with_choices function."""
|
||||||
|
|
||||||
|
def test_default_value_returned_when_env_not_set(self):
|
||||||
|
"""Test default is returned when env var is not set."""
|
||||||
|
env_func = env_with_choices("NONEXISTENT_ENV", "default",
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == "default"
|
||||||
|
|
||||||
|
def test_none_default_returned_when_env_not_set(self):
|
||||||
|
"""Test that None is returned when env not set and default is None."""
|
||||||
|
env_func = env_with_choices("NONEXISTENT_ENV", None,
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() is None
|
||||||
|
|
||||||
|
def test_valid_value_returned_case_sensitive(self):
|
||||||
|
"""Test that valid value is returned in case sensitive mode."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["option1", "option2"],
|
||||||
|
case_sensitive=True)
|
||||||
|
assert env_func() == "option1"
|
||||||
|
|
||||||
|
def test_valid_lowercase_value_returned_case_insensitive(self):
|
||||||
|
"""Test that lowercase value is accepted in case insensitive mode."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["OPTION1", "OPTION2"],
|
||||||
|
case_sensitive=False)
|
||||||
|
assert env_func() == "option1"
|
||||||
|
|
||||||
|
def test_valid_uppercase_value_returned_case_insensitive(self):
|
||||||
|
"""Test that uppercase value is accepted in case insensitive mode."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["option1", "option2"],
|
||||||
|
case_sensitive=False)
|
||||||
|
assert env_func() == "OPTION1"
|
||||||
|
|
||||||
|
def test_invalid_value_raises_error_case_sensitive(self):
|
||||||
|
"""Test that invalid value raises ValueError in case sensitive mode."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["option1", "option2"],
|
||||||
|
case_sensitive=True)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'invalid' for TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_case_mismatch_raises_error_case_sensitive(self):
|
||||||
|
"""Test that case mismatch raises ValueError in case sensitive mode."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["option1", "option2"],
|
||||||
|
case_sensitive=True)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'OPTION1' for TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_invalid_value_raises_error_case_insensitive(self):
|
||||||
|
"""Test that invalid value raises ValueError when case insensitive."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV",
|
||||||
|
"default", ["option1", "option2"],
|
||||||
|
case_sensitive=False)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'invalid' for TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_callable_choices_resolved_correctly(self):
|
||||||
|
"""Test that callable choices are resolved correctly."""
|
||||||
|
|
||||||
|
def get_choices():
|
||||||
|
return ["dynamic1", "dynamic2"]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||||
|
assert env_func() == "dynamic1"
|
||||||
|
|
||||||
|
def test_callable_choices_with_invalid_value(self):
|
||||||
|
"""Test that callable choices raise error for invalid values."""
|
||||||
|
|
||||||
|
def get_choices():
|
||||||
|
return ["dynamic1", "dynamic2"]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||||
|
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'invalid' for TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvListWithChoices:
|
||||||
|
"""Test cases for env_list_with_choices function."""
|
||||||
|
|
||||||
|
def test_default_list_returned_when_env_not_set(self):
|
||||||
|
"""Test that default list is returned when env var is not set."""
|
||||||
|
env_func = env_list_with_choices("NONEXISTENT_ENV",
|
||||||
|
["default1", "default2"],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["default1", "default2"]
|
||||||
|
|
||||||
|
def test_empty_default_list_returned_when_env_not_set(self):
|
||||||
|
"""Test that empty default list is returned when env not set."""
|
||||||
|
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == []
|
||||||
|
|
||||||
|
def test_single_valid_value_parsed_correctly(self):
|
||||||
|
"""Test that single valid value is parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["option1"]
|
||||||
|
|
||||||
|
def test_multiple_valid_values_parsed_correctly(self):
|
||||||
|
"""Test that multiple valid values are parsed correctly."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["option1", "option2"]
|
||||||
|
|
||||||
|
def test_values_with_whitespace_trimmed(self):
|
||||||
|
"""Test that values with whitespace are trimmed correctly."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["option1", "option2"]
|
||||||
|
|
||||||
|
def test_empty_values_filtered_out(self):
|
||||||
|
"""Test that empty values are filtered out."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["option1", "option2"]
|
||||||
|
|
||||||
|
def test_empty_string_returns_default(self):
|
||||||
|
"""Test that empty string returns default."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": ""}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["default"]
|
||||||
|
|
||||||
|
def test_only_commas_returns_default(self):
|
||||||
|
"""Test that string with only commas returns default."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["default"]
|
||||||
|
|
||||||
|
def test_case_sensitive_validation(self):
|
||||||
|
"""Test case sensitive validation."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"],
|
||||||
|
case_sensitive=True)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'OPTION2' in TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_case_insensitive_validation(self):
|
||||||
|
"""Test case insensitive validation."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"],
|
||||||
|
case_sensitive=False)
|
||||||
|
assert env_func() == ["OPTION1", "option2"]
|
||||||
|
|
||||||
|
def test_invalid_value_in_list_raises_error(self):
|
||||||
|
"""Test that invalid value in list raises ValueError."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'invalid' in TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_callable_choices_resolved_correctly(self):
|
||||||
|
"""Test that callable choices are resolved correctly."""
|
||||||
|
|
||||||
|
def get_choices():
|
||||||
|
return ["dynamic1", "dynamic2"]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||||
|
assert env_func() == ["dynamic1", "dynamic2"]
|
||||||
|
|
||||||
|
def test_callable_choices_with_invalid_value(self):
|
||||||
|
"""Test that callable choices raise error for invalid values."""
|
||||||
|
|
||||||
|
def get_choices():
|
||||||
|
return ["dynamic1", "dynamic2"]
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Invalid value 'invalid' in TEST_ENV"):
|
||||||
|
env_func()
|
||||||
|
|
||||||
|
def test_duplicate_values_preserved(self):
|
||||||
|
"""Test that duplicate values in the list are preserved."""
|
||||||
|
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
|
||||||
|
env_func = env_list_with_choices("TEST_ENV", [],
|
||||||
|
["option1", "option2"])
|
||||||
|
assert env_func() == ["option1", "option1", "option2"]
|
||||||
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
|||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
from openai.types.responses.tool import Mcp
|
||||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||||
|
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
@ -21,6 +22,24 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# This is currently needed as the tool type doesn't 1:1 match the
|
||||||
|
# tool namespace, which is what is used to look up the
|
||||||
|
# connection to the tool server
|
||||||
|
_TOOL_NAME_TO_TYPE_MAP = {
|
||||||
|
"browser": "web_search_preview",
|
||||||
|
"python": "code_interpreter",
|
||||||
|
"container": "container",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||||
|
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||||
|
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||||
|
raise ValueError(
|
||||||
|
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||||
|
f"Available tools: {available_tools}")
|
||||||
|
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||||
|
|
||||||
|
|
||||||
class TurnTokens:
|
class TurnTokens:
|
||||||
"""Tracks token counts for a single conversation turn."""
|
"""Tracks token counts for a single conversation turn."""
|
||||||
@ -59,8 +78,8 @@ class ConversationContext(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack, request_id: str,
|
||||||
request_id: str) -> None:
|
mcp_tools: dict[str, Mcp]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
|
|||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack, request_id: str,
|
||||||
request_id: str) -> None:
|
mcp_tools: dict[str, Mcp]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def cleanup_session(self) -> None:
|
async def cleanup_session(self) -> None:
|
||||||
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
exit_stack: AsyncExitStack,
|
exit_stack: AsyncExitStack, request_id: str,
|
||||||
request_id: str) -> None:
|
mcp_tools: dict[str, Mcp]):
|
||||||
if tool_server:
|
if tool_server:
|
||||||
for tool_name in self.available_tools:
|
for tool_name in self.available_tools:
|
||||||
if tool_name not in self._tool_sessions:
|
if tool_name not in self._tool_sessions:
|
||||||
|
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||||
|
headers = mcp_tools[
|
||||||
|
tool_type].headers if tool_type in mcp_tools else None
|
||||||
tool_session = await exit_stack.enter_async_context(
|
tool_session = await exit_stack.enter_async_context(
|
||||||
tool_server.new_session(tool_name, request_id))
|
tool_server.new_session(tool_name, request_id,
|
||||||
|
headers))
|
||||||
self._tool_sessions[tool_name] = tool_session
|
self._tool_sessions[tool_name] = tool_session
|
||||||
exit_stack.push_async_exit(self.cleanup_session)
|
exit_stack.push_async_exit(self.cleanup_session)
|
||||||
|
|
||||||
|
|||||||
@ -126,8 +126,10 @@ def get_developer_message(
|
|||||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.type in ("web_search_preview", "code_interpreter",
|
if tool.type in ("web_search_preview", "code_interpreter",
|
||||||
"container"):
|
"container", "mcp"):
|
||||||
# These are built-in tools that are added to the system message.
|
# These are built-in tools that are added to the system message.
|
||||||
|
# Adding in MCP for now until we support MCP tools executed
|
||||||
|
# server side
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif tool.type == "function":
|
elif tool.type == "function":
|
||||||
|
|||||||
@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
try:
|
try:
|
||||||
|
mcp_tools = {
|
||||||
|
tool.server_label: tool
|
||||||
|
for tool in request.tools if tool.type == "mcp"
|
||||||
|
}
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||||
request.request_id)
|
request.request_id, mcp_tools)
|
||||||
async for _ in result_generator:
|
async for _ in result_generator:
|
||||||
pass
|
pass
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
# New conversation.
|
# New conversation.
|
||||||
reasoning_effort = (request.reasoning.effort
|
reasoning_effort = (request.reasoning.effort
|
||||||
if request.reasoning else None)
|
if request.reasoning else None)
|
||||||
# Temporary: OpenAI types doesn't have container tool
|
|
||||||
# so we used MCP to cover that, up for change
|
|
||||||
tool_types = [tool.type for tool in request.tools]
|
tool_types = [tool.type for tool in request.tools]
|
||||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
|
||||||
tool_types.append("container")
|
# Allow the MCP Tool type to enable built in tools if the
|
||||||
|
# server_label is allowlisted in
|
||||||
|
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
|
||||||
|
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
|
||||||
|
for tool in request.tools:
|
||||||
|
if (tool.type == "mcp" and tool.server_label
|
||||||
|
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
|
||||||
|
tool_types.append(tool.server_label)
|
||||||
enable_browser = ("web_search_preview" in tool_types
|
enable_browser = ("web_search_preview" in tool_types
|
||||||
and self.tool_server is not None
|
and self.tool_server is not None
|
||||||
and self.tool_server.has_tool("browser"))
|
and self.tool_server.has_tool("browser"))
|
||||||
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
async with AsyncExitStack() as exit_stack:
|
async with AsyncExitStack() as exit_stack:
|
||||||
processer = None
|
processer = None
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
|
mcp_tools = {
|
||||||
|
tool.server_label: tool
|
||||||
|
for tool in request.tools if tool.type == "mcp"
|
||||||
|
}
|
||||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||||
request.request_id)
|
request.request_id, mcp_tools)
|
||||||
processer = self._process_harmony_streaming_events
|
processer = self._process_harmony_streaming_events
|
||||||
else:
|
else:
|
||||||
processer = self._process_simple_streaming_events
|
processer = self._process_simple_streaming_events
|
||||||
|
|||||||
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
|||||||
async def list_server_and_tools(server_url: str):
|
async def list_server_and_tools(server_url: str):
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
async with sse_client(url=server_url) as streams, ClientSession(
|
async with sse_client(url=server_url) as streams, ClientSession(
|
||||||
*streams) as session:
|
*streams) as session:
|
||||||
initialize_response = await session.initialize()
|
initialize_response = await session.initialize()
|
||||||
@ -86,8 +85,12 @@ class ToolServer(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def new_session(self, tool_name: str,
|
def new_session(
|
||||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
session_id: str,
|
||||||
|
headers: Optional[dict[str, str]] = None
|
||||||
|
) -> AbstractAsyncContextManager[Any]:
|
||||||
"""
|
"""
|
||||||
Create a session for the tool.
|
Create a session for the tool.
|
||||||
"""
|
"""
|
||||||
@ -144,15 +147,20 @@ class MCPToolServer(ToolServer):
|
|||||||
return self.harmony_tool_descriptions.get(tool_name)
|
return self.harmony_tool_descriptions.get(tool_name)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_session(self, tool_name: str, session_id: str):
|
async def new_session(self,
|
||||||
|
tool_name: str,
|
||||||
|
session_id: str,
|
||||||
|
headers: Optional[dict[str, str]] = None):
|
||||||
from mcp import ClientSession
|
from mcp import ClientSession
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
url = self.urls.get(tool_name)
|
url = self.urls.get(tool_name)
|
||||||
headers = {"x-session-id": session_id}
|
request_headers = {"x-session-id": session_id}
|
||||||
|
if headers is not None:
|
||||||
|
request_headers.update(headers)
|
||||||
if not url:
|
if not url:
|
||||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
async with sse_client(url=url,
|
async with sse_client(
|
||||||
headers=headers) as streams, ClientSession(
|
url=url, headers=request_headers) as streams, ClientSession(
|
||||||
*streams) as session:
|
*streams) as session:
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
yield session
|
yield session
|
||||||
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
|
|||||||
raise ValueError(f"Unknown tool {tool_name}")
|
raise ValueError(f"Unknown tool {tool_name}")
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def new_session(self, tool_name: str, session_id: str):
|
async def new_session(self,
|
||||||
|
tool_name: str,
|
||||||
|
session_id: str,
|
||||||
|
headers: Optional[dict[str, str]] = None):
|
||||||
if tool_name not in self.tools:
|
if tool_name not in self.tools:
|
||||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||||
yield self.tools[tool_name]
|
yield self.tools[tool_name]
|
||||||
|
|||||||
66
vllm/envs.py
66
vllm/envs.py
@ -185,11 +185,11 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
|
||||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||||
|
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@ -261,6 +261,58 @@ def env_with_choices(
|
|||||||
return _get_validated_env
|
return _get_validated_env
|
||||||
|
|
||||||
|
|
||||||
|
def env_list_with_choices(
|
||||||
|
env_name: str,
|
||||||
|
default: list[str],
|
||||||
|
choices: Union[list[str], Callable[[], list[str]]],
|
||||||
|
case_sensitive: bool = True) -> Callable[[], list[str]]:
|
||||||
|
"""
|
||||||
|
Create a lambda that validates environment variable
|
||||||
|
containing comma-separated values against allowed choices
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_name: Name of the environment variable
|
||||||
|
default: Default list of values if not set
|
||||||
|
choices: List of valid string options or callable that returns list
|
||||||
|
case_sensitive: Whether validation should be case sensitive
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lambda function for environment_variables
|
||||||
|
dict that returns list of strings
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_validated_env_list() -> list[str]:
|
||||||
|
value = os.getenv(env_name)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Split comma-separated values and strip whitespace
|
||||||
|
values = [v.strip() for v in value.split(",") if v.strip()]
|
||||||
|
|
||||||
|
if not values:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Resolve choices if it's a callable (for lazy loading)
|
||||||
|
actual_choices = choices() if callable(choices) else choices
|
||||||
|
|
||||||
|
# Validate each value
|
||||||
|
for val in values:
|
||||||
|
if not case_sensitive:
|
||||||
|
check_value = val.lower()
|
||||||
|
check_choices = [choice.lower() for choice in actual_choices]
|
||||||
|
else:
|
||||||
|
check_value = val
|
||||||
|
check_choices = actual_choices
|
||||||
|
|
||||||
|
if check_value not in check_choices:
|
||||||
|
raise ValueError(f"Invalid value '{val}' in {env_name}. "
|
||||||
|
f"Valid options: {actual_choices}.")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
return _get_validated_env_list
|
||||||
|
|
||||||
|
|
||||||
def get_vllm_port() -> Optional[int]:
|
def get_vllm_port() -> Optional[int]:
|
||||||
"""Get the port from VLLM_PORT environment variable.
|
"""Get the port from VLLM_PORT environment variable.
|
||||||
|
|
||||||
@ -1320,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_TUNED_CONFIG_FOLDER":
|
"VLLM_TUNED_CONFIG_FOLDER":
|
||||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||||
|
|
||||||
# Allows vllm use container tool
|
|
||||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
|
||||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
|
||||||
|
|
||||||
# Allows harmony instructions to be injected on system messages
|
# Allows harmony instructions to be injected on system messages
|
||||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
||||||
lambda: bool(
|
lambda: bool(
|
||||||
@ -1343,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
|
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
|
||||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||||
|
|
||||||
|
# Valid values are container,code_interpreter,web_search_preview
|
||||||
|
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||||
|
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||||
|
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
|
||||||
|
["container",
|
||||||
|
"code_interpreter",
|
||||||
|
"web_search_preview"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user