mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:35:01 +08:00
[Frontend] Responses API MCP tools for built in tools and to pass through headers (#24628)
Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
8bed179109
commit
45d7d852d3
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
106
tests/entrypoints/openai/test_response_api_mcp_tools.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
args = ["--enforce-eager", "--tool-server", "demo"]
|
||||
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
|
||||
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
|
||||
m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS",
|
||||
"code_interpreter,container")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_disabled_client(mcp_disabled_server):
|
||||
async with mcp_disabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mcp_enabled_client(mcp_enabled_server):
|
||||
async with mcp_enabled_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI,
|
||||
model_name: str):
|
||||
response = await mcp_enabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.")
|
||||
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI,
|
||||
model_name: str):
|
||||
response = await mcp_disabled_client.responses.create(
|
||||
model=model_name,
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "mcp",
|
||||
"server_label": "code_interpreter",
|
||||
# URL unused for DemoToolServer
|
||||
"server_url": "http://localhost:8888"
|
||||
}],
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens == 0
|
||||
@ -454,7 +454,13 @@ async def test_web_search(client: OpenAI, model_name: str):
|
||||
async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="Multiply 64548*15151 using builtin python interpreter.",
|
||||
# TODO: Ideally should be able to set max tool calls
|
||||
# to prevent multi-turn, but it is not currently supported
|
||||
# would speed up the test
|
||||
input=("What's the first 4 digits after the decimal point of "
|
||||
"cube root of `19910212 * 20250910`? "
|
||||
"Show only the digits. The python interpreter is not stateful "
|
||||
"and you must print to see the output."),
|
||||
tools=[{
|
||||
"type": "code_interpreter",
|
||||
"container": {
|
||||
@ -464,6 +470,7 @@ async def test_code_interpreter(client: OpenAI, model_name: str):
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert response.usage.output_tokens_details.tool_output_tokens > 0
|
||||
|
||||
|
||||
def get_weather(latitude, longitude):
|
||||
|
||||
216
tests/test_envs.py
Normal file
216
tests/test_envs.py
Normal file
@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.envs import env_list_with_choices, env_with_choices
|
||||
|
||||
|
||||
class TestEnvWithChoices:
|
||||
"""Test cases for env_with_choices function."""
|
||||
|
||||
def test_default_value_returned_when_env_not_set(self):
|
||||
"""Test default is returned when env var is not set."""
|
||||
env_func = env_with_choices("NONEXISTENT_ENV", "default",
|
||||
["option1", "option2"])
|
||||
assert env_func() == "default"
|
||||
|
||||
def test_none_default_returned_when_env_not_set(self):
|
||||
"""Test that None is returned when env not set and default is None."""
|
||||
env_func = env_with_choices("NONEXISTENT_ENV", None,
|
||||
["option1", "option2"])
|
||||
assert env_func() is None
|
||||
|
||||
def test_valid_value_returned_case_sensitive(self):
|
||||
"""Test that valid value is returned in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
assert env_func() == "option1"
|
||||
|
||||
def test_valid_lowercase_value_returned_case_insensitive(self):
|
||||
"""Test that lowercase value is accepted in case insensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["OPTION1", "OPTION2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == "option1"
|
||||
|
||||
def test_valid_uppercase_value_returned_case_insensitive(self):
|
||||
"""Test that uppercase value is accepted in case insensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == "OPTION1"
|
||||
|
||||
def test_invalid_value_raises_error_case_sensitive(self):
|
||||
"""Test that invalid value raises ValueError in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_case_mismatch_raises_error_case_sensitive(self):
|
||||
"""Test that case mismatch raises ValueError in case sensitive mode."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'OPTION1' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_invalid_value_raises_error_case_insensitive(self):
|
||||
"""Test that invalid value raises ValueError when case insensitive."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV",
|
||||
"default", ["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_callable_choices_resolved_correctly(self):
|
||||
"""Test that callable choices are resolved correctly."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1"}):
|
||||
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||
assert env_func() == "dynamic1"
|
||||
|
||||
def test_callable_choices_with_invalid_value(self):
|
||||
"""Test that callable choices raise error for invalid values."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "invalid"}):
|
||||
env_func = env_with_choices("TEST_ENV", "default", get_choices)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' for TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
|
||||
class TestEnvListWithChoices:
|
||||
"""Test cases for env_list_with_choices function."""
|
||||
|
||||
def test_default_list_returned_when_env_not_set(self):
|
||||
"""Test that default list is returned when env var is not set."""
|
||||
env_func = env_list_with_choices("NONEXISTENT_ENV",
|
||||
["default1", "default2"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default1", "default2"]
|
||||
|
||||
def test_empty_default_list_returned_when_env_not_set(self):
|
||||
"""Test that empty default list is returned when env not set."""
|
||||
env_func = env_list_with_choices("NONEXISTENT_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == []
|
||||
|
||||
def test_single_valid_value_parsed_correctly(self):
|
||||
"""Test that single valid value is parsed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1"]
|
||||
|
||||
def test_multiple_valid_values_parsed_correctly(self):
|
||||
"""Test that multiple valid values are parsed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_values_with_whitespace_trimmed(self):
|
||||
"""Test that values with whitespace are trimmed correctly."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": " option1 , option2 "}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_empty_values_filtered_out(self):
|
||||
"""Test that empty values are filtered out."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,,option2,"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option2"]
|
||||
|
||||
def test_empty_string_returns_default(self):
|
||||
"""Test that empty string returns default."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": ""}):
|
||||
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default"]
|
||||
|
||||
def test_only_commas_returns_default(self):
|
||||
"""Test that string with only commas returns default."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": ",,,"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", ["default"],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["default"]
|
||||
|
||||
def test_case_sensitive_validation(self):
|
||||
"""Test case sensitive validation."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,OPTION2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"],
|
||||
case_sensitive=True)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'OPTION2' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_case_insensitive_validation(self):
|
||||
"""Test case insensitive validation."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "OPTION1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"],
|
||||
case_sensitive=False)
|
||||
assert env_func() == ["OPTION1", "option2"]
|
||||
|
||||
def test_invalid_value_in_list_raises_error(self):
|
||||
"""Test that invalid value in list raises ValueError."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,invalid,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_callable_choices_resolved_correctly(self):
|
||||
"""Test that callable choices are resolved correctly."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,dynamic2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||
assert env_func() == ["dynamic1", "dynamic2"]
|
||||
|
||||
def test_callable_choices_with_invalid_value(self):
|
||||
"""Test that callable choices raise error for invalid values."""
|
||||
|
||||
def get_choices():
|
||||
return ["dynamic1", "dynamic2"]
|
||||
|
||||
with patch.dict(os.environ, {"TEST_ENV": "dynamic1,invalid"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [], get_choices)
|
||||
with pytest.raises(ValueError,
|
||||
match="Invalid value 'invalid' in TEST_ENV"):
|
||||
env_func()
|
||||
|
||||
def test_duplicate_values_preserved(self):
|
||||
"""Test that duplicate values in the list are preserved."""
|
||||
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
|
||||
env_func = env_list_with_choices("TEST_ENV", [],
|
||||
["option1", "option2"])
|
||||
assert env_func() == ["option1", "option1", "option2"]
|
||||
@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from openai.types.responses.tool import Mcp
|
||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
@ -21,6 +22,24 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# This is currently needed as the tool type doesn't 1:1 match the
|
||||
# tool namespace, which is what is used to look up the
|
||||
# connection to the tool server
|
||||
_TOOL_NAME_TO_TYPE_MAP = {
|
||||
"browser": "web_search_preview",
|
||||
"python": "code_interpreter",
|
||||
"container": "container",
|
||||
}
|
||||
|
||||
|
||||
def _map_tool_name_to_tool_type(tool_name: str) -> str:
|
||||
if tool_name not in _TOOL_NAME_TO_TYPE_MAP:
|
||||
available_tools = ', '.join(_TOOL_NAME_TO_TYPE_MAP.keys())
|
||||
raise ValueError(
|
||||
f"Built-in tool name '{tool_name}' not defined in mapping. "
|
||||
f"Available tools: {available_tools}")
|
||||
return _TOOL_NAME_TO_TYPE_MAP[tool_name]
|
||||
|
||||
|
||||
class TurnTokens:
|
||||
"""Tracks token counts for a single conversation turn."""
|
||||
@ -59,8 +78,8 @@ class ConversationContext(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -96,8 +115,8 @@ class SimpleContext(ConversationContext):
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]) -> None:
|
||||
pass
|
||||
|
||||
async def cleanup_session(self) -> None:
|
||||
@ -318,13 +337,17 @@ class HarmonyContext(ConversationContext):
|
||||
]
|
||||
|
||||
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||
exit_stack: AsyncExitStack,
|
||||
request_id: str) -> None:
|
||||
exit_stack: AsyncExitStack, request_id: str,
|
||||
mcp_tools: dict[str, Mcp]):
|
||||
if tool_server:
|
||||
for tool_name in self.available_tools:
|
||||
if tool_name not in self._tool_sessions:
|
||||
tool_type = _map_tool_name_to_tool_type(tool_name)
|
||||
headers = mcp_tools[
|
||||
tool_type].headers if tool_type in mcp_tools else None
|
||||
tool_session = await exit_stack.enter_async_context(
|
||||
tool_server.new_session(tool_name, request_id))
|
||||
tool_server.new_session(tool_name, request_id,
|
||||
headers))
|
||||
self._tool_sessions[tool_name] = tool_session
|
||||
exit_stack.push_async_exit(self.cleanup_session)
|
||||
|
||||
|
||||
@ -126,8 +126,10 @@ def get_developer_message(
|
||||
function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
|
||||
for tool in tools:
|
||||
if tool.type in ("web_search_preview", "code_interpreter",
|
||||
"container"):
|
||||
"container", "mcp"):
|
||||
# These are built-in tools that are added to the system message.
|
||||
# Adding in MCP for now until we support MCP tools executed
|
||||
# server side
|
||||
pass
|
||||
|
||||
elif tool.type == "function":
|
||||
|
||||
@ -460,8 +460,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
try:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
request.request_id, mcp_tools)
|
||||
async for _ in result_generator:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
@ -748,11 +752,16 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# New conversation.
|
||||
reasoning_effort = (request.reasoning.effort
|
||||
if request.reasoning else None)
|
||||
# Temporary: OpenAI types doesn't have container tool
|
||||
# so we used MCP to cover that, up for change
|
||||
tool_types = [tool.type for tool in request.tools]
|
||||
if envs.VLLM_GPT_OSS_USE_CONTAINER_TOOL:
|
||||
tool_types.append("container")
|
||||
|
||||
# Allow the MCP Tool type to enable built in tools if the
|
||||
# server_label is allowlisted in
|
||||
# envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS
|
||||
if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS:
|
||||
for tool in request.tools:
|
||||
if (tool.type == "mcp" and tool.server_label
|
||||
in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS):
|
||||
tool_types.append(tool.server_label)
|
||||
enable_browser = ("web_search_preview" in tool_types
|
||||
and self.tool_server is not None
|
||||
and self.tool_server.has_tool("browser"))
|
||||
@ -1653,8 +1662,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
async with AsyncExitStack() as exit_stack:
|
||||
processer = None
|
||||
if self.use_harmony:
|
||||
mcp_tools = {
|
||||
tool.server_label: tool
|
||||
for tool in request.tools if tool.type == "mcp"
|
||||
}
|
||||
await context.init_tool_sessions(self.tool_server, exit_stack,
|
||||
request.request_id)
|
||||
request.request_id, mcp_tools)
|
||||
processer = self._process_harmony_streaming_events
|
||||
else:
|
||||
processer = self._process_simple_streaming_events
|
||||
|
||||
@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
async def list_server_and_tools(server_url: str):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
async with sse_client(url=server_url) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
initialize_response = await session.initialize()
|
||||
@ -86,8 +85,12 @@ class ToolServer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def new_session(self, tool_name: str,
|
||||
session_id: str) -> AbstractAsyncContextManager[Any]:
|
||||
def new_session(
|
||||
self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None
|
||||
) -> AbstractAsyncContextManager[Any]:
|
||||
"""
|
||||
Create a session for the tool.
|
||||
"""
|
||||
@ -144,16 +147,21 @@ class MCPToolServer(ToolServer):
|
||||
return self.harmony_tool_descriptions.get(tool_name)
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
url = self.urls.get(tool_name)
|
||||
headers = {"x-session-id": session_id}
|
||||
request_headers = {"x-session-id": session_id}
|
||||
if headers is not None:
|
||||
request_headers.update(headers)
|
||||
if not url:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
async with sse_client(url=url,
|
||||
headers=headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
async with sse_client(
|
||||
url=url, headers=request_headers) as streams, ClientSession(
|
||||
*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
|
||||
@ -189,7 +197,10 @@ class DemoToolServer(ToolServer):
|
||||
raise ValueError(f"Unknown tool {tool_name}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def new_session(self, tool_name: str, session_id: str):
|
||||
async def new_session(self,
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
headers: Optional[dict[str, str]] = None):
|
||||
if tool_name not in self.tools:
|
||||
raise KeyError(f"Tool '{tool_name}' is not supported")
|
||||
yield self.tools[tool_name]
|
||||
|
||||
66
vllm/envs.py
66
vllm/envs.py
@ -185,11 +185,11 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
|
||||
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
|
||||
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
GPT_OSS_SYSTEM_TOOL_MCP_LABELS: list[str] = []
|
||||
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
|
||||
|
||||
|
||||
@ -261,6 +261,58 @@ def env_with_choices(
|
||||
return _get_validated_env
|
||||
|
||||
|
||||
def env_list_with_choices(
|
||||
env_name: str,
|
||||
default: list[str],
|
||||
choices: Union[list[str], Callable[[], list[str]]],
|
||||
case_sensitive: bool = True) -> Callable[[], list[str]]:
|
||||
"""
|
||||
Create a lambda that validates environment variable
|
||||
containing comma-separated values against allowed choices
|
||||
|
||||
Args:
|
||||
env_name: Name of the environment variable
|
||||
default: Default list of values if not set
|
||||
choices: List of valid string options or callable that returns list
|
||||
case_sensitive: Whether validation should be case sensitive
|
||||
|
||||
Returns:
|
||||
Lambda function for environment_variables
|
||||
dict that returns list of strings
|
||||
"""
|
||||
|
||||
def _get_validated_env_list() -> list[str]:
|
||||
value = os.getenv(env_name)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
# Split comma-separated values and strip whitespace
|
||||
values = [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
if not values:
|
||||
return default
|
||||
|
||||
# Resolve choices if it's a callable (for lazy loading)
|
||||
actual_choices = choices() if callable(choices) else choices
|
||||
|
||||
# Validate each value
|
||||
for val in values:
|
||||
if not case_sensitive:
|
||||
check_value = val.lower()
|
||||
check_choices = [choice.lower() for choice in actual_choices]
|
||||
else:
|
||||
check_value = val
|
||||
check_choices = actual_choices
|
||||
|
||||
if check_value not in check_choices:
|
||||
raise ValueError(f"Invalid value '{val}' in {env_name}. "
|
||||
f"Valid options: {actual_choices}.")
|
||||
|
||||
return values
|
||||
|
||||
return _get_validated_env_list
|
||||
|
||||
|
||||
def get_vllm_port() -> Optional[int]:
|
||||
"""Get the port from VLLM_PORT environment variable.
|
||||
|
||||
@ -1320,10 +1372,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_TUNED_CONFIG_FOLDER":
|
||||
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
|
||||
|
||||
# Allows vllm use container tool
|
||||
"VLLM_GPT_OSS_USE_CONTAINER_TOOL":
|
||||
lambda: bool(int(os.getenv("VLLM_GPT_OSS_USE_CONTAINER_TOOL", "0"))),
|
||||
|
||||
# Allows harmony instructions to be injected on system messages
|
||||
"VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS":
|
||||
lambda: bool(
|
||||
@ -1343,6 +1391,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
|
||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||
|
||||
# Valid values are container,code_interpreter,web_search_preview
|
||||
# ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
|
||||
"GPT_OSS_SYSTEM_TOOL_MCP_LABELS":
|
||||
env_list_with_choices("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", [],
|
||||
["container",
|
||||
"code_interpreter",
|
||||
"web_search_preview"]),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user